In [28]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import random
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
from torch import nn

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

from run_model import parse_options
from evaluation import spectrum_band
from models import ResNet, Unet, ResNet_UM, Unet_UM, ResNet_Mag, Unet_Mag, ResNet_Rot, Unet_Rot, ResNet_Scale, Unet_Scale
from utils import train_epoch, eval_epoch, test_epoch, Dataset, get_lr, train_epoch_scale, eval_epoch_scale, test_epoch_scale, Dataset_scale

import os, sys
sys.path.append(os.path.join(os.getcwd(), '../../sympde/'))
from models.model_mag import mag_conv2d, mag_resblock
from models.model_noequ import Resblock
from misc.equiv import assert_equiv


from model.networks.single_sym.magnitude import Conv1dMag, Conv2dMag


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
args = parse_options(notebook=True)
args.dataset = 'RBC'
args.architecture = 'ResNet'
args.symmetry = 'Mag' #'None', 'UM', 'Rot', 'Mag', 'Scale'
args.output_length = 3
args.learning_rate = 0.001

args.batch_size = 3

In [6]:
random.seed(args.seed)  # python random generator
np.random.seed(args.seed)  # numpy random generator

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

symmetry = args.symmetry
model_name = args.architecture + "_" + args.symmetry
num_epoch = args.num_epoch
learning_rate = args.learning_rate # 0.0005 for mag_equ resnet; 0.0001 for scale_equ resnet
batch_size = args.batch_size
input_length = args.input_length
train_output_length = args.output_length # 4 for all Unets
test_output_length = 10
kernel_size = args.kernel_size
lr_decay = args.decay_rate

## Data

In [7]:
if args.dataset == "RBC":
    train_direc = "data_64/sample_"
    valid_direc = "data_64/sample_"
    train_indices = list(range(0, 6000))
    valid_indices = list(range(6000, 8000))

    # test on future time steps
    test_future_direc = "data_64/sample_"
    test_future_indices = list(range(8000, 10000)) 

    # test on data applied with symmetry transformations 
    test_domain_direc = "data_64/sample_" if args.symmetry == "None" else "data_" + symmetry.lower() + "/sample_" 
    print(test_domain_direc)
    test_domain_indices = list(range(8000, 10000)) 

elif args.dataset == "Ocean":
    train_direc = "ocean_train/sample_"
    valid_direc = "ocean_train/sample_"
    train_indices = list(range(0, 8000))
    valid_indices = list(range(8000, 10000))

    # test on future time steps
    test_future_direc = "ocean_train/sample_"
    test_future_indices = list(range(10000, 12000)) 

    # test on data from different domain
    test_domain_direc = "ocean_test/sample_"
    test_domain_indices = list(range(0, 2000)) 
    
else:
    print("Invalid dataset name entered!")

data_mag/sample_


In [8]:
if symmetry != "Scale":
    train_set = Dataset(train_indices, input_length, 30, train_output_length, train_direc, True)
    valid_set = Dataset(valid_indices, input_length, 30, train_output_length, valid_direc, True)
    test_future_set = Dataset(test_future_indices, input_length, 30, test_output_length, test_future_direc, True)
    test_domain_set = Dataset(test_domain_indices, input_length, 30, test_output_length, test_domain_direc, True)
else:
    # use Dataset_scale for scale equivariant models
    train_set = Dataset_scale(train_indices, input_length, 30, train_output_length, train_direc)
    valid_set = Dataset_scale(valid_indices, input_length, 30, train_output_length, train_direc)
    test_future_set = Dataset_scale(test_future_indices, input_length, 30, test_output_length, test_future_direc)
    test_domain_set = Dataset_scale(test_domain_indices, input_length, 30, test_output_length, test_domain_direc)


In [9]:
d = torch.load(train_direc + str(1) + ".pt")
x = d[(train_set.mid-train_set.input_length):train_set.mid].transpose(0,1)
y = d[train_set.mid:(train_set.mid+train_set.output_length)].transpose(0,1)
x.shape, y.shape

(torch.Size([2, 24, 64, 64]), torch.Size([2, 3, 64, 64]))

In [10]:
x, y = train_set[0]
print(x.shape, y.shape)

torch.Size([48, 64, 64]) torch.Size([3, 2, 64, 64])


In [11]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory=True)
test_future_loader = torch.utils.data.DataLoader(test_future_set, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory=True)
test_domain_loader = torch.utils.data.DataLoader(test_domain_set, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory=True)


