In [1]:
import numpy as np
import pandas as pd

import torch

from gaussian_ring_grid_generator import GaussianRingSpaceTimeGrid
import encoder
import importlib
importlib.reload(encoder)

<module 'encoder' from '/home/prs392/codes/capstone/numerical-relativity-interpolation/dev/encoder_sr/encoder.py'>

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
spaceTimeContext = GaussianRingSpaceTimeGrid(
    n_space_grid = 32,
    n_time_grid = 5,
    space_min_x = -5,
    space_max_x = 5,
    time_min_t = 0.5,
    time_max_t = 2.5
)

  [torch.tensor(GaussianRing(self.space_grid.cpu(), i, i/2)).unsqueeze(0) for i in self.time_axis]


In [4]:
spaceTimeContext.values.cpu().numpy().shape

(5, 1, 32, 32, 32)

In [5]:
spaceTimeContext.time_axis

tensor([0.5000, 1.0000, 1.5000, 2.0000, 2.5000], device='cuda:0')

In [6]:
train_grid = GaussianRingSpaceTimeGrid(
    n_space_grid = 48,
    n_time_grid = 4,
    space_min_x = -5,
    space_max_x = 5,
    time_min_t = 0.75,
    time_max_t = 2.25
)

In [7]:
train_grid.time_axis

tensor([0.7500, 1.2500, 1.7500, 2.2500], device='cuda:0')

In [8]:
train_grid.values.shape

torch.Size([4, 1, 48, 48, 48])

In [9]:
def get_decimal_index(axis_values, value):
    closest_above = axis_values[axis_values >= value].min()
    closest_below = axis_values[axis_values <= value].max()
    index_below = np.where(axis_values == closest_below)[0][0]
    
    if closest_above == closest_below:
        return index_below
    
    return index_below + ((value - closest_below)/(closest_above - closest_below))

In [10]:
get_decimal_index(spaceTimeContext.time_axis.cpu().numpy(), 2.2500)

3.5

In [11]:
ts = []
xs = []
ys = []
zs = []

vals = []

for i in range(train_grid.values.shape[0]):
    t = get_decimal_index(spaceTimeContext.time_axis.cpu().numpy(), train_grid.time_axis[i].item())
    
    for j in range(train_grid.values.shape[2]):
        x = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[j].item())
        
        for k in range(train_grid.values.shape[3]):
            y = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[k].item())
            
            for l in range(train_grid.values.shape[4]):
                z = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[l].item())
                
                ts.append(t)
                xs.append(x)
                ys.append(y)
                zs.append(z)
                vals.append(train_grid.values[i, 0, j, k, l].item())

In [12]:
# import matplotlib.pyplot as plt

# plt.hist(vals, bins=100)

In [13]:
inps = np.array([ts, xs, ys, zs]).T
outs = np.array(vals)
inps.shape, outs.shape

((442368, 4), (442368,))

In [14]:
input_tensor = torch.Tensor(inps).unsqueeze(1)
output_tensor = torch.Tensor(outs).unsqueeze(1)
input_tensor.shape, output_tensor.shape

(torch.Size([442368, 1, 4]), torch.Size([442368, 1]))

In [15]:
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(input_tensor,output_tensor) # create your datset
train_dataloader = DataLoader(train_dataset, batch_size=32768, shuffle=True) # create your dataloader

In [19]:
import torch.nn as nn

model = encoder.SR(1, 16, 5)

In [25]:
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.05)

In [None]:
import tqdm

model.to(device)
model.train()

losses = []

