In [1]:
from utils.data_lightning import preloading
from utils.data_lightning import otf
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib as mat
import matplotlib as mpl
import argparse
import torch as th
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import os
import sys
from datetime import datetime
import torch.optim as optim
import time
from models.ae import seq2seq_ConvLSTM
import random

In [None]:


# -------------- Functions

def rss_loss(input, target):
    return th.sum((target - input) ** 2)

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
   
# -------------- Setting up the run

num_run = len(os.listdir("runs/")) + 1
now = datetime.now()
foldername = "train_{}_{}".format(num_run, now.strftime("%d_%m_%Y_%H_%M_%S"))
os.mkdir("runs/" + foldername)
weights_path = "runs/" + foldername + "/model.weights"

print("[!] Session folder: {}".format("runs/" + foldername))

writer = SummaryWriter("runs/" + foldername)

# -------------------------------
plotsize = 15

dataset = preloading.SWEDataModule(
    root="../datasets/arda/old_256/",
    test_size=0.1,
    val_size=0,
    past_frames=4,
    future_frames=1,
    partial=None,
    filtering=True,
    batch_size=4,
    workers=4,
    image_size=192,
    shuffle=False,
    dynamicity=2e-1,
    caching=False,
    downsampling=True
)

dataset.prepare_data()

[!] Session folder: runs/train_110_31_08_2021_12_36_34
[x] 24 areas found
Area 0 - sequences: 35
- - - - - - x x x x x x x x x x - - - x x x x x x x x x x x - - - - - 
[4%] 21 valid sequences loaded
Area 1 - sequences: 35
- - - - - - x x x x x x x x x x - - - x x x x x x x x x x x - - 

In [2]:
# ---- Model
net = seq2seq_ConvLSTM.EncoderDecoderConvLSTM(nf=32, in_chan=4, out_chan=3)

# Parallelism
if th.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
device = th.device(dev)

if th.cuda.device_count() > 1:
  print("[!] Yay! Using ", th.cuda.device_count(), "GPUs!")
  net = nn.DataParallel(net)

net = net.to(device)

# ---- Training time!
optimizer = optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-2) # L2, Ridge Regression
# L1 Lasso Regression --> https://medium.com/analytics-vidhya/understanding-regularization-with-pytorch-26a838d94058
losses = []
avg_losses = []
errors = []
test_errors = []
print("\n[!] It's training time!")

epochs = 200
plot_graph = False



[!] It's training time!


In [3]:
def accuracy(prediction, target, threshold = 1e-2):

    total = (target * prediction).cpu().detach().numpy()
    total = np.array(total > 0).astype(int) # TP + TN + FP + FN

    diff = np.abs((target - prediction).cpu().detach().numpy())
    correct_cells = (diff < threshold).astype(int)
    correct_cells = correct_cells*total # TP + TN

    accuracy = np.sum(correct_cells)/np.sum(total)
    return accuracy