In [13]:
save_name = args.dataset + "_model{}_bz{}_inp{}_pred{}_lr{}_decay{}_kernel{}_seed{}".format(model_name,
                                                                                                batch_size,
                                                                                                input_length,
                                                                                                train_output_length,
                                                                                                learning_rate,
                                                                                                lr_decay,
                                                                                                kernel_size, 
                                                                                                args.seed)
                                                                                        
print(save_name)

RBC_modelResNet_Mag_bz3_inp24_pred3_lr0.001_decay0.95_kernel3_seed0


## Model

In [14]:
if model_name == "ResNet_UM":
    model = ResNet_UM(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)
elif model_name == "Unet_UM":
    model = Unet_UM(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)
elif model_name == "ResNet_Rot":
    model = ResNet_Rot(input_frames = input_length, output_frames = 1, kernel_size = kernel_size, N = 8).to(device)
elif model_name == "Unet_Rot":
    model = Unet_Rot(input_frames = input_length, output_frames = 1, kernel_size = kernel_size, N = 8).to(device)
elif model_name == "ResNet_Mag":
    model = ResNet_Mag(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)
elif model_name == "Unet_Mag":  
    model = Unet_Mag(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)
elif model_name == "ResNet_Scale":
    model = ResNet_Scale(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)
elif model_name == "Unet_Scale":  
    model = Unet_Scale(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)
elif model_name == "ResNet_None":
    model = ResNet(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)
elif model_name == "Unet_None":
    model = Unet(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)
else:
    print("Invalid model name entered!")

In [15]:
optimizer = torch.optim.Adam(model.parameters(), learning_rate,betas=(0.9, 0.999), weight_decay=4e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=lr_decay)
loss_fun = torch.nn.MSELoss()

## Train

In [16]:
# min_rmse = 1e6
# train_rmse = []
# valid_rmse = []
# test_rmse = []

# for i in tqdm(range(num_epoch), desc = 'Training'):
#     start = time.time()
    
#     if symmetry != "Scale":
#         model.train()
#         train_rmse.append(train_epoch(train_loader, model, optimizer, loss_fun))
#         model.eval()
#         rmse, _, _ = eval_epoch(valid_loader, model, loss_fun)
#         valid_rmse.append(rmse)
#     else:
#         model.train()
#         train_rmse.append(train_epoch_scale(train_loader, model, optimizer, loss_fun))
#         model.eval()
#         rmse, _, _ = eval_epoch_scale(valid_loader, model, loss_fun)
#         valid_rmse.append(rmse)

#     if valid_rmse[-1] < min_rmse:
#         min_rmse = valid_rmse[-1] 
#         best_model = model
#     end = time.time()
    
#     # Early Stopping but train at least for 50 epochs
#     if (len(train_rmse) > 50 and np.mean(valid_rmse[-5:]) >= np.mean(valid_rmse[-10:-5])):
#             break
#     print("Epoch {} | T: {:0.2f} | Train RMSE: {:0.3f} | Valid RMSE: {:0.3f}".format(i + 1, (end-start) / 60, train_rmse[-1], valid_rmse[-1]))
#     scheduler.step()
    
    
# if symmetry != "Scale":
#     test_future_rmse, test_future_preds, test_future_trues, test_future_loss_curve = test_epoch(test_future_loader, best_model, loss_fun)
#     test_domain_rmse, test_domain_preds, test_domain_trues, test_domain_loss_curve = test_epoch(test_domain_loader, best_model, loss_fun)
# else:
#     test_future_rmse, test_future_preds, test_future_trues, test_future_loss_curve = test_epoch_scale(test_future_loader, best_model, loss_fun)
#     test_domain_rmse, test_domain_preds, test_domain_trues, test_domain_loss_curve = test_epoch_scale(test_domain_loader, best_model, loss_fun)


In [21]:
x, y = next(iter(train_loader))
x.shape, y.shape

(torch.Size([3, 48, 64, 64]), torch.Size([3, 3, 2, 64, 64]))

In [25]:
def mag_mult(x):
    shape = x.shape
    batch_size = shape[0]
    lens = np.ones(len(shape)-1)
    lens = lens.astype(int)
    lens = np.append(batch_size, lens)
    mult = torch.rand(*lens)
    return x * mult

mult = torch.rand(batch_size)
def mag_mult(x):
    shape = x.shape
    batch_size = shape[0]
    lens = np.ones(len(shape)-1)
    lens = lens.astype(int)
    lens = np.append(batch_size, lens)
    mult2 = mult.view(*lens)
    return x * mult2


In [26]:
input_channels = input_length*2
output_channels = 32
kernel_size = kernel_size
um_dim = 1

