In [1]:
import numpy as np
import pandas as pd

import torch

from gaussian_ring_grid_generator import GaussianRingSpaceTimeGrid
import encoder
import importlib
importlib.reload(encoder)

<module 'encoder' from '/home/ns4486/repos/numerical-relativity-interpolation/dev/encoder_sr/encoder.py'>

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
spaceTimeContext = GaussianRingSpaceTimeGrid(
    n_space_grid = 32,
    n_time_grid = 5,
    space_min_x = -5,
    space_max_x = 5,
    time_min_t = 0.5,
    time_max_t = 2.5
)

  [torch.tensor(GaussianRing(self.space_grid.cpu(), i, i/2)).unsqueeze(0) for i in self.time_axis]


In [4]:
spaceTimeContext.values.cpu().numpy().shape

(5, 1, 32, 32, 32)

In [5]:
spaceTimeContext.time_axis

tensor([0.5000, 1.0000, 1.5000, 2.0000, 2.5000], device='cuda:0')

In [6]:
train_grid = GaussianRingSpaceTimeGrid(
    n_space_grid = 48,
    n_time_grid = 4,
    space_min_x = -5,
    space_max_x = 5,
    time_min_t = 0.75,
    time_max_t = 2.25
)

In [7]:
train_grid.time_axis

tensor([0.7500, 1.2500, 1.7500, 2.2500], device='cuda:0')

In [8]:
train_grid.values.shape

torch.Size([4, 1, 48, 48, 48])

In [9]:
def get_decimal_index(axis_values, value):
    closest_above = axis_values[axis_values >= value].min()
    closest_below = axis_values[axis_values <= value].max()
    index_below = np.where(axis_values == closest_below)[0][0]
    
    if closest_above == closest_below:
        return index_below
    
    return index_below + ((value - closest_below)/(closest_above - closest_below))

In [10]:
get_decimal_index(spaceTimeContext.time_axis.cpu().numpy(), 2.2500)

3.5

In [11]:
ts = []
xs = []
ys = []
zs = []

vals = []

for i in range(train_grid.values.shape[0]):
    t = get_decimal_index(spaceTimeContext.time_axis.cpu().numpy(), train_grid.time_axis[i].item())
    
    for j in range(train_grid.values.shape[2]):
        x = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[j].item())
        
        for k in range(train_grid.values.shape[3]):
            y = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[k].item())
            
            for l in range(train_grid.values.shape[4]):
                z = get_decimal_index(spaceTimeContext.space_axis.cpu().numpy(), train_grid.space_axis[l].item())
                
                ts.append(t)
                xs.append(x)
                ys.append(y)
                zs.append(z)
                vals.append(train_grid.values[i, 0, j, k, l].item())

In [12]:
# import matplotlib.pyplot as plt

# plt.hist(vals, bins=100)

In [13]:
inps = np.array([ts, xs, ys, zs]).T
outs = np.array(vals)
inps.shape, outs.shape

((442368, 4), (442368,))

In [14]:
input_tensor = torch.Tensor(inps).unsqueeze(1)
output_tensor = torch.Tensor(outs).unsqueeze(1)
input_tensor.shape, output_tensor.shape

(torch.Size([442368, 1, 4]), torch.Size([442368, 1]))

In [15]:
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(input_tensor,output_tensor) # create your datset
train_dataloader = DataLoader(train_dataset, batch_size=32768, shuffle=True) # create your dataloader

In [16]:
import torch.nn as nn

model = encoder.SR(1, 16, 5)

In [17]:
import torch.optim as optim

criterion = nn.MSELoss()
# optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr = 1e-2)

In [18]:
import tqdm

model.to(device)
model.train()

losses = []

