## Configure environment

In [None]:
%%capture
%pip install torch torchvision h5py xarray matplotlib netcdf4

In [None]:
import random
import time
import os
import datetime
import itertools
import math
import sys
from random import randint

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.v2 as v2
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.utils.data as data
from torch.utils.data import DataLoader
import numpy as np
import h5py
import PIL.Image
from IPython.core import display as idisplay
import matplotlib.pyplot as plt
import xarray as xr

from vae import VAE, vae_loss_fn
sys.path.append('../common')
import common

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
common.set_memory_limit_if_not_limit(1.5)

In [None]:
os.makedirs('save', exist_ok=True)

## Load data

In [None]:
cwd = os.getcwd()
pardir = os.path.dirname(os.path.dirname(cwd))
data_folder = os.path.join(pardir,'data')
data_path = os.path.join(data_folder,'video_prediction_dataset.hdf5')
mask_path = os.path.join('assets','mask_black_skygptdata.png')
model_name = 'CNNVAE'
output_path = os.path.join(cwd,"save", f"{model_name}.torch")
with h5py.File(data_path, 'r') as fds:
    group_names = list(fds.keys())
    print(group_names)

dss = {}
for gname in group_names:
    dss[gname] = xr.open_dataset(data_path, group=gname)

In [None]:
print(dss['test']['images_log'])
print(dss['test']['images_log'].shape)

In [None]:
mask_png = np.array(PIL.Image.open(mask_path).resize((64, 64)).getdata())
# Mask for size 64
print(mask_png.shape)
mask_to_black = (mask_png[:,3] > 127)
not_mask_to_black = (mask_png[:,3] <= 127)
print(mask_to_black.sum()) # alpha = 1 = black visible
print(not_mask_to_black.sum())
print(not_mask_to_black.sum()+mask_to_black.sum())
mask_to_black = mask_to_black.reshape((64, 64))
empty_mask = np.ones((64, 64, 3))
mask_to_black = np.where(np.expand_dims(mask_to_black, 2), empty_mask, empty_mask*0)


In [None]:
bs = 256 # batch size
transform = transforms.Compose([
    transforms.Lambda(lambda x: np.where(mask_to_black, x*0, x)),
    #v2.functional.adjust_contrast(),
    transforms.ToPILImage(), # This already normalizes the image
    transforms.Lambda(v2.functional.autocontrast),
    transforms.Resize(64),
    transforms.ToTensor(),
    #transforms.Lambda(lambda x: x.float()),
    #transforms.Normalize(mean=[0.5], std=[0.5])  # Example normalization
])
# Create Dataset and DataLoader
train_dataset = common.VideoDataset(dss['trainval']['images_log'], dss['trainval']['images_pred'], transform=transform, stack_videos=True)
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)

test_dataset = common.VideoDataset(dss['test']['images_log'], dss['test']['images_pred'], transform=transform, stack_videos=True)
test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

print(f"Number of videos: {len(train_dataset.videos)}.")
print(f"Number of video batches: {len(train_loader)}")
print(f"Size of video batches: {bs}.")

In [None]:
# Fixed input for debugging
fixed_x = next(iter(train_loader))
print(fixed_x[0].shape) # 0 because VideoDataset has batch_size videos, not images.
print(f"Number of images per video: {fixed_x[0].shape[0]}")
print(f"Number of images per batch: {fixed_x[0].shape[0]*bs}")
torchvision.utils.save_image(fixed_x[0], 'save/real_image.png')
idisplay.Image('save/real_image.png')

In [None]:
#cnn = torch.nn.Conv2d()
#rnn = torch.nn.LSTM()

In [None]:
image_channels = fixed_x[0].size(1)

In [None]:
z_dim = 64

In [None]:
#vae = VAE(image_channels=image_channels).to(device)
vae = VAE(image_channels=image_channels, z_dim=z_dim).to(device)
last_epoch = 0
if os.path.exists(output_path):
    vae.load_state_dict(torch.load(output_path, map_location='cpu'))
    last_epoch = 8
else:
    print('No states loaded')

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(vae))

In [None]:
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) 

In [None]:
epochs = 50

