## Configure environment

In [None]:
%%capture
%pip install torch torchvision h5py xarray matplotlib netcdf4

In [None]:
import random
import time
import os
import datetime
import itertools
import math
import sys
from random import randint

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.utils.data as data
from torch.utils.data import DataLoader
import numpy as np
import h5py
import PIL
from IPython.core import display as idisplay
import matplotlib.pyplot as plt
import xarray as xr

from vae import VAE, vae_loss_fn
sys.path.append('../common')
import common

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
common.set_memory_limit_if_not_limit(1.5)

In [None]:
os.makedirs('save', exist_ok=True)

## Load data

In [None]:
cwd = os.getcwd()
pardir = os.path.dirname(os.path.dirname(cwd))
data_folder = os.path.join(pardir,'data')
data_path = os.path.join(data_folder,'video_prediction_dataset.hdf5')
model_name = 'CNNLSTM'
output_path = os.path.join(cwd,"save", f"{model_name}.torch")
with h5py.File(data_path, 'r') as fds:
    group_names = list(fds.keys())
    print(group_names)

dss = {}
for gname in group_names:
    dss[gname] = xr.open_dataset(data_path, group=gname)

In [None]:
print(dss['test']['images_log'])
print(dss['test']['images_log'].shape)

In [None]:
bs = 64 # batch size
transform = transforms.Compose([
    transforms.ToPILImage(), # This already normalizes the image
    transforms.Resize(64),
    transforms.ToTensor(),
    #transforms.Lambda(lambda x: x.float()),
    #transforms.Normalize(mean=[0.5], std=[0.5])  # Example normalization
])
# Create Dataset and DataLoader
train_dataset = common.VideoDataset(dss['trainval']['images_log'], dss['trainval']['images_pred'], transform=transform, stack_videos=True)
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)

test_dataset = common.VideoDataset(dss['test']['images_log'], dss['test']['images_pred'], transform=transform)
test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

print(f"Number of videos: {len(train_dataset.videos)}.")
print(f"Number of video batches: {len(train_loader)}")
print(f"Size of video batches: {bs}.")

In [None]:
# Fixed input for debugging
fixed_x = next(iter(train_loader))
print(fixed_x[0].shape) # 0 because VideoDataset has batch_size videos, not images.
print(f"Number of images per video: {fixed_x[0].shape[0]}")
print(f"Number of images per batch: {fixed_x[0].shape[0]*bs}")
torchvision.utils.save_image(fixed_x[0], 'save/real_image.png')
idisplay.Image('save/real_image.png')

In [None]:
#cnn = torch.nn.Conv2d()
#rnn = torch.nn.LSTM()

In [None]:
image_channels = fixed_x[0].size(1)

In [None]:
vae = VAE(image_channels=image_channels).to(device)
if os.path.exists(output_path):
    vae.load_state_dict(torch.load(output_path, map_location='cpu'))
else:
    print('No states loaded')

In [None]:
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) 

In [None]:
epochs = 10

In [None]:
for epoch in range(epochs):
    vae.train()
    for idx, (videos) in enumerate(train_loader):
        images = torch.stack([img[randint(0, len(img)-1)] for img in videos])
        # Only taking the first into account, all images of videos are very similar
        #images = videos.flatten(0,1)
        recon_images, mu, logvar = vae(images)
        loss, bce, kld = vae_loss_fn(recon_images, images, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        to_print = (
            f"Epoch[{epoch+1}/{epochs}] B[{idx+1}/{len(train_loader)}] Loss: {loss.data.item()/bs:.4g} "
            f"{bce.data.item()/bs:.4g} {kld.data.item()/bs:.3g}"
        )
        print(to_print)

torch.save(vae.state_dict(), 'vae.torch')

In [None]:
def compare(x):
    recon_x, _, _ = vae(x)
    return torch.cat([x, recon_x])

In [None]:
vae.eval()
fixed_x = train_dataset[randint(1, 100)][0].unsqueeze(0)
compare_x = compare(fixed_x)

torchvision.utils.save_image(compare_x.data.cpu(), 'sample_image.png')
display(idisplay.Image('sample_image.png', width=700, unconfined=True))