In [21]:
import numpy as np
import pandas as pd

import torch

from gaussian_ring_grid_generator import GaussianRingSpaceTimeGrid

In [2]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
LR = 1e-3
NUM_EPOCHS = 500
SPACE_GRID_MIN = -5
SPACE_GRID_MAX = 5
TIME_GRID_MIN = 0.5
TIME_GRID_MAX = 2.5

T_SPACE_TIME_MIN = torch.tensor([TIME_GRID_MIN, SPACE_GRID_MIN, SPACE_GRID_MIN, SPACE_GRID_MIN]).float().to(DEVICE)
T_SPACE_TIME_MAX = torch.tensor([TIME_GRID_MAX, SPACE_GRID_MAX, SPACE_GRID_MAX, SPACE_GRID_MAX]).float().to(DEVICE)


In [3]:
spaceTimeContext = GaussianRingSpaceTimeGrid(
    n_space_grid = 32,
    n_time_grid = 5,
    space_min_x = -5,
    space_max_x = 5,
    time_min_t = 0.5,
    time_max_t = 2.5
)

  [torch.tensor(GaussianRing(self.space_grid.cpu(), i, i/2)).unsqueeze(0) for i in self.time_axis]


In [7]:
spaceTimeContext.train_context.cpu().numpy().shape

(5, 1, 32, 32, 32)

In [8]:
grid = spaceTimeContext.train_context.cpu().numpy()
data_dicts = []
for t in range(grid.shape[0]):
    data_dict = {
        'x': [],
        'y': [],
        'z': [],
        'val': []
    }
    for x in range(grid.shape[2]):
        for y in range(grid.shape[3]):
            for z in range(grid.shape[4]):
                data_dict['x'].append(x)
                data_dict['y'].append(y)
                data_dict['z'].append(z)
                data_dict['val'].append(grid[t, 0, x, y, z])
    data_dicts.append(data_dict)        

In [10]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [36]:
## PLOT 3D

# t = 1
# fig = go.Figure(
#     data=[go.Scatter3d(x=data_dicts[t]['x'], 
#                        y=data_dicts[t]['y'], 
#                        z=data_dicts[t]['z'], 
#                        mode='markers',
#                        marker=dict(color=data_dicts[t]['val']),
#                        opacity=0.1)
#          ])
# fig.show()

In [35]:
fig = make_subplots(rows=5, cols=1)

for t in range(5):
    df = pd.DataFrame.from_dict(data_dicts[t])
    
    df_slice = df[df['z'] == int(df['z'].max()/2)]

    fig.add_trace(
        go.Scatter(x=df_slice['x'], y=df_slice['y'], mode = 'markers', marker=dict(
            color=df_slice['val'],
            size=10,
        )),
        row=t+1, col=1
    )


fig.update_layout(
    autosize=False,
    width=500,
    height=2500,
)
fig.show()