for epoch in range(50, 100):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(train_dataloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(spaceTimeContext.values, inputs)
#         print(inputs)
#         print(labels)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        
    torch.save(model.state_dict(), f'/scratch/prs392/capstone/checkpoints/encoder_sr/epoch={epoch}.pt')
    
    print(f"epoch {epoch} loss = {running_loss/(i+1)}")
    losses.append(running_loss/(i+1))

print('Finished Training')

14it [04:24, 18.93s/it]
0it [00:00, ?it/s]

epoch 50 loss = 5333928.055937124


14it [04:25, 18.97s/it]
0it [00:00, ?it/s]

epoch 51 loss = 488.6258752005441


14it [04:26, 19.01s/it]
0it [00:00, ?it/s]

epoch 52 loss = 55.14495594160898


14it [04:24, 18.87s/it]
0it [00:00, ?it/s]

epoch 53 loss = 7.323188739163535


14it [04:22, 18.73s/it]
0it [00:00, ?it/s]

epoch 54 loss = 1.5556413808039256


14it [04:22, 18.72s/it]
0it [00:00, ?it/s]

epoch 55 loss = 0.39612047161374775


14it [04:21, 18.71s/it]
0it [00:00, ?it/s]

epoch 56 loss = 0.13036634427096164


14it [04:21, 18.70s/it]
0it [00:00, ?it/s]

epoch 57 loss = 0.05152318014630249


14it [04:22, 18.72s/it]
0it [00:00, ?it/s]

epoch 58 loss = 0.03611288831702301


14it [04:21, 18.70s/it]
0it [00:00, ?it/s]

epoch 59 loss = 0.028987424048994268


14it [04:21, 18.69s/it]
0it [00:00, ?it/s]

epoch 60 loss = 0.024255430458911827


14it [04:21, 18.70s/it]
0it [00:00, ?it/s]

epoch 61 loss = 0.021437791974416802


14it [04:21, 18.71s/it]
0it [00:00, ?it/s]

epoch 62 loss = 0.019567768488611494


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 63 loss = 0.01813037706805127


14it [04:21, 18.66s/it]
0it [00:00, ?it/s]

epoch 64 loss = 0.01699257149760212


14it [04:21, 18.65s/it]
0it [00:00, ?it/s]

epoch 65 loss = 0.01620507992005774


14it [04:21, 18.68s/it]
0it [00:00, ?it/s]

epoch 66 loss = 0.015423483215272427


14it [04:21, 18.69s/it]
0it [00:00, ?it/s]

epoch 67 loss = 0.014770506348993098


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 68 loss = 0.014207832182624511


14it [04:21, 18.65s/it]
0it [00:00, ?it/s]

epoch 69 loss = 0.013702534538294588


14it [04:21, 18.68s/it]
0it [00:00, ?it/s]

epoch 70 loss = 0.013280486688017845


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 71 loss = 0.012883521949074097


14it [04:21, 18.66s/it]
0it [00:00, ?it/s]

epoch 72 loss = 0.012500918337277003


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 73 loss = 0.012195691599377565


14it [04:21, 18.65s/it]
0it [00:00, ?it/s]

epoch 74 loss = 0.011916940699198417


14it [04:21, 18.65s/it]
0it [00:00, ?it/s]

epoch 75 loss = 0.011703362289283956


14it [04:21, 18.66s/it]
0it [00:00, ?it/s]

epoch 76 loss = 0.011397354770451784


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 77 loss = 0.011156127122896058


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 78 loss = 0.010964959859848022


14it [04:21, 18.66s/it]
0it [00:00, ?it/s]

epoch 79 loss = 0.010790758633187838


14it [04:21, 18.66s/it]
0it [00:00, ?it/s]

epoch 80 loss = 0.010610060433724098


14it [04:21, 18.64s/it]
0it [00:00, ?it/s]

epoch 81 loss = 0.010517977577235018


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 82 loss = 0.010368079878389835


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 83 loss = 0.010175659587340695


14it [04:21, 18.68s/it]
0it [00:00, ?it/s]

epoch 84 loss = 0.010026534819709403


14it [04:21, 18.71s/it]
0it [00:00, ?it/s]

epoch 85 loss = 0.009901656503123897


13it [04:11, 19.41s/it]