In [None]:
for epoch in range(last_epoch, epochs):
    vae.train()
    for idx, videos in enumerate(train_loader):
        images = torch.stack([img[randint(0, len(img)-1)] for img in videos])
        # Only taking the first into account, all images of videos are very similar
        #images = videos.flatten(0,1)
        recon_images, mu, logvar = vae(images)
        loss, bce, kld = vae_loss_fn(recon_images, images, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx % 10 == 0:
            to_print = (
                f"Epoch[{epoch+1}/{epochs}] B[{idx+1}/{len(train_loader)}] Loss: {loss.data.item()/bs:.4g} "
                f"{bce.data.item()/bs:.4g} {kld.data.item()/bs:.3g}"
            )
            print(to_print)
    torch.save(vae.state_dict(), output_path)


In [None]:
def compare(x):
    recon_x, _, _ = vae(x)
    return torch.cat([x, recon_x])

In [None]:
def show(data):
    torchvision.utils.save_image(data, 'save/sample_image.png')
    display(idisplay.Image('save/sample_image.png', width=700, unconfined=True))

In [None]:
vae.eval()
fixed_x = train_dataset[randint(1, 5000)][0].unsqueeze(0)
compare_x = compare(fixed_x)

show(compare_x.data.cpu())

In [None]:
vae.eval()
fixed_x = test_dataset[randint(0, 4466)][0].unsqueeze(0)
compare_x = compare(fixed_x)

show(compare_x.data.cpu())

## LSTM

In [None]:
class CNNLSTM(nn.Module):
    def __init__(self, cnnvae: VAE, z_dim, hidden_dim, num_layers=1):
        super(CNNLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.cnnvae = cnnvae
        self.lstm_enc = nn.LSTM(z_dim, hidden_dim, num_layers, batch_first=True)
        #input_size??
        input_size=z_dim
        self.lstm_dec = nn.LSTM(input_size, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, z_dim)


    def forward(self, video, target_len: int):
        seq_dims = video.shape[0:2]
        video = video.flatten(0,1)
        zvects, _, _ = self.cnnvae.encode(video)
        zvects = torch.unflatten(zvects, 0, seq_dims)
        ## LSTM
        h_0 = torch.zeros(self.num_layers, zvects.size(0), self.hidden_dim).to(zvects.device)  # Initialize hidden state
        c_0 = torch.zeros(self.num_layers, zvects.size(0), self.hidden_dim).to(zvects.device)  # Initialize cell state
        _, (h, c) = self.lstm_enc(zvects, (h_0, c_0))
        batch_size = zvects.size(0)
        target_size = zvects.size(2)
        lstm_dec_input = torch.zeros((batch_size, 1, target_size)).to(zvects.device)
        lstm_out = torch.zeros((batch_size, target_len, target_size)).to(zvects.device)
        for t in range(target_len):
            out, (h, c) = self.lstm_dec(lstm_dec_input, (h, c))
            out = self.fc(out)
            lstm_out[:, t, :] = out.squeeze(1)
            lstm_dec_input = out
        #lstm_out = self.fc(lstm_out)
        lstm_out = lstm_out.flatten(0,1)
        ##
        predicted = self.cnnvae.decode(lstm_out)
        predicted = torch.unflatten(predicted, 0, (seq_dims[0], target_len))
        return predicted

    def parameters(self):
        return itertools.chain(self.lstm_enc.parameters(), self.lstm_dec.parameters(), self.fc.parameters())

In [None]:
model_name = 'CNNLSTM'
output_path = os.path.join(cwd,"save", f"{model_name}.torch")

cnnvae = vae
cnnvae.eval()
lstm = CNNLSTM(cnnvae, z_dim, 128, 5)


last_epoch = 0
if os.path.exists(output_path):
    lstm.load_state_dict(torch.load(output_path, map_location='cpu'))
    last_epoch = 1
else:
    print('No states loaded')

for param in lstm.cnnvae.parameters():
    param.requires_grad = False

In [None]:
bs = 32
# Recreate Dataset and DataLoader, now without stacking
train_dataset = common.VideoDataset(dss['trainval']['images_log'], dss['trainval']['images_pred'], transform=transform, stack_videos=False)
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)

test_dataset = common.VideoDataset(dss['test']['images_log'], dss['test']['images_pred'], transform=transform, stack_videos=False)
test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

In [None]:
with torch.no_grad():
    input_video, target_video = train_dataset[randint(1, 5000)]
    target_fuzzy, _, _ = lstm.cnnvae(target_video)
    pred = lstm(input_video.unsqueeze(0), 15).squeeze(0)
    show(pred.data.cpu())
    show(target_fuzzy.data.cpu())

In [None]:
loss_function = lambda o,t: F.binary_cross_entropy(o,t, reduction='sum')#NLLLoss()
optimizer = torch.optim.Adam(lstm.parameters(), lr=1e-3) 
#optimizer = torch.optim.SGD(lstm.parameters(), lr=0.1)
epochs = 30

for epoch in range(epochs):  # again, normally you would NOT do 300 epochs, it is toy data
    lstm.train()
    lstm.cnnvae.eval()
    for idx, (videos, targets) in enumerate(train_loader):
        #lstm.lstm.zero_grad()
        outputs = lstm(videos, targets.size(1))
        seq_dims = targets.shape[0:2]
        targets = targets.flatten(0,1)
        targets, _, _ = lstm.cnnvae(targets)
        targets = torch.unflatten(targets, 0, seq_dims)
        loss = loss_function(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx % 10 == 0:
            to_print = (
                f"Epoch[{epoch+1}/{epochs}] B[{idx+1}/{len(train_loader)}]"
                f" Loss: {loss.data.item()/bs:.4g} "
            )
            print(to_print)
    torch.save(vae.state_dict(), output_path)

In [None]:
with torch.no_grad():
    input_video, target_video = train_dataset[randint(1, 5000)]
    target_fuzzy, _, _ = lstm.cnnvae(target_video)
    pred = lstm(input_video.unsqueeze(0), 15).squeeze(0)
    show(pred.data.cpu())
    show(target_fuzzy.data.cpu())