In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import pathlib
from sklearn.cluster import KMeans

import common.loss_utils as loss_utils

import sys
sys.path.append('./utae-paps')
from src.backbones import utae_mod
from src.learning.weight_init import weight_init


In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [4]:
FORECAST = 17
BATCH_SIZE = 1

In [None]:
p = pathlib.Path('data/data_random_300.npz')
f = np.load(p)
times = f['times']
data = f['data']

In [None]:
_MEDIAN_PIXEL = 212.0
_IQR = 213.0

deltas = np.linspace(-2.0, 2.0, num=81).reshape(-1,1)
_KM = KMeans()
_KM.cluster_centers_ = deltas
_KM._n_threads = 1

def transform(x):
    return np.tanh((x - _MEDIAN_PIXEL) / _IQR)

def transform_y(y, starter):
    y = transform(y)
    y = y - starter
    y_grouped = _KM.predict(y.reshape(-1,1))
    y_grouped = y_grouped.reshape(y.shape)
    return y_grouped

def check_times(tstart, tend):
    # check_times(times[0], times[35])
    return int((tend - tstart) / np.timedelta64(1, 'm')) == 175

class CustomTensorDataset(torch.utils.data.Dataset):
    """TensorDataset with support of transforms.
    """
    def __init__(self, times, data, random_state=7):
        self.times = times
        self.data = data
        self.generator = np.random.RandomState(random_state)

    def _get_crop(self, input_slice, target_slice):
        # roughly over the mainland UK
        rand_x = self.generator.randint(0, input_slice.shape[2] - 128)
        rand_y = self.generator.randint(0, input_slice.shape[1] - 128)

        # make a data selection
        in_crop = input_slice[:, rand_y : rand_y + 128, rand_x : rand_x + 128]
        target_crop = target_slice[
            :, rand_y + 32 : rand_y + 96, rand_x + 32 : rand_x + 96
        ]

        return in_crop, target_crop

    def __getitem__(self, index):
        tend = self.times[index + 35]
        tstart = self.times[index]
        if not check_times(tstart, tend):
            return self.__getitem__((index + 35) % len(self))
        src = data[index:index+12]
        trg = data[index+12:index+36]
        x, y = self._get_crop(src, trg)
        y = y[:FORECAST] # chop forecast
        x = transform(x)
        x_last = x[-1]
        x = x - x_last
        y = transform_y(y, x_last[32:96,32:96])
        return x, y, x_last

    def __len__(self):
        return len(self.times) - 35

ds = CustomTensorDataset(times, data)
dl = torch.utils.data.DataLoader(
    ds,
    shuffle=True,
    batch_size=BATCH_SIZE,
#     num_workers=1,
#     prefetch_factor=1
)


In [None]:
ex_x, ex_y, x_last = ds[0]

In [None]:
ex_x.shape, ex_y.shape, x_last.shape

In [None]:
plt.imshow(ex_x[-2], cmap='gray')

In [None]:
plt.imshow(ex_y[0], cmap='gray')

In [None]:
model = utae_mod.UTAE(
    forecast_steps=FORECAST,
    input_dim=1, # 10 for paper
    encoder_widths=[64, 64, 64, 128],
    decoder_widths=[32, 32, 64, 128],
    out_conv=[32, len(deltas)],
    str_conv_k=4,
    str_conv_s=2,
    str_conv_p=1,
    agg_mode="att_group",
    encoder_norm="group",
    n_head=16,
    d_model=256,
    d_k=4,
    encoder=False,
    return_maps=False,
    pad_value=None, # 0
    padding_mode="reflect",
)

model = model.apply(weight_init)
model = model.to(DEVICE)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model)} trainable parameters')

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def train_epoch(model, epoch, dl, optimizer, criterion):
    model.train()
    optimizer.zero_grad()

    epoch_loss = 0
    total_count = 0
    pbar = tqdm.tqdm(dl)
    optimizer.zero_grad()
    
    true_batch_size = 8
    for i, (srcs, trgs, _) in enumerate(pbar):
        x = srcs.float().to(DEVICE)
        y = trgs.long().to(DEVICE)
        x = torch.unsqueeze(x, dim=2)

        preds = model(x)
        b, t, k, w, h = preds.shape
        preds_flat = preds.reshape(b*t, k, w, h)
        y_flat = y.reshape(b*t, w, h)
        loss = criterion(preds_flat, y_flat)
        loss.backward()
        
        if i % true_batch_size == true_batch_size - 1:
            # gradients are accumulated until here
            optimizer.step()
            optimizer.zero_grad()
      
        ls = loss.item()
        epoch_loss += ls
        total_count += len(srcs)
        if i % 1 == 0:
            ls = round(ls/len(srcs), 4)
            avg_loss = round(epoch_loss / total_count, 4)
            pbar.set_description(f'Avg Loss, Batch Loss: {avg_loss, ls}')
    
    # do a final update
    optimizer.step()
    optimizer.zero_grad()
    return epoch_loss / total_count


In [None]:
EPOCHS = 100
EXISTING = 0

for i in range(EXISTING + 1, EPOCHS + 1):
    print(f"Epoch {i}")
    avg_loss = train_epoch(model, i, dl, optimizer, criterion)
    avg_loss = round(avg_loss, 4)
    torch.save(model.state_dict(), f'weights/300d_imagediff_forecast=17/utae_epochs={i}_loss={avg_loss}.pt')