In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import numpy as np
import pandas as pd

import torch

from gaussian_ring_grid_generator import GaussianRingSpaceTimeGrid
import encoder
import importlib
importlib.reload(encoder)

<module 'encoder' from '/Users/paramshah/Documents/Param/NYU/courses/Capstone Project/repo/numerical-relativity-interpolation/dev/encoder_sr/encoder.py'>

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
spaceTimeContext = GaussianRingSpaceTimeGrid(
    n_space_grid = 32,
    n_time_grid = 5,
    space_min_x = -5,
    space_max_x = 5,
    time_min_t = 0.5,
    time_max_t = 2.5
)

  [torch.tensor(GaussianRing(self.space_grid.cpu(), i, i/2)).unsqueeze(0) for i in self.time_axis]


In [6]:
spaceTimeContext.values.cpu().numpy().shape

(5, 1, 32, 32, 32)

In [7]:
spaceTimeContext.time_axis

tensor([0.5000, 1.0000, 1.5000, 2.0000, 2.5000])

In [8]:
train_grid = GaussianRingSpaceTimeGrid(
    n_space_grid = 48,
    n_time_grid = 4,
    space_min_x = -5,
    space_max_x = 5,
    time_min_t = 0.75,
    time_max_t = 2.25
)

In [9]:
train_grid.time_axis

tensor([0.7500, 1.2500, 1.7500, 2.2500])

In [10]:
train_grid.values.shape

torch.Size([4, 1, 48, 48, 48])

In [11]:
def get_decimal_index(axis_values, value):
    closest_above = axis_values[axis_values >= value].min()
    closest_below = axis_values[axis_values <= value].max()
    index_below = np.where(axis_values == closest_below)[0][0]
    
    if closest_above == closest_below:
        return index_below
    
    return index_below + ((value - closest_below)/(closest_above - closest_below))

In [12]:
get_decimal_index(spaceTimeContext.time_axis.cpu().numpy(), 2.2500)

3.5

In [13]:
ts = []
xs = []
ys = []
zs = []

vals = []

for i in range(train_grid.values.shape[0]):
    t = get_decimal_index(spaceTimeContext.time_axis.cpu().numpy(), train_grid.time_axis[i].item())
    
    for j in range(train_grid.values.shape[2]):
        x = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[j].item())
        
        for k in range(train_grid.values.shape[3]):
            y = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[k].item())
            
            for l in range(train_grid.values.shape[4]):
                z = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[l].item())
                
                ts.append(t)
                xs.append(x)
                ys.append(y)
                zs.append(z)
                vals.append(train_grid.values[i, 0, j, k, l].item())

In [14]:
# import matplotlib.pyplot as plt

# plt.hist(vals, bins=100)

In [15]:
inps = np.array([ts, xs, ys, zs]).T
outs = np.array(vals)
inps.shape, outs.shape

((442368, 4), (442368,))

In [16]:
input_tensor = torch.Tensor(inps).unsqueeze(1)
output_tensor = torch.Tensor(outs).unsqueeze(1)
input_tensor.shape, output_tensor.shape

(torch.Size([442368, 1, 4]), torch.Size([442368, 1]))

In [17]:
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(input_tensor,output_tensor) # create your datset
train_dataloader = DataLoader(train_dataset, batch_size=`, shuffle=True) # create your dataloader

In [49]:
import torch.nn as nn

model = encoder.SR(1, 16, 5)

In [50]:
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.05)

In [53]:
import tqdm

model.to(device)
model.train()

losses = []

for epoch in range(0, 100):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(train_dataloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(spaceTimeContext.values, inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        
    torch.save(model.state_dict(), f'/scratch/prs392/capstone/checkpoints/encoder_sr/epoch={epoch}.pt')
    
    print(f"epoch {epoch} loss = {running_loss/(i+1)}")
    losses.append(running_loss/(i+1))

print('Finished Training')

0it [00:00, ?it/s]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


1it [00:08,  8.77s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


2it [00:17,  8.67s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


3it [00:30, 10.03s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


4it [00:41, 10.48s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


5it [00:55, 11.30s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


6it [01:04, 10.84s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


7it [01:13, 10.26s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


8it [01:22,  9.80s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


9it [01:32,  9.91s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


10it [01:42,  9.93s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


11it [01:52, 10.01s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


12it [02:01,  9.70s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([32768, 1, 4])
torch.Size([32768, 1])
torch.Size([32768, 1])


13it [02:10,  9.50s/it]

torch.Size([5, 1, 32, 32, 32])
torch.Size([16384, 1, 4])
torch.Size([16384, 1])
torch.Size([16384, 1])


14it [02:16,  9.73s/it]

epoch 0 loss = 38892.905305044995
Finished Training



