In [1]:
import os
import pickle
from tqdm import tqdm
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.utils.data import DataLoader
from datasets.ascad import AscadDataset
from models.recurrent import LstmModel
from datasets.transforms import ToTensor
from training.training import execute_epoch, train_batch, eval_batch
from training.metrics import get_loss

In [2]:
retrain = False
num_epochs = 10
batch_size = 256
grad_clip_norm = 1.0
metrics = {'loss': get_loss}
results_dir = os.path.join('.', 'results', 'lstm_autoencoder')

In [3]:
train_dataset = AscadDataset(train=True, transform=ToTensor())
test_dataset = AscadDataset(train=False, transform=ToTensor())
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
realistic_hindsight_values = [True, False]
delay_values = [0, 1, 2, 4, 8, 16]
train_results = {(hs, d): {} for hs in realistic_hindsight_values for d in delay_values}
test_results = {(hs, d): {} for hs in realistic_hindsight_values for d in delay_values}

if retrain:
    for realistic_hindsight in realistic_hindsight_values:
        for delay in delay_values:
            model = LstmModel(delay=delay, realistic_hindsight=realistic_hindsight)
            loss_fn = nn.MSELoss()
            optimizer = optim.Adam(model.parameters())
            device = 'cuda'
            model = model.to(device)
            print('Realistic hindsight: {}, Delay: {}'.format(realistic_hindsight, delay))

            def train_epoch(update_params=True):
                if update_params:
                    results = execute_epoch(train_batch, train_dataloader, model, loss_fn, optimizer, device,
                                            batch_metric_fns=metrics, autoencoder=True, grad_clip_val=1.0)
                else:
                    results = execute_epoch(eval_batch, train_dataloader, model, loss_fn, device,
                                            batch_metric_fns=metrics, autoencoder=True)
                for key in results.keys():
                    if not(key in train_results[(realistic_hindsight, delay)].keys()):
                        train_results[(realistic_hindsight, delay)][key] = []
                    train_results[(realistic_hindsight, delay)][key].append(results[key])


            def test_epoch():
                results = execute_epoch(eval_batch, test_dataloader, model, loss_fn, device,
                                        batch_metric_fns=metrics, autoencoder=True)
                for key in results.keys():
                    if not(key in test_results[(realistic_hindsight, delay)].keys()):
                        test_results[(realistic_hindsight, delay)][key] = []
                    test_results[(realistic_hindsight, delay)][key].append(results[key])

            #train_epoch(update_params=False)
            test_epoch()
            for epoch in tqdm(range(num_epochs)):
                train_epoch()
                test_epoch()
            torch.save(model, os.path.join(results_dir, 'trained_model__{}_{}.pt'.format(realistic_hindsight, delay)))

Realistic hindsight: True, Delay: 0


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [16:28<00:00, 98.85s/it]


Realistic hindsight: True, Delay: 1


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [16:33<00:00, 99.32s/it]


Realistic hindsight: True, Delay: 2


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [16:39<00:00, 99.97s/it]


Realistic hindsight: True, Delay: 4


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [16:45<00:00, 100.52s/it]


Realistic hindsight: True, Delay: 8


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [16:44<00:00, 100.48s/it]


Realistic hindsight: True, Delay: 16


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [16:53<00:00, 101.31s/it]


Realistic hindsight: False, Delay: 0


  0%|                                                                                                                                                                                  | 0/10 [00:00<?, ?it/s]

In [None]:
if retrain:
    with open(os.path.join(results_dir, 'train_results.pickle'), 'wb') as F:
        pickle.dump(train_results, F)
    with open(os.path.join(results_dir, 'test_results.pickle'), 'wb') as F:
        pickle.dump(test_results, F)

In [None]:
with open(os.path.join(results_dir, 'train_results.pickle'), 'rb') as F:
    train_results = pickle.load(F)
with open(os.path.join(results_dir, 'test_results.pickle'), 'rb') as F:
    test_results = pickle.load(F)

In [None]:
with open(os.path.join(results_dir, 'delay_sweep.pickle'), 'wb') as F:
    pickle.dump(test_losses, F)

In [None]:
with open(os.path.join(results_dir, 'delay_sweep.pickle'), 'rb') as F:
    test_losses = pickle.load(F)

print(test_losses)
    
plt.plot([np.min(test_losses[(False, d)]) for d in delay_values], color='blue', label='Oracle hindsight')
plt.plot([np.min(test_losses[(True, d)]) for d in delay_values], color='red', label='Realistic hindsight')
assert False

In [None]:
model = LstmModel(delay=1, realistic_hindsight=True)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters())
device = 'cuda'
model = model.to(device)

print('Model:', model)
print('Loss function:', loss_fn)
print('Optimizer:', optimizer)
print('Device:', device)

In [None]:
train_results = {}
test_results = {}

def train_epoch(update_params=True):
    if update_params:
        results = execute_epoch(train_batch, train_dataloader, model, loss_fn, optimizer, device,
                                batch_metric_fns=metrics, autoencoder=True, grad_clip_val=1.0)
    else:
        results = execute_epoch(eval_batch, train_dataloader, model, loss_fn, device,
                                batch_metric_fns=metrics, autoencoder=True)
    for key in results.keys():
        if not(key in train_results):
            train_results[key] = []
        train_results[key].append(results[key])

def test_epoch():
    results = execute_epoch(eval_batch, test_dataloader, model, loss_fn, device,
                            batch_metric_fns=metrics, autoencoder=True)
    for key in results.keys():
        if not(key in test_results):
            test_results[key] = []
        test_results[key].append(results[key])

if retrain:
    train_epoch(update_params=False)
    test_epoch()
    for epoch in tqdm(range(num_epochs)):
        train_epoch()
        test_epoch()

In [None]:
if retrain:
    os.makedirs(results_dir, exist_ok=True)
    with open(os.path.join(results_dir, 'train_results.pickle'), 'wb') as F:
        pickle.dump(train_results, F)
    with open(os.path.join(results_dir, 'test_results.pickle'), 'wb') as F:
        pickle.dump(test_results, F)
    torch.save(model, os.path.join(results_dir, 'trained_model.pt'))

In [None]:
with open(os.path.join(results_dir, 'train_results.pickle'), 'rb') as F:
    train_results = pickle.load(F)
with open(os.path.join(results_dir, 'test_results.pickle'), 'rb') as F:
    test_results = pickle.load(F)
model = torch.load(os.path.join(results_dir, 'trained_model.pt'))

In [None]:
for k in train_results.keys():
    train_results[k] = [np.mean(x) for x in train_results[k]]
for k in test_results.keys():
    test_results[k] = [np.mean(x) for x in test_results[k]]
plt.plot(train_results['loss'])
plt.plot(test_results['loss'])
plt.yscale('log')

In [None]:
from display_results import plot_autoencoder_traces
traces = next(iter(test_dataloader))[0][:4]
plot_autoencoder_traces(traces, model, device)