## Imports

In [1]:
# EXPORT
# --- Must haves ---
import os, sys
sys.path.append('..')

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.cuda as cuda
import torch.nn as nn

from surrogates4sims.datasets import MantaFlowDataset, getSingleSim, createMantaFlowTrainTest

from surrogates4sims.utils import create_opt, create_one_cycle, find_lr, printNumModelParams, \
                                    rmse, writeMessage, plotSampleWprediction, plotSampleWpredictionByChannel, \
                                    plotSample, curl, jacobian, stream2uv

from surrogates4sims.models import Generator, Encoder, AE_xhat_z

import numpy as np
from tqdm import tqdm
from copy import deepcopy



## Settings

In [2]:
# model name, for tensorboard recording and checkpointing purposes.
versionName = "EandG_as_oneModel"

# GPU Numbers to use. Comma seprate them for multi-GPUs.

gpu_ids = "3"
# path to load model weights.
pretrained_path = None

# rate at which to record metrics. (number of batches to average over when recording metrics, e.g. "every 5 batches")
tensorboard_rate = 5

# number of epochs to train. This is defined here so we can use the OneCycle LR Scheduler.
epochs = 100

# Data Directory
dataDirec = '/data/mantaFlowSim/data/smoke_pos21_size5_f200/v'
reverseXY = False 

# checkpoint directory
cps = 'cps'
tensorboard_direc = "tb"

findLRs = True # only do this if you're trying to set the LR of E, G. It blows up the GPU 

# hyper-params
seed = 1234
np.random.seed(seed)
testSplit = .1
bz = 8
numSamplesToKeep = np.infty #if not debugging
latentDim = 16
filters = 16
num_conv = 4
simLen = 200
stack = True
simVizIndex = 0 # sim in the test set to visualize
createStreamFcn = False
doJacobian = True
versionName = versionName + '_latentDim{}_filters{}_bz{}_numConv{}_stream{}_jacobian{}_epochs{}'.format(latentDim,filters,bz,num_conv,createStreamFcn,doJacobian,epochs)
versionName

'EandG_as_oneModel_latentDim16_filters16_bz8_numConv4_streamFalse_jacobianTrue_epochs100'

### Select Personal GPUs

In [3]:
!nvidia-smi

Wed Mar 18 01:31:26 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40       Driver Version: 430.40       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  TITAN Xp            On   | 00000000:02:00.0 Off |                  N/A |
| 23%   21C    P8     8W / 250W |   1042MiB / 12196MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  TITAN Xp            On   | 00000000:03:00.0 Off |                  N/A |
| 23%   21C    P8     8W / 250W |      1MiB / 12196MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  TITAN Xp            On   | 00000000:81:00.0 Off |                  N/

In [4]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]=gpu_ids

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [6]:
if device.type == 'cuda':
    print(cuda.is_available())
    print(cuda.device_count())
    print(cuda.current_device())
    print(cuda.get_device_name())

True
1
0
TITAN Xp


In [7]:
a = torch.zeros(5, device=device.type)
!nvidia-smi

Wed Mar 18 01:31:29 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40       Driver Version: 430.40       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  TITAN Xp            On   | 00000000:02:00.0 Off |                  N/A |
| 23%   21C    P8     8W / 250W |   1042MiB / 12196MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  TITAN Xp            On   | 00000000:03:00.0 Off |                  N/A |
| 23%   21C    P8     8W / 250W |      1MiB / 12196MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  TITAN Xp            On   | 00000000:81:00.0 Off |                  N/

## Datasets & Loaders

In [8]:
trainData, testData = createMantaFlowTrainTest(dataDirec,simLen,testSplit,seed)
print((len(trainData),len(testData)))

(19000, 2000)


In [9]:
# datasets may be smaller because: numSamplesToKeep 
testDataset = MantaFlowDataset(testData, reverseXY=reverseXY, numToKeep=numSamplesToKeep, AE=False)
trainDataset = MantaFlowDataset(trainData, reverseXY=reverseXY,numToKeep=numSamplesToKeep, AE=False)
len(trainDataset), len(testDataset)

