In [None]:
#Imports
import torch
import torch.nn as nn
import torch.cuda.amp as amp
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import cv2 as cv
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import wandb
import io
import imageio
import matplotlib.pyplot as plt
from ipywidgets import widgets, HBox
from PIL import Image
import lpips

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)



In [None]:
wandb.login()
wandb.init(
    # set the wandb project where this run will be logged
    project="convlstm_perc_mse",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.02,
    "architecture": "Conv-LSTM",
    "dataset": "SHMU",
    "epochs": 10,
    }
)

In [None]:
# Original ConvLSTM cell as proposed by Shi et al.
class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTMCell, self).__init__()  

        if activation == "tanh":
            self.activation = torch.tanh 
        elif activation == "relu":
            self.activation = torch.relu
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels, 
            out_channels=4 * out_channels, 
            kernel_size=kernel_size, 
            padding=padding)           

        # Initialize weights for Hadamard Products
        self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
        self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)
        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )

        # Current Cell output
        C = forget_gate*C_prev + input_gate * self.activation(C_conv)

        output_gate = torch.sigmoid(o_conv + self.W_co * C )

        # Current Hidden State
        H = output_gate * self.activation(C)

        return H, C

In [None]:
#from ConvLSTMCell import ConvLSTMCell
class ConvLSTM(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTM, self).__init__()

        self.out_channels = out_channels

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, 
        kernel_size, padding, activation, frame_size)

    def forward(self, X):

        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, self.out_channels, seq_len, 
        height, width, device=device)
        
        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels, 
        height, width, device=device)

        # Initialize Cell Input
        C = torch.zeros(batch_size,self.out_channels, 
        height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):

            H, C = self.convLSTMcell(X[:,:,time_step], H, C)

            output[:,:,time_step] = H

        return output

In [None]:
#from ConvLSTM import ConvLSTM
class Seq2Seq(nn.Module):

    def __init__(self, num_channels, num_kernels, kernel_size, padding, 
    activation, frame_size, num_layers):

        super(Seq2Seq, self).__init__()

        self.sequential = nn.Sequential()

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding, 
                activation=activation, frame_size=frame_size)
        )

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
        ) 

        # Add rest of the layers
        for l in range(2, num_layers+1):

            self.sequential.add_module(
                f"convlstm{l}", ConvLSTM(
                    in_channels=num_kernels, out_channels=num_kernels,
                    kernel_size=kernel_size, padding=padding, 
                    activation=activation, frame_size=frame_size)
                )
                
            self.sequential.add_module(
                f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels)
                ) 

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=kernel_size, padding=padding)

    def forward(self, X):

        # Forward propagation through all the layers
        output = self.sequential(X)

        # Return only the last output frame
        output = self.conv(output[:,:,-1])
        
        return nn.Sigmoid()(output)

In [None]:
class SHMUDataset(Dataset):
    def __init__(self, data_frame, input_frames_length, target_frames_length, minutes):
        # Initialize the dataset with the given parameters
        self.data_frame = data_frame # DataFrame containing image paths
        self.input_frames_length = input_frames_length
        self.target_frames_length = target_frames_length
        self.minutes = minutes # Time difference between images in minutes 
        self.selected_paths = self.data_frame[::self.minutes // 5].iloc[:, 0].tolist() # Select every nth image based on the specified time difference

    def transform(self, image_path):
        # Load and transform an image
        img = cv.imread(image_path)
        # Apply morphological operations
        morph_operator = cv.MORPH_OPEN
        element = cv.getStructuringElement(cv.MORPH_CROSS, (3, 3))
        morphed = cv.morphologyEx(src=img, op=morph_operator, kernel=element, iterations=2)
        # Crop and resize the image
        cropped = morphed[283:1147, 537:2087]
        resized = cv.resize(cropped, (517,288))
        # Convert the image to RGB and then to NumPy array
        image_rgb = cv.cvtColor(resized, cv.COLOR_BGR2RGB)
        image = np.array(image_rgb)
        return image

    def __len__(self):
        return (len(self.selected_paths) - self.input_frames_length - self.target_frames_length)

    def __getitem__(self, idx):  
        np_input_frames = np.stack([self.transform(path) for path in self.selected_paths[idx:idx+self.input_frames_length]] , axis=0)
        np_target_frames = np.stack([self.transform(path) for path in self.selected_paths[idx+self.input_frames_length:idx+self.input_frames_length+self.target_frames_length]] , axis=0)
        
        # Convert to float, and normalize by dividing by 255 to scale pixel values to [0, 1]
        input_frames = torch.from_numpy(np_input_frames.transpose(0,3,1,2)).transpose(0,1).float() / 255.0
        target_frames = torch.from_numpy(np_target_frames.transpose(0,3,1,2)).transpose(0,1).float() / 255.0
         # If there's only one target frame, remove the singleton dimension
        if self.target_frames_length == 1:
            target_frames = target_frames.squeeze(1)
          
        return input_frames, target_frames


In [None]:
# Load Data as Numpy Array
data = pd.read_csv("dataset.csv")
print(data.shape)

# # Train, Test, Validation splits
train_data = SHMUDataset(data[:6400], minutes = 60, input_frames_length = 20, target_frames_length = 1) 
val_data = SHMUDataset(data[6400:7200], minutes = 60, input_frames_length = 20, target_frames_length = 1)      
test_data = SHMUDataset(data[7200:8000], minutes = 60, input_frames_length = 20, target_frames_length = 1)    

# Training Data Loader
train_loader = DataLoader(train_data, shuffle=True, batch_size=4,num_workers=24, generator=torch.Generator(device=device))

# Validation Data Loader
val_loader = DataLoader(val_data, shuffle=True, batch_size=4,num_workers=24, generator=torch.Generator(device=device))

In [None]:
# Visualize frames from a batch of a sequences and their corresponding target
batch,target = next(iter(val_loader))
for seq in batch:
    print(seq.shape)
    for frame in seq.transpose(0,1).transpose(1,2).transpose(2,3):
        plt.imshow(frame)
        plt.show()
    plt.imshow(target[0].transpose(0,1).transpose(1,2))
    plt.show()
    break   

In [None]:
# The input video frames are RGB, thus num_channels = 3
model = Seq2Seq(num_channels=3, num_kernels=32, kernel_size=(3, 3), padding=(1, 1), activation="relu", frame_size=(288, 517), num_layers=1).to(device)

optim = Adam(model.parameters(), lr=1e-4)

loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)

