In [None]:
import torch
import torch.nn as nn
import numpy as np
from lib.model import RNN, DCN, Encoder, EncDec
import random
from lib.utils import Dataset,train_model, inference
from math import exp
from lib.generator import Shift, LorenzRandFGenerator
from torch.utils.tensorboard import SummaryWriter

SEED = 1234
DTYPE = torch.float64
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.benchmark=True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
data_name = f'Shift'

train_size = 3000
test_size = 500

train_dataset = Dataset(*Shift({'path_len':128,'shift': 40}).generate(data_num=train_size), dtype=DTYPE, device=device)
test_dataset = Dataset(*Shift({'path_len':128,'shift': 40}).generate(data_num=test_size), dtype=DTYPE, device=device)

train_data = torch.utils.data.DataLoader(train_dataset, batch_size=128,drop_last=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=128,drop_last=True)

In [None]:
experiment_name = f'{data_name}_rnn'
rnn = RNN(input_dim=1, output_dim=1, hid_dim=32).double().to(device)
rnn.load_state_dict(torch.load(f"saved_model/{experiment_name}/best_valid.pt"))
rnn.count_parameters()
# train_model(name=experiment_name,model=rnn,train_data=train_data, test_data=test_data)
print(f'Best Valid Loss: {np.mean(inference(rnn, test_data)):.2e}')

In [None]:
experiment_name = f'{data_name}_cnn'
cnn = DCN(input_dim=1, output_dim=1, hid_dim=1, kernel_size=2, num_layers=6).double().to(device)
cnn.load_state_dict(torch.load(f"saved_model/{experiment_name}/best_valid.pt"))
cnn.count_parameters()
# train_model(name=experiment_name,model=cnn,train_data=train_data, test_data=test_data)
print(f'Best Valid Loss: {np.mean(inference(cnn, test_data)):.2e}')

In [None]:
data_name = f'Shift'

train_size = 3000
test_size = 500

patch_size = 8

x, y = Shift({'path_len':128,'shift': 40}).generate(data_num=train_size)

x = x.reshape(train_size, -1,patch_size)
y = y.reshape(train_size, -1, patch_size)

train_dataset = Dataset(x,y, dtype=DTYPE, device=device)

xx, yy = Shift({'path_len':128,'shift': 40}).generate(data_num=test_size)

xx = xx.reshape(test_size, -1,patch_size)
yy = yy.reshape(test_size, -1, patch_size)

test_dataset = Dataset(xx,yy , dtype=DTYPE, device=device)

train_data = torch.utils.data.DataLoader(train_dataset, batch_size=128,drop_last=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=128,drop_last=True)

In [None]:
experiment_name = f'{data_name}_transformer'
transformer = Encoder(8, 8,hid_dim=16,n_layers=2,n_heads=2, pf_dim=8, dropout=0.1,device=device,max_length = 16,pos_embed = True, param = 'full').double().to(device)
# transformer.load_state_dict(torch.load(f"saved_model/{experiment_name}/best_valid.pt"))
transformer.count_parameters()
train_model(name=experiment_name,model=transformer,train_data=train_data, test_data=test_data, epochs=10000)
print(f'Best Valid Loss: {np.mean(inference(transformer, test_data)):.2e}')

In [None]:
experiment_name = f'{data_name}_encdec'
encdec = EncDec(1, 1, hid_dim=8, output_len=128).double().to(device)
encdec.load_state_dict(torch.load(f"saved_model/{experiment_name}/best_valid.pt"))
# train_model(name=experiment_name,model=encdec,train_data=train_data, test_data=test_data)
print(f'Best Valid Loss: {np.mean(inference(encdec, test_data)):.2e}')

In [None]:
data_name = f'Lorentz'


train_size = 3000
test_size = 500

train_dataset = Dataset(*LorenzRandFGenerator({'path_len':128, 'n_init':train_size, 'K':1, 'J':6}).generate(data_num=train_size), dtype=DTYPE, device=device)
test_dataset = Dataset(*LorenzRandFGenerator({'path_len':128, 'n_init':test_size,'K':1, 'J':6}).generate(data_num=test_size), dtype=DTYPE, device=device)



train_data = torch.utils.data.DataLoader(train_dataset, batch_size=128,drop_last=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=128,drop_last=True)

In [None]:
experiment_name = f'{data_name}_rnn'
rnn = RNN(input_dim=1, output_dim=1, hid_dim=32, activation='tanh').double().to(device)
rnn.load_state_dict(torch.load(f"saved_model/{experiment_name}/best_valid.pt"))
rnn.count_parameters()
# train_model(name=experiment_name,model=rnn,train_data=train_data, test_data=test_data)
print(f'Best Valid Loss: {np.mean(inference(rnn, test_data)):.2e}')

In [None]:
experiment_name = f'{data_name}_cnn'
cnn = DCN(input_dim=1, output_dim=1, hid_dim=16, kernel_size=2, num_layers=8, activation='tanh').double().to(device)
cnn.load_state_dict(torch.load(f"saved_model/{experiment_name}/best_valid.pt"))
cnn.count_parameters()
# train_model(name=experiment_name,model=cnn,train_data=train_data, test_data=test_data)
print(f'Best Valid Loss: {np.mean(inference(cnn, test_data)):.2e}')

In [None]:
experiment_name = f'{data_name}_encdec'
encdec = EncDec(1, 1, hid_dim=64, output_len=128, activation='tanh').double().to(device)
# transformer.load_state_dict(torch.load(f"saved_model/{experiment_name}/best_valid.pt"))
train_model(name=experiment_name,model=encdec,train_data=train_data, test_data=test_data)
print(f'Best Valid Loss: {np.mean(inference(encdec, test_data)):.2e}')