In [None]:
from utils.data_lightning import preloading
from utils.data_lightning import otf
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib as mat
import argparse
import torch as th
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import os
import sys
from datetime import datetime
import matplotlib as mpl
import torch.optim as optim
import time
from models.ae import seq2seq_ConvLSTM

mat.use("Agg") # headless mode

# -------------- Functions

def rss_loss(input, target):
    return th.sum((target - input) ** 2)

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
   
# -------------- Setting up the run

num_run = len(os.listdir("runs/")) + 1
now = datetime.now()
foldername = "train_{}_{}".format(num_run, now.strftime("%d_%m_%Y_%H_%M_%S"))
os.mkdir("runs/" + foldername)
weights_path = "runs/" + foldername + "/model.weights"

print("[!] Session folder: {}".format("runs/" + foldername))

writer = SummaryWriter("runs/" + foldername)

# -------------------------------
plotsize = 15

dataset = preloading.SWEDataModule(
    root="../datasets/arda/256/",
    test_size=0,
    val_size=0,
    past_frames=4,
    future_frames=1,
    partial=None,
    filtering=True,
    batch_size=4,
    workers=4,
    image_size=256,
    shuffle=False,
    dynamicity=2e-1,
    caching=False,
    downsampling=False
)

dataset.prepare_data()

[!] Session folder: runs/train_83_30_08_2021_18_12_25
[x] 20 areas found
Area 0 - sequences: 245
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - x x x x x - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - x x x x x x x x x x x x x x - - - - - - - - - - - - - - - - - - - - - - x x x x x x x x x x x x - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
[5%] 31 valid sequences loaded
Area 1 - sequences: 245
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

In [None]:
# ---- Model
net = seq2seq_ConvLSTM.EncoderDecoderConvLSTM(nf=32, in_chan=4, out_chan=3)

# Parallelism
if th.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
device = th.device(dev)

if th.cuda.device_count() > 1:
  print("[!] Yay! Using ", th.cuda.device_count(), "GPUs!")
  net = nn.DataParallel(net)

net = net.to(device)

# ---- Training time!
optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5) # L2, Ridge Regression
# L1 Lasso Regression --> https://medium.com/analytics-vidhya/understanding-regularization-with-pytorch-26a838d94058
losses = []
avg_losses = []
errors = []
test_errors = []
print("\n[!] It's training time!")

epochs = 200
plot_graph = False


In [None]:
def accuracy(prediction, target, threshold = 1e-2):

    total = (target * prediction).cpu().detach().numpy()
    total = np.array(total > 0).astype(int) # TP + TN + FP + FN

    diff = np.abs((target - prediction).cpu().detach().numpy())
    correct_cells = (diff < threshold).astype(int)
    correct_cells = correct_cells*total # TP + TN

    accuracy = np.sum(correct_cells)/np.sum(total)
    return accuracy

In [None]:
for epoch in range(epochs):  # loop over the dataset multiple times

    print("---- Epoch {}".format(epoch))
    epoch_start = time.time()
    training_times = []
    query_times = []
    query_start = time.time()

    iter_dataset = tqdm(enumerate(dataset.train_dataloader()), file=sys.stdout)
    for i, batch in iter_dataset:
        query_end = time.time()
        query_times.append(query_end-query_start)

        x, y = batch

        optimizer.zero_grad()

        x = x.float().to(device)
        y = y.float().to(device)

        # first time plot the graph
        if not plot_graph:
            writer.add_graph(net, x)
            writer.close()
            plot_graph = True

        # ---- Predicting
        start = time.time()
        outputs = net(x, 1)  # 0 for layer index, 0 for h index

        # ---- Batch Loss
        loss = rss_loss(outputs[:, :3, 0, :, :], y[:, 0, :3, :, :])
        acc = accuracy(outputs[:, :3, 0, :, :], y[:, 0, :3, :, :], threshold=1e-1)

        loss.backward()
        optimizer.step()

        end = time.time()
        training_times.append(end - start)

        losses.append(loss.item())
        query_start = time.time()

        iter_dataset.set_postfix(
            loss=np.mean(losses),
            accuracy=acc,
            fwd_time=np.mean(training_times),
            query_time=np.mean(query_times)
        )

        writer.add_scalar('accuracy',
                          acc,
                          epoch*len(dataset.train_dataloader())+i)

        if i % 3:
            writer.add_scalar('avg training loss',
                              np.mean(losses),
                              epoch)

    epoch_end = time.time()
    print("\navg.loss {:.2f}\ttook {:.2f} s\tavg. inference time {:.2f} s\tavg.query time/batch {:.2f} s"
          .format(epoch, np.mean(losses), epoch_end-epoch_start, np.mean(training_times), np.mean(query_times)))
    avg_losses.append(np.mean(losses))

    # checkpoint weights
    th.save(net.state_dict(), weights_path)

In [None]:

end = time.time()
print(end - start)

print('[!] Finished Training, storing final weights...')

# Loss plot
mpl.rcParams['text.color'] = 'k'

plt.title("average loss")
plt.plot(range(len(avg_losses)), avg_losses)
plt.savefig("runs/" + foldername + "/avg_loss.png")
plt.clf()

print("Avg.training time: {}".format(np.mean(training_times)))
