In [1]:
from utils.data_lightning import preloading
from utils.data_lightning import otf
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib as mat
import matplotlib as mpl
import argparse
import torch as th
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import os
import sys
from datetime import datetime
import torch.optim as optim
import time
from models.ae import seq2seq_ConvLSTM
import random

In [2]:


# -------------- Functions

def rss_loss(input, target):
    return th.sum((target - input) ** 2)

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
   
# -------------- Setting up the run

num_run = len(os.listdir("runs/")) + 1
now = datetime.now()
foldername = "train_{}_{}".format(num_run, now.strftime("%d_%m_%Y_%H_%M_%S"))
os.mkdir("runs/" + foldername)
weights_path = "runs/" + foldername + "/model.weights"

print("[!] Session folder: {}".format("runs/" + foldername))

writer = SummaryWriter("runs/" + foldername)

# -------------------------------
plotsize = 15

dataset = preloading.SWEDataModule(
    root="../datasets/arda/old_256/",
    test_size=0.1,
    val_size=0,
    past_frames=4,
    future_frames=1,
    partial=None,
    filtering=True,
    batch_size=4,
    workers=4,
    image_size=192,
    shuffle=False,
    dynamicity=2e-1,
    caching=False,
    downsampling=True
)

dataset.prepare_data()

[!] Session folder: runs/train_110_31_08_2021_12_36_34
[x] 24 areas found
Area 0 - sequences: 35
- - - - - - x x x x x x x x x x - - - x x x x x x x x x x x - - - - - 
[4%] 21 valid sequences loaded
Area 1 - sequences: 35
- - - - - - x x x x x x x x x x - - - x x x x x x x x x x x - - - - - 
[8%] 21 valid sequences loaded
Area 2 - sequences: 35
- - - - - - x x x x x x x x x x - - - x x x x x x x x x x x - - - - - 
[12%] 21 valid sequences loaded
Area 3 - sequences: 35
- - - - - - x x x x x x x x x x - - - x x x x x x x x x x x - - - - - 
[17%] 21 valid sequences loaded
Area 4 - sequences: 35
- - - - - - x x x x x x x x x x - - - x x x x x x x x x x x - - - - - 
[21%] 21 valid sequences loaded
Area 5 - sequences: 35
- - - - - - x x x x x x x x x x - - - x x x x x x x x x x x - - - - - 
[25%] 21 valid sequences loaded
Area 6 - sequences: 35
- - - - - - x x x x x x x x x x - - - x x x x x x x x x x x - - - - - 
[29%] 21 valid sequences loaded
Area 7 - sequences: 35
- - - - - - x x x x x x

In [3]:
# ---- Model
net = seq2seq_ConvLSTM.EncoderDecoderConvLSTM(nf=32, in_chan=4, out_chan=3)

# Parallelism
if th.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
device = th.device(dev)

if th.cuda.device_count() > 1:
  print("[!] Yay! Using ", th.cuda.device_count(), "GPUs!")
  net = nn.DataParallel(net)

net = net.to(device)

# ---- Training time!
optimizer = optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-2) # L2, Ridge Regression
# L1 Lasso Regression --> https://medium.com/analytics-vidhya/understanding-regularization-with-pytorch-26a838d94058
losses = []
avg_losses = []
errors = []
test_errors = []
print("\n[!] It's training time!")

epochs = 200
plot_graph = False



[!] It's training time!


In [4]:
def accuracy(prediction, target, threshold = 1e-2):

    total = (target * prediction).cpu().detach().numpy()
    total = np.array(total > 0).astype(int) # TP + TN + FP + FN

    diff = np.abs((target - prediction).cpu().detach().numpy())
    correct_cells = (diff < threshold).astype(int)
    correct_cells = correct_cells*total # TP + TN

    accuracy = np.sum(correct_cells)/np.sum(total)
    return accuracy