for epoch in range(50):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(train_dataloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(spaceTimeContext.values, inputs)
#         print(inputs)
#         print(labels)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        
    torch.save(model.state_dict(), f'/scratch/ns4486/capstone/checkpoints/encoder_sr/epoch={epoch}.pt')
    
    print(f"epoch {epoch} loss = {running_loss/(i+1)}")
    losses.append(running_loss/(i+1))

print('Finished Training')

14it [05:45, 24.68s/it]
0it [00:00, ?it/s]

epoch 0 loss = 6770.706713551922


14it [05:44, 24.58s/it]
0it [00:00, ?it/s]

epoch 1 loss = 0.6194949597120285


14it [05:44, 24.57s/it]
0it [00:00, ?it/s]

epoch 2 loss = 0.05988189165613481


14it [05:42, 24.47s/it]
0it [00:00, ?it/s]

epoch 3 loss = 0.011697415528552873


14it [05:43, 24.57s/it]
0it [00:00, ?it/s]

epoch 4 loss = 0.011148180413459028


14it [05:45, 24.69s/it]
0it [00:00, ?it/s]

epoch 5 loss = 0.00949632063774126


14it [05:44, 24.62s/it]
0it [00:00, ?it/s]

epoch 6 loss = 0.009279511735907622


14it [05:44, 24.63s/it]
0it [00:00, ?it/s]

epoch 7 loss = 0.009130594919302635


14it [05:44, 24.57s/it]
0it [00:00, ?it/s]

epoch 8 loss = 0.009056243247219495


14it [05:44, 24.64s/it]
0it [00:00, ?it/s]

epoch 9 loss = 0.009019255837691682


14it [05:44, 24.64s/it]
0it [00:00, ?it/s]

epoch 10 loss = 0.008979667643351215


14it [05:44, 24.64s/it]
0it [00:00, ?it/s]

epoch 11 loss = 0.008941244400505508


14it [05:43, 24.55s/it]
0it [00:00, ?it/s]

epoch 12 loss = 0.008970942880426134


14it [05:43, 24.56s/it]
0it [00:00, ?it/s]

epoch 13 loss = 0.008920830142285143


14it [05:44, 24.57s/it]
0it [00:00, ?it/s]

epoch 14 loss = 0.008952116061534201


14it [05:42, 24.45s/it]
0it [00:00, ?it/s]

epoch 15 loss = 0.008936675465000527


14it [05:43, 24.55s/it]
0it [00:00, ?it/s]

epoch 16 loss = 0.008942103212965387


14it [05:43, 24.53s/it]
0it [00:00, ?it/s]

epoch 17 loss = 0.008928659571600812


14it [05:41, 24.41s/it]
0it [00:00, ?it/s]

epoch 18 loss = 0.008928511225219284


14it [05:41, 24.40s/it]
0it [00:00, ?it/s]

epoch 19 loss = 0.008907151741108723


14it [05:41, 24.40s/it]
0it [00:00, ?it/s]

epoch 20 loss = 0.008932368497231178


14it [05:41, 24.40s/it]
0it [00:00, ?it/s]

epoch 21 loss = 0.00893037925873484


14it [05:41, 24.42s/it]
0it [00:00, ?it/s]

epoch 22 loss = 0.0089134038425982


14it [05:45, 24.71s/it]
0it [00:00, ?it/s]

epoch 23 loss = 0.008915816034589494


14it [05:41, 24.41s/it]
0it [00:00, ?it/s]

epoch 24 loss = 0.00890451782782163


14it [05:41, 24.36s/it]
0it [00:00, ?it/s]

epoch 25 loss = 0.008918393137199538


14it [05:41, 24.41s/it]
0it [00:00, ?it/s]

epoch 26 loss = 0.00889779135052647


14it [05:44, 24.59s/it]
0it [00:00, ?it/s]

epoch 27 loss = 0.008881916757673025


14it [05:41, 24.39s/it]
0it [00:00, ?it/s]

epoch 28 loss = 0.008878881510879313


14it [05:41, 24.43s/it]
0it [00:00, ?it/s]

epoch 29 loss = 0.008869468633617674


14it [05:43, 24.50s/it]
0it [00:00, ?it/s]

epoch 30 loss = 0.008863791357725859


14it [05:42, 24.49s/it]
0it [00:00, ?it/s]

epoch 31 loss = 0.00879957526922226


14it [05:44, 24.61s/it]
0it [00:00, ?it/s]

epoch 32 loss = 0.008749103639274836


14it [05:45, 24.65s/it]
0it [00:00, ?it/s]

epoch 33 loss = 0.008651306054421834


14it [05:44, 24.64s/it]
0it [00:00, ?it/s]

epoch 34 loss = 0.00854050527725901


14it [05:45, 24.69s/it]
0it [00:00, ?it/s]

epoch 35 loss = 0.008476507317806994


14it [05:43, 24.57s/it]
0it [00:00, ?it/s]

epoch 36 loss = 0.008470666195665087


14it [05:45, 24.66s/it]
0it [00:00, ?it/s]

epoch 37 loss = 0.00847233725445611


14it [05:41, 24.41s/it]
0it [00:00, ?it/s]

epoch 38 loss = 0.008444930360253369


14it [05:41, 24.43s/it]
0it [00:00, ?it/s]

epoch 39 loss = 0.008449646910386426


14it [05:40, 24.32s/it]
0it [00:00, ?it/s]

epoch 40 loss = 0.008399678061583213


14it [05:45, 24.68s/it]
0it [00:00, ?it/s]

epoch 41 loss = 0.008413996547460556


14it [05:46, 24.72s/it]
0it [00:00, ?it/s]

epoch 42 loss = 0.008396371533828122


14it [05:45, 24.66s/it]
0it [00:00, ?it/s]

epoch 43 loss = 0.008414726970451218


14it [05:45, 24.67s/it]
0it [00:00, ?it/s]

epoch 44 loss = 0.008371946135801929


14it [05:43, 24.53s/it]
0it [00:00, ?it/s]

epoch 45 loss = 0.00836835369201643


14it [05:43, 24.51s/it]
0it [00:00, ?it/s]

epoch 46 loss = 0.008345689424978835


14it [05:44, 24.62s/it]
0it [00:00, ?it/s]

epoch 47 loss = 0.008342076625142778


14it [05:43, 24.56s/it]
0it [00:00, ?it/s]

epoch 48 loss = 0.008340467299733843


14it [05:44, 24.62s/it]

epoch 49 loss = 0.008309325868529933
Finished Training



