In [1]:
import numpy as np
import pandas as pd

import torch

from gaussian_ring_grid_generator import GaussianRingSpaceTimeGrid
import encoder
import importlib
importlib.reload(encoder)

<module 'encoder' from '/home/prs392/codes/capstone/numerical-relativity-interpolation/dev/encoder_sr/encoder.py'>

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
spaceTimeContext = GaussianRingSpaceTimeGrid(
    n_space_grid = 32,
    n_time_grid = 5,
    space_min_x = -5,
    space_max_x = 5,
    time_min_t = 0.5,
    time_max_t = 2.5
)

  [torch.tensor(GaussianRing(self.space_grid.cpu(), i, i/2)).unsqueeze(0) for i in self.time_axis]


In [4]:
spaceTimeContext.values.cpu().numpy().shape

(5, 1, 32, 32, 32)

In [5]:
spaceTimeContext.time_axis

tensor([0.5000, 1.0000, 1.5000, 2.0000, 2.5000], device='cuda:0')

In [6]:
train_grid = GaussianRingSpaceTimeGrid(
    n_space_grid = 48,
    n_time_grid = 4,
    space_min_x = -5,
    space_max_x = 5,
    time_min_t = 0.75,
    time_max_t = 2.25
)

In [7]:
train_grid.time_axis

tensor([0.7500, 1.2500, 1.7500, 2.2500], device='cuda:0')

In [8]:
train_grid.values.shape

torch.Size([4, 1, 48, 48, 48])

In [9]:
def get_decimal_index(axis_values, value):
    closest_above = axis_values[axis_values >= value].min()
    closest_below = axis_values[axis_values <= value].max()
    index_below = np.where(axis_values == closest_below)[0][0]
    
    if closest_above == closest_below:
        return index_below
    
    return index_below + ((value - closest_below)/(closest_above - closest_below))

In [10]:
get_decimal_index(spaceTimeContext.time_axis.cpu().numpy(), 2.2500)

3.5

In [11]:
ts = []
xs = []
ys = []
zs = []

vals = []

for i in range(train_grid.values.shape[0]):
    t = get_decimal_index(spaceTimeContext.time_axis.cpu().numpy(), train_grid.time_axis[i].item())
    
    for j in range(train_grid.values.shape[2]):
        x = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[j].item())
        
        for k in range(train_grid.values.shape[3]):
            y = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[k].item())
            
            for l in range(train_grid.values.shape[4]):
                z = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[l].item())
                
                ts.append(t)
                xs.append(x)
                ys.append(y)
                zs.append(z)
                vals.append(train_grid.values[i, 0, j, k, l].item())

In [12]:
# import matplotlib.pyplot as plt

# plt.hist(vals, bins=100)

In [13]:
inps = np.array([ts, xs, ys, zs]).T
outs = np.array(vals)
inps.shape, outs.shape

((442368, 4), (442368,))

In [14]:
input_tensor = torch.Tensor(inps).unsqueeze(1)
output_tensor = torch.Tensor(outs).unsqueeze(1)
input_tensor.shape, output_tensor.shape

(torch.Size([442368, 1, 4]), torch.Size([442368, 1]))

In [15]:
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(input_tensor,output_tensor) # create your datset
train_dataloader = DataLoader(train_dataset, batch_size=32768, shuffle=True) # create your dataloader

In [19]:
import torch.nn as nn

model = encoder.SR(1, 16, 5)

In [20]:
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
import tqdm

model.to(device)
model.train()

losses = []