In [32]:
mag_conv = mag_conv2d(
    input_channels = input_length*2,
    output_channels = 32,
    kernel_size = kernel_size,
    um_dim = 2,
)

non_conv = nn.Conv2d(
    in_channels = input_length*2,
    out_channels = 32,
    kernel_size = kernel_size,
    padding = (kernel_size - 1)//2,
)

out = mag_conv(x)

my_mag_conv = Conv2dMag(
    input_channels = input_length*2,
    output_channels = 32,
    kernel_size = kernel_size,
)


# unfolded = conv.unfold(x)
# # print(out.shape)

# transformed, stds = conv.transform(unfolded)
# # print(transformed.shape, stds.shape)

# conved = conv.conv2d(transformed)
# # print(conved.shape)

# inverse_transformed = conv.inverse_transform(conved, stds)

print_only = False
atol = 1e-5

resblock_mag  = mag_resblock(input_channels, 64, kernel_size)
resblock_none = Resblock(input_channels, 64, kernel_size)

print('\nresblock_none')
assert_equiv(x, mag_mult, resblock_none, atol=atol, print_only=True)
print('\nnon_conv')
assert_equiv(x, mag_mult, non_conv, atol=atol, print_only=True)

print('\n')
assert_equiv(x, mag_mult, resblock_mag, atol=atol, print_only=print_only)
assert_equiv(x, mag_mult, mag_conv, atol=atol, print_only=print_only)
assert_equiv(x, mag_mult, my_mag_conv, atol=atol, print_only=print_only)


resblock_none
Equivariance test failed. 
Max difference:  2.74568 
Mean difference: 0.212647

non_conv
Equivariance test failed. 
Max difference:  0.0326888 
Mean difference: 0.0121383


residual upscale add
residual upscale add
in torch.Size([3, 32, 64, 64])
stds torch.Size([3, 1, 64, 64])
in torch.Size([3, 32, 64, 64])
stds torch.Size([3, 1, 64, 64])


In [37]:
resblock_mag  = mag_resblock(input_channels, 64, kernel_size)

assert_equiv(x, mag_mult, resblock_mag, atol=atol, print_only=True)


residual upscale add
torch.Size([3, 64, 64, 64]) tensor(0.2598, grad_fn=<MeanBackward0>)
torch.Size([3, 48, 64, 64]) tensor(-0.1797)
torch.Size([3, 64, 64, 64]) tensor(0.1444, grad_fn=<MeanBackward0>)
residual upscale add
torch.Size([3, 64, 64, 64]) tensor(0.1262, grad_fn=<MeanBackward0>)
torch.Size([3, 48, 64, 64]) tensor(-0.0794)
torch.Size([3, 64, 64, 64]) tensor(0.0683, grad_fn=<MeanBackward0>)
Equivariance test passed.


In [146]:
64*64

4096

In [183]:
my_mag_conv2 = my_mag_conv2d(
    input_channels = input_length*2,
    output_channels = 32,
    kernel_size = kernel_size,
)

out = my_mag_conv2(x)

in torch.Size([3, 32, 64, 64])
stds torch.Size([3, 1, 64, 64])


In [163]:
48*3

144

In [216]:
nx, nt = 10, 15

u = torch.rand(1, nx, nt)
dx = torch.rand(1, nx, nt)
dt = torch.rand(1, nt, nt)

cat = torch.cat([u, dx, dt], dim=1)
cat.shape

torch.Size([1, 35, 15])

In [47]:
input_length

24

In [38]:
my_mag_conv1 = Conv1dMag(
    input_channels = input_length*2,
    output_channels = 32,
    kernel_size = kernel_size,
)

x1d = x[..., 0]
out = my_mag_conv1(x1d)

In [46]:
torch.save(x1d, 'x1d.pt')

In [44]:
print(x1d.shape)
print(x1d.mean(), x1d.std())
y1, y2 = assert_equiv(x1d, mag_mult, my_mag_conv1, atol=1e-6, print_only=True)
print(y1.mean(), y2.mean(), y1.std(), y2.std())

torch.Size([3, 48, 64])
tensor(-0.4090) tensor(0.9056)
Equivariance test passed.
tensor(0.1278, grad_fn=<MeanBackward0>) tensor(0.1278, grad_fn=<MeanBackward0>) tensor(0.1762, grad_fn=<StdBackward0>) tensor(0.1762, grad_fn=<StdBackward0>)


In [179]:
my_conv = my_mag_conv2d(
    input_channels = input_length*2,
    output_channels = 32,
    kernel_size = kernel_size,
)