In [7]:
for epoch in range(epochs):  # loop over the dataset multiple times

    print("---- Epoch {}".format(epoch))
    epoch_start = time.time()
    training_times = []
    query_times = []
    query_start = time.time()

    iter_dataset = tqdm(enumerate(dataset.train_dataloader()), file=sys.stdout)
    for i, batch in iter_dataset:
        query_end = time.time()
        query_times.append(query_end-query_start)

        x, y = batch

        optimizer.zero_grad()

        x = x.float().to(device)
        y = y.float().to(device)

        # first time plot the graph
        if not plot_graph:
            writer.add_graph(net, x)
            writer.close()
            plot_graph = True

        # ---- Predicting
        start = time.time()
        outputs = net(x, 1)  # 0 for layer index, 0 for h index

        # ---- Batch Loss
        loss = rss_loss(outputs[:, :3, 0, :, :], y[:, 0, :3, :, :])
        acc = accuracy(outputs[:, :3, 0, :, :], y[:, 0, :3, :, :], threshold=1e-1)

        loss.backward()
        optimizer.step()

        end = time.time()
        training_times.append(end - start)

        losses.append(loss.item())
        query_start = time.time()
        
        # ----- Testing
        size = len(dataset.datasets[1] )
        random_test_batch = dataset.datasets[1][random.randint(0, size-1)]
        x_test, y_test = batch
        x_test = x_test.float().to(device)
        y_test = y_test.float().to(device)
        test_outputs = net(x_test, 1) 
        test_acc = accuracy(test_outputs[:, :3, 0, :, :], y_test[:, 0, :3, :, :], threshold=1e-1)

        writer.add_scalar('test_accuracy',
                      test_acc,
                      epoch*len(dataset.train_dataloader())+i)

        writer.add_scalar('train_accuracy',
                          acc,
                          epoch*len(dataset.train_dataloader())+i)

        # Plot values
        if i % 3:
            writer.add_scalar('avg training loss',
                              np.mean(losses),
                              epoch)
            
        iter_dataset.set_postfix(
            loss=np.mean(losses),
            train_acc=acc,
            test_acc=test_acc,
            fwd_time=np.mean(training_times),
            query_time=np.mean(query_times)
        )

    epoch_end = time.time()
    print("\navg.loss {:.2f}\ttook {:.2f} s\tavg. inference time {:.2f} s\tavg.query time/batch {:.2f} s"
          .format(epoch, np.mean(losses), epoch_end-epoch_start, np.mean(training_times), np.mean(query_times)))
    avg_losses.append(np.mean(losses))

    # checkpoint weights
    th.save(net.state_dict(), weights_path)

---- Epoch 0
10it [00:00, 17.18it/s, fwd_time=0.0158, loss=6.5e+3, query_time=0.0244, test_acc=0.395, train_acc=0.395]

avg.loss 0.00	took 6502.57 s	avg. inference time 0.67 s	avg.query time/batch 0.02 s
---- Epoch 1
10it [00:00, 18.33it/s, fwd_time=0.0142, loss=6.18e+3, query_time=0.025, test_acc=0.398, train_acc=0.408]

avg.loss 1.00	took 6179.71 s	avg. inference time 0.65 s	avg.query time/batch 0.01 s
---- Epoch 2
10it [00:00, 18.71it/s, fwd_time=0.014, loss=5.88e+3, query_time=0.0239, test_acc=0.393, train_acc=0.394]

avg.loss 2.00	took 5880.99 s	avg. inference time 0.63 s	avg.query time/batch 0.01 s
---- Epoch 3
10it [00:00, 18.30it/s, fwd_time=0.0145, loss=5.6e+3, query_time=0.0232, test_acc=0.389, train_acc=0.392]

avg.loss 3.00	took 5602.27 s	avg. inference time 0.64 s	avg.query time/batch 0.01 s
---- Epoch 4
10it [00:00, 18.34it/s, fwd_time=0.0137, loss=5.34e+3, query_time=0.0258, test_acc=0.372, train_acc=0.392]

avg.loss 4.00	took 5343.34 s	avg. inference time 0.64 s	avg.que

---- Epoch 40
10it [00:00, 18.90it/s, fwd_time=0.0147, loss=2.24e+3, query_time=0.0242, test_acc=0.447, train_acc=0.45]

avg.loss 40.00	took 2244.78 s	avg. inference time 0.63 s	avg.query time/batch 0.01 s
---- Epoch 41
10it [00:00, 17.68it/s, fwd_time=0.0153, loss=2.21e+3, query_time=0.0248, test_acc=0.446, train_acc=0.455]

avg.loss 41.00	took 2212.98 s	avg. inference time 0.67 s	avg.query time/batch 0.02 s
---- Epoch 42
10it [00:00, 17.40it/s, fwd_time=0.0159, loss=2.18e+3, query_time=0.0257, test_acc=0.45, train_acc=0.462]