criterion = nn.MSELoss(reduction='mean')

In [None]:
# Set `use_wandb` to True if you want to enable the use of Weights & Biases for experiment tracking and visualization.
use_wandb = False
# Initialize the GradScaler for automatic mixed precision (AMP) training
scaler = amp.GradScaler()
num_epochs = 5
for epoch in range(1, num_epochs+1):
    train_loss = 0
    model.train()
    # Initialize a progress bar for the training loop
    loop = tqdm(train_loader)
    for batch_num, (input, target) in enumerate(loop, 1):
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            output = model(input.to(device))
            output = output.to(device)
            target = target.to(device)
            # Compute the MSE loss
            mse_loss = criterion(output.flatten(), target.flatten())
            # Compute the perceptual loss
            perceptual_loss = loss_fn_vgg(output, target)
            # Combine the losses
            loss = mse_loss + perceptual_loss
            loss = loss.sum()
        
        # Scale the loss, perform backpropagation, and update the weights
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        optim.zero_grad()
        train_loss += loss.item()
        loop.set_postfix(loss=loss.item())
        
    # Calculate the average training loss and accuracy for the epoch
    train_loss /= len(train_loader)
    
    # Validation loop
    val_loss = 0
    model.eval()
    with torch.no_grad():
        for input, target in val_loader:
            with torch.cuda.amp.autocast():
                output = model(input.to(device))
                output = output.to(device)
                target = target.to(device)
                # Compute the MSE loss
                mse_loss = criterion(output.flatten(), target.flatten())
                # Compute the perceptual loss
                perceptual_loss = loss_fn_vgg(output, target)
                # Combine the losses
                loss = mse_loss + perceptual_loss
                loss = loss.sum()
            val_loss += loss.item()
            
    # Calculate the average validation loss and accuracy for the epoch
    val_loss /= len(val_loader)
    
    print("Epoch:{} Training Loss:{:.2f} Validation Loss:{:.2f}\n".format(
        epoch, train_loss, val_loss))
    torch.cuda.empty_cache()
    # Update the progress bar description and postfix
    loop.set_description(f"Epoch [{epoch}/{num_epochs}]")
    loop.set_postfix(loss=loss.item())
    if use_wandb: wandb.log({"val_loss": val_loss, "train_loss": train_loss})
if use_wandb: wandb.finish()

In [None]:
# Testing Data Loader
test_loader = DataLoader(test_data,shuffle=False, batch_size=1,num_workers=24,generator=torch.Generator(device=device))

num_of_seq = 10
frames_per_seq = 100
out = np.zeros((num_of_seq,3,frames_per_seq,288,517), dtype=np.uint8)
tgt = np.zeros((num_of_seq,3,frames_per_seq,288,517), dtype=np.uint8)

seq = 0
timestep = 0
for (input, target) in test_loader:
    out[seq,:,timestep]=(model(input.to(device)).detach().cpu())*255.0    
    tgt[seq,:,timestep]=target*255.0
    timestep+=1
    if timestep == frames_per_seq-1:
        timestep = 0
        seq+=1
        print(seq)
    if seq == num_of_seq:
        break
  

In [None]:
#Save gifs
target_frames = []
output_frames = []

for video_idx in range(out.shape[0]): # Loop over videos
    for frame_idx in range(out.shape[2]): # Loop over frames in the sequence
        # Extract a single frame from the video
        tgt_frame = tgt[video_idx, :, frame_idx]
        out_frame = out[video_idx, :, frame_idx]
        
        tgt_frame = tgt_frame.astype(np.uint8)
        out_frame = out_frame.astype(np.uint8)
        
        tgt_frame = np.transpose(tgt_frame, (1, 2, 0))
        out_frame = np.transpose(out_frame, (1, 2, 0))
        
        tgt_frame_pil = Image.fromarray(tgt_frame)
        out_frame_pil = Image.fromarray(out_frame)
        
        # Append the frames to the lists
        target_frames.append(tgt_frame_pil)
        output_frames.append(out_frame_pil)

# Save the frames as GIFs
target_frames[0].save('target.gif', save_all=True, append_images=target_frames[1:], duration=100, loop=0)
output_frames[0].save('output.gif', save_all=True, append_images=output_frames[1:], duration=100, loop=0)


In [None]:
torch.save(model.state_dict(), "path/to/model")
#model.load_state_dict(torch.load("path/to/model"))
#model.eval()