my_out = my_conv(x)
assert (my_out == out).all(), torch.mean(torch.abs(my_out - out))

in torch.Size([3, 48, 64, 64])
padded torch.Size([3, 48, 66, 66])
unfold torch.Size([3, 432, 4096])


RuntimeError: The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 2

In [None]:
assert False

In [None]:
reshaped = unfolded.reshape(3, 24, 2, 3, 3, 64, 64)
print(reshaped.shape)
reshaped.max(1).values.unsqueeze(1).max(3).values.shape

torch.Size([3, 24, 2, 3, 3, 64, 64])


torch.Size([3, 1, 2, 3, 64, 64])

In [None]:
conv(x)

torch.Size([3, 48, 64, 64])
torch.Size([3, 48, 3, 3, 64, 64])
torch.Size([3, 48, 192, 192]) torch.Size([3, 1, 2, 64, 64])
torch.Size([3, 64, 64, 64])


AssertionError: (torch.Size([3, 64, 64, 64]), torch.Size([3, 48, 64, 64]))

In [None]:
out2 = out.clone()

In [None]:
(out == out2).all()

tensor(True)

In [None]:
out.shape

torch.Size([3, 48, 3, 3, 64, 64])

In [None]:
model(x)

x torch.Size([3, 48, 64, 64])
torch.Size([3, 48, 64, 64])
torch.Size([3, 48, 3, 3, 64, 64])
torch.Size([3, 48, 192, 192]) torch.Size([3, 1, 2, 64, 64])
torch.Size([3, 64, 64, 64])
activation torch.Size([3, 64, 64, 64])
torch.Size([3, 64, 64, 64])
torch.Size([3, 64, 64, 64])
torch.Size([3, 64, 3, 3, 64, 64])
torch.Size([3, 64, 192, 192]) torch.Size([3, 1, 2, 64, 64])
torch.Size([3, 64, 64, 64])
activation torch.Size([3, 64, 64, 64])
torch.Size([3, 64, 64, 64])
torch.Size([3, 48, 64, 64])
torch.Size([3, 48, 3, 3, 64, 64])
torch.Size([3, 48, 192, 192]) torch.Size([3, 1, 2, 64, 64])
torch.Size([3, 64, 64, 64])
torch.Size([3, 64, 64, 64])
forward torch.Size([3, 64, 64, 64])
torch.Size([3, 64, 64, 64])
torch.Size([3, 64, 3, 3, 64, 64])
torch.Size([3, 64, 192, 192]) torch.Size([3, 1, 2, 64, 64])
torch.Size([3, 64, 64, 64])
activation torch.Size([3, 64, 64, 64])
torch.Size([3, 64, 64, 64])
torch.Size([3, 64, 64, 64])
torch.Size([3, 64, 3, 3, 64, 64])
torch.Size([3, 64, 192, 192]) torch.Size([3

tensor([[[[-5.6323e-02, -5.0285e-02, -4.9234e-02,  ..., -5.5318e-02,
           -5.7749e-02, -6.2793e-02],
          [-6.7118e-02, -6.6979e-02, -7.2402e-02,  ..., -4.1469e-02,
           -4.2685e-02, -4.7675e-02],
          [-7.7572e-02, -7.4063e-02, -8.0540e-02,  ..., -1.3340e-02,
           -1.6663e-02, -2.1257e-02],
          ...,
          [-5.6505e-02, -5.4493e-02, -5.4724e-02,  ..., -1.1731e-02,
           -1.7330e-02, -2.3117e-02],
          [-5.3550e-02, -5.2407e-02, -5.6243e-02,  ..., -1.9907e-03,
            2.6783e-03, -4.3141e-04],
          [-4.6165e-02, -4.5276e-02, -4.9311e-02,  ..., -5.3422e-02,
           -4.7472e-02, -4.9616e-02]],

         [[-2.4309e-03, -6.7737e-04,  8.8923e-03,  ...,  4.9826e-02,
            5.5412e-02,  5.7955e-02],
          [-1.4533e-02, -1.2620e-02, -1.1789e-02,  ...,  3.3725e-02,
            3.5838e-02,  3.8212e-02],
          [-4.8281e-02, -5.2890e-02, -5.5887e-02,  ...,  5.0041e-02,
            5.0288e-02,  4.8916e-02],
          ...,
     

## Workbench

In [None]:
assert False

In [None]:
# Shapes in ResNet_rot
# for i in [16, 32, 64, 128, 192]:
#     print(i, i*8)