for epoch in range(50):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(train_dataloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(spaceTimeContext.values, inputs)
#         print(inputs)
#         print(labels)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        
    torch.save(model.state_dict(), f'/scratch/prs392/capstone/checkpoints/encoder_sr/epoch={epoch}.pt')
    
    print(f"epoch {epoch} loss = {running_loss/(i+1)}")
    losses.append(running_loss/(i+1))

print('Finished Training')

14it [05:17, 22.64s/it]
0it [00:00, ?it/s]

epoch 0 loss = 0.007605520376403417


14it [05:16, 22.59s/it]
0it [00:00, ?it/s]

epoch 1 loss = 0.007448434829711914


14it [05:14, 22.47s/it]
0it [00:00, ?it/s]

epoch 2 loss = 0.007344634471727269


14it [05:13, 22.38s/it]
0it [00:00, ?it/s]

epoch 3 loss = 0.00721325894950756


14it [05:13, 22.39s/it]
0it [00:00, ?it/s]

epoch 4 loss = 0.007068819599226117


14it [05:15, 22.52s/it]
0it [00:00, ?it/s]

epoch 5 loss = 0.006960891420021653


14it [05:12, 22.32s/it]
0it [00:00, ?it/s]

epoch 6 loss = 0.006852412662867989


14it [05:12, 22.34s/it]
0it [00:00, ?it/s]

epoch 7 loss = 0.0067599195587847915


14it [05:12, 22.32s/it]
0it [00:00, ?it/s]

epoch 8 loss = 0.006682845397985407


14it [05:12, 22.33s/it]
0it [00:00, ?it/s]

epoch 9 loss = 0.00660231097468308


14it [05:11, 22.28s/it]
0it [00:00, ?it/s]

epoch 10 loss = 0.006508246463324342


14it [05:13, 22.36s/it]
0it [00:00, ?it/s]

epoch 11 loss = 0.0064395632860916


14it [05:12, 22.30s/it]
0it [00:00, ?it/s]

epoch 12 loss = 0.006383288452135665


14it [05:15, 22.52s/it]
0it [00:00, ?it/s]

epoch 13 loss = 0.006308311929128


14it [05:11, 22.28s/it]
0it [00:00, ?it/s]

epoch 14 loss = 0.006246997243059533


14it [05:12, 22.34s/it]
0it [00:00, ?it/s]

epoch 15 loss = 0.006182679640395301


14it [05:12, 22.31s/it]
0it [00:00, ?it/s]

epoch 16 loss = 0.006106865359470248


14it [05:11, 22.28s/it]
0it [00:00, ?it/s]

epoch 17 loss = 0.0060717032424041206


14it [05:14, 22.46s/it]
0it [00:00, ?it/s]

epoch 18 loss = 0.006024948886728713


14it [05:12, 22.32s/it]
0it [00:00, ?it/s]

epoch 19 loss = 0.005971043470448681


14it [04:39, 19.99s/it]
0it [00:00, ?it/s]

epoch 20 loss = 0.005916495807468891


14it [04:21, 18.68s/it]
0it [00:00, ?it/s]

epoch 21 loss = 0.0058698539755174094


14it [04:21, 18.68s/it]
0it [00:00, ?it/s]

epoch 22 loss = 0.005824423461620297


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 23 loss = 0.005787487308095608


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 24 loss = 0.005751576573987093


14it [04:21, 18.66s/it]
0it [00:00, ?it/s]

epoch 25 loss = 0.0057195795234292746


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 26 loss = 0.0056900015499975


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 27 loss = 0.00565663688550038


14it [04:21, 18.69s/it]
0it [00:00, ?it/s]

epoch 28 loss = 0.0056306247133761644


14it [04:21, 18.69s/it]
0it [00:00, ?it/s]

epoch 29 loss = 0.005593237334064075


14it [04:21, 18.68s/it]
0it [00:00, ?it/s]

epoch 30 loss = 0.005567864681194935


14it [04:21, 18.66s/it]
0it [00:00, ?it/s]

epoch 31 loss = 0.0055362797741379055


14it [04:21, 18.65s/it]
0it [00:00, ?it/s]

epoch 32 loss = 0.005525986930089337


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 33 loss = 0.005500405162040677


14it [04:21, 18.66s/it]
0it [00:00, ?it/s]

epoch 34 loss = 0.005466016408588205


14it [04:21, 18.65s/it]
0it [00:00, ?it/s]

epoch 35 loss = 0.00545191764831543


14it [04:21, 18.65s/it]
0it [00:00, ?it/s]

epoch 36 loss = 0.005439422958131347


14it [04:21, 18.66s/it]
0it [00:00, ?it/s]

epoch 37 loss = 0.00542583512807531


14it [04:21, 18.67s/it]
0it [00:00, ?it/s]

epoch 38 loss = 0.005404135910794139


12it [03:51, 19.33s/it]