avg.loss 42.00	took 2182.19 s	avg. inference time 0.68 s	avg.query time/batch 0.02 s
---- Epoch 43
10it [00:00, 18.37it/s, fwd_time=0.0142, loss=2.15e+3, query_time=0.0238, test_acc=0.456, train_acc=0.461]

avg.loss 43.00	took 2152.36 s	avg. inference time 0.63 s	avg.query time/batch 0.01 s
---- Epoch 44
10it [00:00, 19.31it/s, fwd_time=0.0141, loss=2.12e+3, query_time=0.0235, test_acc=0.462, train_acc=0.456]

avg.loss 44.00	took 2123.45 s	avg. inference time 0.

---- Epoch 80
10it [00:00, 18.05it/s, fwd_time=0.0139, loss=1.46e+3, query_time=0.0239, test_acc=0.548, train_acc=0.529]

avg.loss 80.00	took 1458.54 s	avg. inference time 0.65 s	avg.query time/batch 0.01 s
---- Epoch 81
10it [00:00, 19.37it/s, fwd_time=0.0139, loss=1.45e+3, query_time=0.024, test_acc=0.55, train_acc=0.532] 

avg.loss 81.00	took 1446.53 s	avg. inference time 0.62 s	avg.query time/batch 0.01 s
---- Epoch 82
10it [00:00, 17.45it/s, fwd_time=0.0151, loss=1.43e+3, query_time=0.0256, test_acc=0.552, train_acc=0.532]

avg.loss 82.00	took 1434.74 s	avg. inference time 0.67 s	avg.query time/batch 0.02 s
---- Epoch 83
10it [00:00, 18.58it/s, fwd_time=0.0146, loss=1.42e+3, query_time=0.0219, test_acc=0.549, train_acc=0.535]

avg.loss 83.00	took 1423.16 s	avg. inference time 0.63 s	avg.query time/batch 0.01 s
---- Epoch 84
10it [00:00, 17.81it/s, fwd_time=0.0158, loss=1.41e+3, query_time=0.0269, test_acc=0.553, train_acc=0.532]

avg.loss 84.00	took 1411.78 s	avg. inference time 0

10it [00:00, 15.82it/s, fwd_time=0.0149, loss=904, query_time=0.0372, test_acc=0.598, train_acc=0.616]

avg.loss 159.00	took 904.25 s	avg. inference time 0.84 s	avg.query time/batch 0.01 s
---- Epoch 160
10it [00:00, 16.71it/s, fwd_time=0.016, loss=900, query_time=0.0283, test_acc=0.609, train_acc=0.602]

avg.loss 160.00	took 900.19 s	avg. inference time 0.71 s	avg.query time/batch 0.02 s
---- Epoch 161
10it [00:00, 16.78it/s, fwd_time=0.0157, loss=896, query_time=0.028, test_acc=0.6, train_acc=0.615]  

avg.loss 161.00	took 896.09 s	avg. inference time 0.70 s	avg.query time/batch 0.02 s
---- Epoch 162
10it [00:00, 16.61it/s, fwd_time=0.0154, loss=892, query_time=0.0279, test_acc=0.611, train_acc=0.606]

avg.loss 162.00	took 892.11 s	avg. inference time 0.70 s	avg.query time/batch 0.02 s
---- Epoch 163
10it [00:00, 15.31it/s, fwd_time=0.0188, loss=888, query_time=0.0253, test_acc=0.6, train_acc=0.617] 

avg.loss 163.00	took 888.10 s	avg. inference time 0.75 s	avg.query time/batch 0.02 

In [None]:

end = time.time()
print(end - start)

print('[!] Finished Training, storing final weights...')

# Loss plot
mpl.rcParams['text.color'] = 'k'

plt.title("average loss")
plt.plot(range(len(avg_losses)), avg_losses)
plt.savefig("runs/" + foldername + "/avg_loss.png")
plt.clf()

print("Avg.training time: {}".format(np.mean(training_times)))