100%|██████████| 2000/2000 [00:03<00:00, 567.80it/s]
100%|██████████| 19000/19000 [00:34<00:00, 548.28it/s]


(19000, 2000)

In [10]:
trainDataLoader = DataLoader(dataset=trainDataset, batch_size=bz, shuffle=True, drop_last=True)
testDataLoader = DataLoader(dataset=testDataset, batch_size=bz)

## Model

Currently, the models need to take data to be built. It's kinda weird. I may look into fix this later. 

In [11]:
X,p = next(iter(testDataLoader))
X = X.to(device)
p = p.to(device)
X.shape, p.shape

(torch.Size([8, 2, 128, 96]), torch.Size([8, 3]))

In [12]:
E = Encoder(X,filters,latentDim,num_conv=num_conv).to(device)
E

Encoder(
  (act): LeakyReLU(negative_slope=0.01)
  (conv1): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convs): Sequential(
    (0): convBlock(
      (act): LeakyReLU(negative_slope=0.01)
      (convs): Sequential(
        (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.01)
        (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): LeakyReLU(negative_slope=0.01)
        (4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): LeakyReLU(negative_slope=0.01)
        (6): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): LeakyReLU(negative_slope=0.01)
      )
      (downSampleLayer): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (1): convBlock(
      (act): LeakyReLU(negative_slope=0.01)
      (convs): Sequential(
        (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1

In [13]:
printNumModelParams(E)

54 layers require gradients (unfrozen) out of 54 layers
99,792 parameters require gradients (unfrozen) out of 99,792 parameters


In [14]:
z = E(X)
z.shape

torch.Size([8, 16])

In [15]:
output_shape = torch.tensor(X[0].shape)
output_shape

tensor([  2, 128,  96])

In [16]:
G = Generator(z, filters, output_shape,
                 num_conv=num_conv, conv_k=3, last_k=3, repeat=0, 
                 skip_connection=False, act=nn.LeakyReLU(), stack=stack)
G = G.to(device)
G

Generator(
  (linear): Linear(in_features=16, out_features=768, bias=True)
  (convTransBlockLayers): Sequential(
    (0): convTransBlock(
      (act): LeakyReLU(negative_slope=0.01)
      (upsample): Upsample(scale_factor=2.0, mode=nearest)
      (seq): Sequential(
        (0): ConvTranspose2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.01)
        (2): ConvTranspose2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): LeakyReLU(negative_slope=0.01)
        (4): ConvTranspose2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): LeakyReLU(negative_slope=0.01)
        (6): ConvTranspose2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
        (7): LeakyReLU(negative_slope=0.01)
      )
    )
    (1): convTransBlock(
      (act): LeakyReLU(negative_slope=0.01)
      (upsample): Upsample(scale_factor=2.0, mode=nearest)
      (seq): Sequential(
        (0): Con

In [17]:
printNumModelParams(G)

36 layers require gradients (unfrozen) out of 36 layers
65,442 parameters require gradients (unfrozen) out of 65,442 parameters


In [18]:
Xhat = G(z)
Xhat.shape

torch.Size([8, 2, 128, 96])

In [19]:
model = AE_xhat_z(E,G).to(device)
model

AE_xhat_z(
  (encoder): Encoder(
    (act): LeakyReLU(negative_slope=0.01)
    (conv1): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (convs): Sequential(
      (0): convBlock(
        (act): LeakyReLU(negative_slope=0.01)
        (convs): Sequential(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): LeakyReLU(negative_slope=0.01)
          (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): LeakyReLU(negative_slope=0.01)
          (4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (5): LeakyReLU(negative_slope=0.01)
          (6): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (7): LeakyReLU(negative_slope=0.01)
        )
        (downSampleLayer): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
      (1): convBlock(
        (act): LeakyReLU(negative_slope=0.01)
        (convs): Sequential(
         

In [20]:
printNumModelParams(model)

90 layers require gradients (unfrozen) out of 90 layers
165,234 parameters require gradients (unfrozen) out of 165,234 parameters


## Loss Function

In [21]:
def L1_loss(pred, target):
    return torch.mean(torch.abs(pred - target))


def jacobian_loss(pred, target, device='cpu'):
    return L1_loss(jacobian(pred, device), jacobian(target, device))


def curl_loss(pred, target, device):
    return L1_loss(curl(pred, device), curl(target, device))


L = nn.MSELoss()


def p_loss(pred, target):
    return L(pred[:, -target.shape[1]:], target)


def loss(pred, target, device):
    Lj = 0
    if createStreamFcn:
        L1 = L1_loss(stream2uv(pred, device), target)
    else:
        L1 = L1_loss(pred, target)
    if doJacobian:
        Lj = jacobian_loss(pred, target, device)
    return L1 + Lj

In [22]:
loss(Xhat,Xhat,device), loss(Xhat,X,device)

(tensor(0., device='cuda:0', grad_fn=<AddBackward0>),
 tensor(0.0375, device='cuda:0', grad_fn=<AddBackward0>))

In [23]:
p_loss(z,p)

tensor(0.4960, device='cuda:0', grad_fn=<MseLossBackward>)

In [24]:
max_lr = .00001
start_lr = 5*max_lr/10
opt = create_opt(max_lr,model)
#opt = torch.optim.Adam(model.parameters(),lr=start_lr,betas=(.5,.999))
lr_scheduler = create_one_cycle(opt,max_lr,epochs,trainDataLoader)

## Train

In [25]:
# EXPORT
def trainEpoch(myDataLoader, tensorboard_writer, model, opt, p_loss, loss,
               metric, lr_scheduler, tensorboard_rate, device,
               tensorboard_recorder_step, total_steps):
    running_loss = 0.0
    running_rmse = 0.0
    for i, sampleBatch in enumerate(myDataLoader, start=1):

        # --- Main Training ---
        
        # gpu
        X,p = sampleBatch[0],sampleBatch[1]
        X = X.to(device)
        p = p.to(device)

        # zero the parameter gradients
        opt.zero_grad()

        X_hat, z = model(X)
        combined_loss = p_loss(z,p) + loss(X_hat,X,device)
        combined_loss.backward()
        opt.step()
        
        # loss
        batch_loss = combined_loss.item()
        running_loss += batch_loss

        # --- Metrics Recording ---

        # metrics
        r = metric(X_hat, X)
        running_rmse += r

        # record lr change
        total_steps += 1
        tensorboard_writer.add_scalar(tag="LR", scalar_value=opt.param_groups[0]['lr'], global_step=total_steps)
        lr_scheduler.step()

        # tensorboard writes
        if (i % tensorboard_rate == 0):
            tensorboard_recorder_step += 1
            avg_running_loss = running_loss/tensorboard_rate
            avg_running_rmse = running_rmse/tensorboard_rate
            tensorboard_writer.add_scalar(tag="Loss", scalar_value=avg_running_loss, global_step=tensorboard_recorder_step)
            tensorboard_writer.add_scalar(tag=metric.__name__, scalar_value=avg_running_rmse, global_step=tensorboard_recorder_step)
            # reset running_loss for the next set of batches. (tensorboard_rate number of batches)
            running_loss = 0.0
            running_rmse = 0.0

    return batch_loss, tensorboard_recorder_step, total_steps


In [26]:
# EXPORT
def validEpoch(myDataLoader, tensorboard_writer, model, p_loss, loss, metric,
               device, tensorboard_recorder_step):
    running_loss = 0.0
    running_rmse = 0.0
    for i, sampleBatch in enumerate(myDataLoader, start=1):

        # --- Metrics Recording ---

        # gpu
        X,p = sampleBatch[0],sampleBatch[1]
        X = X.to(device)
        p = p.to(device)
        
        perc = len(X)/len(myDataLoader.dataset)

        # forward, no gradient calculations
        with torch.no_grad():
            X_hat, z = model(X)

        # loss
        combined_loss = p_loss(z,p) + loss(X_hat,X,device)
        
        running_loss += perc*(combined_loss.item())

        # metrics
        r = metric(X_hat, X)
        running_rmse += perc*r

    avg_running_loss = running_loss
    avg_running_rmse = running_rmse
    tensorboard_writer.add_scalar(tag="Loss", scalar_value=avg_running_loss, global_step=tensorboard_recorder_step)
    tensorboard_writer.add_scalar(tag=metric.__name__, scalar_value=avg_running_rmse, global_step=tensorboard_recorder_step)

    return running_loss

In [27]:
try:
    os.mkdir(cps)
except:
    print("checkpoints directory already exists :)")

checkpoints directory already exists :)


In [28]:
# create a summary writer.
train_writer = SummaryWriter(os.path.join(tensorboard_direc, versionName,'train'))
test_writer = SummaryWriter(os.path.join(tensorboard_direc, versionName,'valid'))
tensorboard_recorder_step = 0
total_steps = 0

In [29]:
writeMessage('---------- Started Training ----------', versionName)
bestLoss = np.infty

for epoch in tqdm(range(1, epochs+1)):  # loop over the dataset multiple times
    
    writeMessage("--- Epoch {0}/{1} ---".format(epoch, epochs), versionName)
    
    model.train()
    trainLoss, tensorboard_recorder_step, total_steps = trainEpoch(trainDataLoader, 
                                                                   train_writer, model, opt, p_loss, loss,
                                                                   rmse, lr_scheduler, 
                                                                   tensorboard_rate, device,
                                                                   tensorboard_recorder_step, total_steps)
    
    writeMessage("trainLoss: {:.4e}".format(trainLoss),versionName)
    model.eval()
    valLoss = validEpoch(testDataLoader, test_writer, model, p_loss, loss, rmse, device, tensorboard_recorder_step)
    
    # checkpoint progress
    if valLoss < bestLoss:
        bestLoss = valLoss
        writeMessage("Better valLoss: {:.4e}, Saving models...".format(bestLoss),versionName)
        torch.save(model.state_dict(), os.path.join(cps,versionName))

writeMessage('---------- Finished Training ----------', versionName)

  0%|          | 0/100 [00:00<?, ?it/s]

---------- Started Training ----------
--- Epoch 1/100 ---
trainLoss: 5.4962e-01


  1%|          | 1/100 [01:39<2:44:51, 99.91s/it]

Better valLoss: 5.6767e-01, Saving models...
--- Epoch 2/100 ---
trainLoss: 5.4540e-01


  2%|▏         | 2/100 [03:22<2:44:31, 100.73s/it]

Better valLoss: 5.6134e-01, Saving models...
--- Epoch 3/100 ---
trainLoss: 4.9954e-01


  3%|▎         | 3/100 [04:56<2:39:21, 98.57s/it] 

Better valLoss: 5.5638e-01, Saving models...
--- Epoch 4/100 ---
trainLoss: 4.5588e-01


  4%|▍         | 4/100 [06:38<2:39:40, 99.79s/it]

Better valLoss: 5.5220e-01, Saving models...
--- Epoch 5/100 ---
trainLoss: 4.4632e-01


  5%|▌         | 5/100 [08:11<2:34:40, 97.69s/it]

Better valLoss: 5.4803e-01, Saving models...
--- Epoch 6/100 ---
trainLoss: 4.8047e-01


  6%|▌         | 6/100 [09:55<2:36:02, 99.61s/it]

Better valLoss: 5.4215e-01, Saving models...
--- Epoch 7/100 ---
trainLoss: 3.7257e-01


  7%|▋         | 7/100 [11:39<2:36:24, 100.91s/it]

Better valLoss: 5.2141e-01, Saving models...
--- Epoch 8/100 ---
trainLoss: 4.0506e-01


  8%|▊         | 8/100 [13:10<2:30:21, 98.06s/it] 

Better valLoss: 3.6591e-01, Saving models...
--- Epoch 9/100 ---
trainLoss: 2.0759e-01


  9%|▉         | 9/100 [14:56<2:32:18, 100.43s/it]

Better valLoss: 3.0263e-01, Saving models...
--- Epoch 10/100 ---
trainLoss: 2.2957e-01


 10%|█         | 10/100 [16:38<2:31:14, 100.82s/it]

Better valLoss: 2.7988e-01, Saving models...
--- Epoch 11/100 ---
trainLoss: 2.9072e-01


 11%|█         | 11/100 [18:20<2:29:58, 101.10s/it]

--- Epoch 12/100 ---
trainLoss: 2.9711e-01


 12%|█▏        | 12/100 [20:02<2:28:49, 101.47s/it]

Better valLoss: 2.6537e-01, Saving models...
--- Epoch 13/100 ---
trainLoss: 1.4218e-01


 13%|█▎        | 13/100 [21:41<2:26:01, 100.71s/it]

Better valLoss: 2.2796e-01, Saving models...
--- Epoch 14/100 ---
trainLoss: 2.2348e-01


 14%|█▍        | 14/100 [23:20<2:23:42, 100.27s/it]

--- Epoch 15/100 ---
trainLoss: 9.8036e-02


 15%|█▌        | 15/100 [25:03<2:23:00, 100.95s/it]

--- Epoch 16/100 ---
trainLoss: 5.6088e-01


 16%|█▌        | 16/100 [26:39<2:19:17, 99.49s/it] 

--- Epoch 17/100 ---
trainLoss: 2.3574e+00


 17%|█▋        | 17/100 [28:17<2:16:55, 98.98s/it]

--- Epoch 18/100 ---
trainLoss: 4.2828e+00


 18%|█▊        | 18/100 [29:54<2:14:20, 98.30s/it]

--- Epoch 19/100 ---
trainLoss: 6.9889e+00


 19%|█▉        | 19/100 [31:35<2:13:48, 99.12s/it]

--- Epoch 20/100 ---
trainLoss: 8.9120e+01


 20%|██        | 20/100 [33:17<2:13:38, 100.24s/it]

--- Epoch 21/100 ---
trainLoss: 2.3837e+02


 21%|██        | 21/100 [34:58<2:12:12, 100.41s/it]

--- Epoch 22/100 ---
trainLoss: 6.4788e+03


 22%|██▏       | 22/100 [36:47<2:13:58, 103.05s/it]

--- Epoch 23/100 ---
trainLoss: 1.0699e+09


 23%|██▎       | 23/100 [38:23<2:09:31, 100.92s/it]

--- Epoch 24/100 ---
trainLoss: 2.0505e+13


 24%|██▍       | 24/100 [40:00<2:06:21, 99.76s/it] 

--- Epoch 25/100 ---
trainLoss: 9.4251e+14


 25%|██▌       | 25/100 [41:36<2:03:13, 98.58s/it]

--- Epoch 26/100 ---
trainLoss: 3.1843e+16


 26%|██▌       | 26/100 [43:20<2:03:36, 100.22s/it]

--- Epoch 27/100 ---
trainLoss: 4.3563e+17


 27%|██▋       | 27/100 [44:56<2:00:18, 98.88s/it] 

--- Epoch 28/100 ---
trainLoss: 9.8963e+17


 28%|██▊       | 28/100 [46:28<1:56:11, 96.83s/it]

--- Epoch 29/100 ---
trainLoss: 9.4321e+18


 29%|██▉       | 29/100 [48:10<1:56:32, 98.48s/it]

--- Epoch 30/100 ---
trainLoss: 8.8430e+19


 30%|███       | 30/100 [49:43<1:52:52, 96.76s/it]

--- Epoch 31/100 ---
trainLoss: 1.1968e+20


 31%|███       | 31/100 [51:24<1:52:29, 97.82s/it]

--- Epoch 32/100 ---
trainLoss: 2.2566e+20


 32%|███▏      | 32/100 [52:54<1:48:13, 95.49s/it]

--- Epoch 33/100 ---
trainLoss: 2.1486e+20


 33%|███▎      | 33/100 [54:37<1:49:24, 97.98s/it]

--- Epoch 34/100 ---
trainLoss: 1.7400e+20


 34%|███▍      | 34/100 [56:14<1:47:18, 97.55s/it]

--- Epoch 35/100 ---
trainLoss: 1.4559e+20


 35%|███▌      | 35/100 [57:54<1:46:22, 98.19s/it]

--- Epoch 36/100 ---
trainLoss: 1.5817e+20


 36%|███▌      | 36/100 [59:31<1:44:28, 97.95s/it]

--- Epoch 37/100 ---
trainLoss: 3.5627e+20


 37%|███▋      | 37/100 [1:01:18<1:45:43, 100.68s/it]

--- Epoch 38/100 ---
trainLoss: 1.5123e+20


 38%|███▊      | 38/100 [1:02:48<1:40:43, 97.48s/it] 

--- Epoch 39/100 ---
trainLoss: 1.8066e+20


 39%|███▉      | 39/100 [1:04:29<1:40:09, 98.51s/it]

--- Epoch 40/100 ---
trainLoss: 2.5007e+20


 40%|████      | 40/100 [1:06:09<1:38:58, 98.98s/it]

--- Epoch 41/100 ---
trainLoss: 4.6751e+20


 41%|████      | 41/100 [1:07:52<1:38:32, 100.22s/it]

--- Epoch 42/100 ---
trainLoss: 2.9852e+20


 42%|████▏     | 42/100 [1:09:38<1:38:30, 101.90s/it]

--- Epoch 43/100 ---
trainLoss: 7.5548e+20


 43%|████▎     | 43/100 [1:11:25<1:38:21, 103.53s/it]

--- Epoch 44/100 ---
trainLoss: 1.6779e+21


 44%|████▍     | 44/100 [1:13:06<1:35:58, 102.83s/it]

--- Epoch 45/100 ---
trainLoss: 4.6512e+20


 45%|████▌     | 45/100 [1:14:44<1:32:46, 101.20s/it]

--- Epoch 46/100 ---
trainLoss: 1.3076e+21


 46%|████▌     | 46/100 [1:16:32<1:33:04, 103.42s/it]

--- Epoch 47/100 ---
trainLoss: 6.2041e+20


 47%|████▋     | 47/100 [1:18:08<1:29:11, 100.97s/it]

--- Epoch 48/100 ---
trainLoss: 5.3842e+20


 48%|████▊     | 48/100 [1:19:47<1:27:08, 100.55s/it]

--- Epoch 49/100 ---
trainLoss: 5.4306e+20


 49%|████▉     | 49/100 [1:21:22<1:23:55, 98.73s/it] 

--- Epoch 50/100 ---
trainLoss: 7.6217e+20


 50%|█████     | 50/100 [1:23:03<1:22:55, 99.51s/it]

--- Epoch 51/100 ---
trainLoss: 1.2851e+21


 51%|█████     | 51/100 [1:24:39<1:20:17, 98.32s/it]

--- Epoch 52/100 ---
trainLoss: 1.2217e+21


 52%|█████▏    | 52/100 [1:26:18<1:18:59, 98.73s/it]

--- Epoch 53/100 ---
trainLoss: 1.4519e+21


 53%|█████▎    | 53/100 [1:27:40<1:13:15, 93.51s/it]

--- Epoch 54/100 ---
trainLoss: 1.2666e+21


 54%|█████▍    | 54/100 [1:29:18<1:12:43, 94.86s/it]

--- Epoch 55/100 ---
trainLoss: 6.7463e+20


 55%|█████▌    | 55/100 [1:31:01<1:13:02, 97.39s/it]

--- Epoch 56/100 ---
trainLoss: 8.8348e+20


 56%|█████▌    | 56/100 [1:32:45<1:12:54, 99.42s/it]

--- Epoch 57/100 ---
trainLoss: 1.5935e+21


 57%|█████▋    | 57/100 [1:34:24<1:11:04, 99.17s/it]

--- Epoch 58/100 ---
trainLoss: 7.5907e+20


 58%|█████▊    | 58/100 [1:35:59<1:08:35, 97.98s/it]

--- Epoch 59/100 ---
trainLoss: 2.0471e+21


 59%|█████▉    | 59/100 [1:37:36<1:06:49, 97.79s/it]

--- Epoch 60/100 ---
trainLoss: 1.9490e+21


 60%|██████    | 60/100 [1:39:17<1:05:47, 98.69s/it]

--- Epoch 61/100 ---
trainLoss: 1.2235e+21


 61%|██████    | 61/100 [1:41:02<1:05:24, 100.62s/it]

--- Epoch 62/100 ---
trainLoss: 6.4475e+20


 62%|██████▏   | 62/100 [1:42:41<1:03:17, 99.93s/it] 

--- Epoch 63/100 ---
trainLoss: 9.0391e+20


 63%|██████▎   | 63/100 [1:44:22<1:01:53, 100.35s/it]

--- Epoch 64/100 ---
trainLoss: 2.1864e+21


 64%|██████▍   | 64/100 [1:46:06<1:00:53, 101.48s/it]

--- Epoch 65/100 ---
trainLoss: 2.0815e+21


 65%|██████▌   | 65/100 [1:47:51<59:44, 102.42s/it]  

--- Epoch 66/100 ---
trainLoss: 1.7615e+21


 66%|██████▌   | 66/100 [1:49:29<57:23, 101.28s/it]

--- Epoch 67/100 ---
trainLoss: 4.6639e+20


 67%|██████▋   | 67/100 [1:51:02<54:18, 98.75s/it] 

--- Epoch 68/100 ---
trainLoss: 1.0603e+21


 68%|██████▊   | 68/100 [1:52:30<50:55, 95.49s/it]

--- Epoch 69/100 ---
trainLoss: 4.9149e+21


 69%|██████▉   | 69/100 [1:54:12<50:17, 97.35s/it]

--- Epoch 70/100 ---
trainLoss: 1.0919e+21


 70%|███████   | 70/100 [1:55:53<49:19, 98.65s/it]

--- Epoch 71/100 ---
trainLoss: 9.9858e+20


 71%|███████   | 71/100 [1:57:35<48:11, 99.70s/it]

--- Epoch 72/100 ---
trainLoss: 1.4904e+21


 72%|███████▏  | 72/100 [1:59:11<46:00, 98.58s/it]

--- Epoch 73/100 ---
trainLoss: 7.7704e+20


 73%|███████▎  | 73/100 [2:00:52<44:41, 99.32s/it]

--- Epoch 74/100 ---
trainLoss: 9.6056e+20


 74%|███████▍  | 74/100 [2:02:29<42:38, 98.41s/it]

--- Epoch 75/100 ---
trainLoss: 1.1525e+21


 75%|███████▌  | 75/100 [2:03:59<39:56, 95.85s/it]

--- Epoch 76/100 ---
trainLoss: 1.1264e+21


 76%|███████▌  | 76/100 [2:05:46<39:46, 99.43s/it]

--- Epoch 77/100 ---
trainLoss: 3.0393e+21


 77%|███████▋  | 77/100 [2:07:23<37:48, 98.63s/it]

--- Epoch 78/100 ---
trainLoss: 1.5472e+21


 78%|███████▊  | 78/100 [2:08:49<34:42, 94.66s/it]

--- Epoch 79/100 ---
trainLoss: 1.9906e+21


 79%|███████▉  | 79/100 [2:10:31<33:56, 96.99s/it]

--- Epoch 80/100 ---
trainLoss: 1.4218e+21


 80%|████████  | 80/100 [2:12:17<33:11, 99.56s/it]

--- Epoch 81/100 ---
trainLoss: 1.6178e+21


 81%|████████  | 81/100 [2:13:56<31:31, 99.57s/it]

--- Epoch 82/100 ---
trainLoss: 2.2104e+21


 82%|████████▏ | 82/100 [2:15:26<28:57, 96.52s/it]

--- Epoch 83/100 ---
trainLoss: 5.0764e+21


 83%|████████▎ | 83/100 [2:17:04<27:29, 97.05s/it]

--- Epoch 84/100 ---
trainLoss: 1.7475e+21


 84%|████████▍ | 84/100 [2:18:52<26:46, 100.41s/it]

--- Epoch 85/100 ---
trainLoss: 1.6908e+21


 85%|████████▌ | 85/100 [2:20:31<24:58, 99.90s/it] 

--- Epoch 86/100 ---
trainLoss: 3.9375e+21


 86%|████████▌ | 86/100 [2:22:17<23:45, 101.79s/it]

--- Epoch 87/100 ---
trainLoss: 3.0183e+21


 87%|████████▋ | 87/100 [2:23:58<21:58, 101.41s/it]

--- Epoch 88/100 ---
trainLoss: 2.2620e+21


 88%|████████▊ | 88/100 [2:25:39<20:16, 101.39s/it]

--- Epoch 89/100 ---
trainLoss: 1.1120e+21


 89%|████████▉ | 89/100 [2:27:21<18:36, 101.51s/it]

--- Epoch 90/100 ---
trainLoss: 2.2152e+21


 90%|█████████ | 90/100 [2:29:08<17:13, 103.31s/it]

--- Epoch 91/100 ---
trainLoss: 4.3085e+21


 91%|█████████ | 91/100 [2:30:49<15:22, 102.52s/it]

--- Epoch 92/100 ---


KeyboardInterrupt: 

## Compare: Generated vs. Simulated

In [None]:
E.eval()
G.eval()
sampleBatch = next(iter(testDataLoader))
X,p = sampleBatch
X = X.to(device)
p = p.to(device)

with torch.no_grad():
    z = E(X)
    z[:,-p.shape[1]:] = p
    X_hat = G(z)
    
X.shape, p.shape, z.shape, X_hat.shape

In [None]:
idx = 0 # frame in the batch 
XX = X[idx].detach().cpu().squeeze()
XX_hat = X_hat[idx].detach().cpu().squeeze()
plotSampleWpredictionByChannel(XX, XX_hat)

In [None]:
plotSampleWprediction(XX, XX_hat)

In [None]:
plotSample(XX-XX_hat)

### Visualize full simulation

In [None]:
simData = getSingleSim(sim=simVizIndex,dataDirec=testData)
simDataset = MantaFlowDataset(simData, reverseXY=reverseXY, numToKeep=numSamplesToKeep, AE=False)
simDataLoader = DataLoader(simDataset,batch_size=1)

In [None]:
E.eval()
G.eval()
for i, sampleBatch in enumerate(simDataLoader,start=1):
    with torch.no_grad():
        X,p = sampleBatch
        X = X.to(device)
        p = p.to(device)
        
        z = E(X)
        z[:,-p.shape[1]:] = p
        X_hat = G(z)
        
        X = X.detach().cpu().squeeze()
        X_hat = X_hat.detach().cpu().squeeze()
        plotSampleWprediction(X, X_hat)