In [5]:
for epoch in range(epochs):  # loop over the dataset multiple times

    print("---- Epoch {}".format(epoch))
    epoch_start = time.time()
    training_times = []
    query_times = []
    query_start = time.time()

    iter_dataset = tqdm(enumerate(dataset.train_dataloader()), file=sys.stdout)
    for i, batch in iter_dataset:
        query_end = time.time()
        query_times.append(query_end-query_start)

        x, y = batch

        optimizer.zero_grad()

        x = x.float().to(device)
        y = y.float().to(device)

        # first time plot the graph
        if not plot_graph:
            writer.add_graph(net, x)
            writer.close()
            plot_graph = True

        # ---- Predicting
        start = time.time()
        outputs = net(x, 1)  # 0 for layer index, 0 for h index

        # ---- Batch Loss
        loss = rss_loss(outputs[:, :3, 0, :, :], y[:, 0, :3, :, :])
        acc = accuracy(outputs[:, :3, 0, :, :], y[:, 0, :3, :, :], threshold=1e-1)

        loss.backward()
        optimizer.step()

        end = time.time()
        training_times.append(end - start)

        losses.append(loss.item())
        query_start = time.time()
        
        # ----- Testing
        size = len(dataset.datasets[1] )
        random_test_batch = dataset.datasets[1][random.randint(0, size-1)]
        x_test, y_test = batch
        x_test = x_test.float().to(device)
        y_test = y_test.float().to(device)
        test_outputs = net(x_test, 1) 
        test_acc = accuracy(test_outputs[:, :3, 0, :, :], y_test[:, 0, :3, :, :], threshold=1e-1)

        writer.add_scalar('test_accuracy',
                      test_acc,
                      epoch*len(dataset.train_dataloader())+i)

        writer.add_scalar('train_accuracy',
                          acc,
                          epoch*len(dataset.train_dataloader())+i)

        # Plot values
        if i % 3:
            writer.add_scalar('avg training loss',
                              np.mean(losses),
                              epoch)
            
        iter_dataset.set_postfix(
            loss=np.mean(losses),
            train_acc=acc,
            test_acc=test_acc,
            fwd_time=np.mean(training_times),
            query_time=np.mean(query_times)
        )

    epoch_end = time.time()
    print("\navg.loss {:.2f}\ttook {:.2f} s\tavg. inference time {:.2f} s\tavg.query time/batch {:.2f} s"
          .format(epoch, np.mean(losses), epoch_end-epoch_start, np.mean(training_times), np.mean(query_times)))
    avg_losses.append(np.mean(losses))

    # checkpoint weights
    th.save(net.state_dict(), weights_path)

---- Epoch 0
0it [00:00, ?it/s]

  for t in range(seq_len):


109it [00:07, 14.99it/s, fwd_time=0.0145, loss=6.54e+3, query_time=0.0165, test_acc=0.177, train_acc=0.152] 

avg.loss 0.00	took 6539.29 s	avg. inference time 7.35 s	avg.query time/batch 0.01 s
---- Epoch 1
109it [00:05, 19.56it/s, fwd_time=0.0141, loss=4.46e+3, query_time=0.0161, test_acc=0.4, train_acc=0.398]  

avg.loss 1.00	took 4459.56 s	avg. inference time 5.67 s	avg.query time/batch 0.01 s
---- Epoch 2
109it [00:05, 20.19it/s, fwd_time=0.014, loss=3.45e+3, query_time=0.0155, test_acc=0.45, train_acc=0.457]  

avg.loss 2.00	took 3453.52 s	avg. inference time 5.50 s	avg.query time/batch 0.01 s
---- Epoch 3
109it [00:05, 19.88it/s, fwd_time=0.0142, loss=2.84e+3, query_time=0.0154, test_acc=0.481, train_acc=0.471]

avg.loss 3.00	took 2842.58 s	avg. inference time 5.57 s	avg.query time/batch 0.01 s
---- Epoch 4
109it [00:05, 19.11it/s, fwd_time=0.0145, loss=2.42e+3, query_time=0.0165, test_acc=0.493, train_acc=0.491]

avg.loss 4.00	took 2419.04 s	avg. inference time 5.80 s	avg.query 

KeyboardInterrupt: 

In [None]:

end = time.time()
print(end - start)

print('[!] Finished Training, storing final weights...')

# Loss plot
mpl.rcParams['text.color'] = 'k'

plt.title("average loss")
plt.plot(range(len(avg_losses)), avg_losses)
plt.savefig("runs/" + foldername + "/avg_loss.png")
plt.clf()

print("Avg.training time: {}".format(np.mean(training_times)))
