In [1]:
#Imports
import torch
import torch.nn as nn
import torch.cuda.amp as amp
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn.functional import relu
from torch.optim.lr_scheduler import ReduceLROnPlateau

import cv2 as cv
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import wandb
import io
import imageio
from ipywidgets import widgets, HBox
from PIL import Image
import lpips

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.device_count()



In [None]:
wandb.login()
wandb.init(
    # set the wandb project where this run will be logged
    project="convlstm_encoder_decoder",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.0001,
    "architecture": "Conv-LSTM",
    "dataset": "SHMU",
    "epochs": 10,
    }
)

In [2]:
class UNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder
        # input: 512x288x3
        self.e11 = nn.Conv2d(3, 16, kernel_size=3, padding=1) # output: 512x288x16
        self.e12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) # output: 512x288x16
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 256x144x16

        # input: 256x144x16
        self.e21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # output: 256x144x32
        self.e22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) # output: 256x144x32
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 128x72x32

        # input: 128x72x32
        self.e31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # output: 128x72x64
        self.e32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) # output: 128x72x64
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 64x36x64

        # input: 64x36x64
        self.e41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # output: 64x36x128
        self.e42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) # output: 64x36x128
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 32x18x128

        # input: 32x18x128
        self.e51 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # output: 32x18x256
        self.e52 = nn.Conv2d(256, 256, kernel_size=3, padding=1) # output: 32x18x256



    def forward(self, x):
        # Encoder
        xe11 = nn.ReLU()(self.e11(x))
        xe12 = nn.ReLU()(self.e12(xe11))
        xp1 = self.pool1(xe12)

        xe21 = nn.ReLU()(self.e21(xp1))
        xe22 = nn.ReLU()(self.e22(xe21))
        xp2 = self.pool2(xe22)

        xe31 = nn.ReLU()(self.e31(xp2))
        xe32 = nn.ReLU()(self.e32(xe31))
        xp3 = self.pool3(xe32)

        xe41 = nn.ReLU()(self.e41(xp3))
        xe42 = nn.ReLU()(self.e42(xe41))
        xp4 = self.pool4(xe42)

        xe51 = nn.ReLU()(self.e51(xp4))
        xe52 = nn.ReLU()(self.e52(xe51))
        
        return xe52


# Load pre-trained encoder model
encoder_model = UNetEncoder()
encoder_model.load_state_dict(torch.load('Models/unet/unet_80epoch'), strict=False)

In [3]:
class UNetDecoder(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        
        # Decoder
        # input: 32x18x256
        self.upconv1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) # output: 64x36x128
        self.d11 = nn.Conv2d(128, 128, kernel_size=3, padding=1) # output: 64x36x128
        self.d12 = nn.Conv2d(128, 128, kernel_size=3, padding=1) # output: 64x36x128

        # input: 64x36x128
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) # output: 128x72x64
        self.d21 = nn.Conv2d(64, 64, kernel_size=3, padding=1) # output: 128x72x64
        self.d22 = nn.Conv2d(64, 64, kernel_size=3, padding=1) # output: 128x72x64

        # input: 128x72x64        
        self.upconv3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) # output: 256x144x32
        self.d31 = nn.Conv2d(32, 32, kernel_size=3, padding=1) # output: 256x144x32
        self.d32 = nn.Conv2d(32, 32, kernel_size=3, padding=1) # output: 256x144x32

        # input: 256x144x32
        self.upconv4 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2) # output: 512x288x16
        self.d41 = nn.Conv2d(16, 16, kernel_size=3, padding=1) # output: 512x288x16
        self.d42 = nn.Conv2d(16, 16, kernel_size=3, padding=1) # output: 512x288x16

        # Output layer
        self.outconv = nn.Conv2d(16, n_class, kernel_size=1) # output: 512x288x3 (n_class = 3) 

    def forward(self, x):
        # Decoder
        xu1 = self.upconv1(x)
        xd11 = nn.ReLU()(self.d11(xu1))
        xd12 = nn.ReLU()(self.d12(xd11))

        xu2 = self.upconv2(xd12)
        xd21 = nn.ReLU()(self.d21(xu2))
        xd22 = nn.ReLU()(self.d22(xd21))

        xu3 = self.upconv3(xd22)
        xd31 = nn.ReLU()(self.d31(xu3))
        xd32 = nn.ReLU()(self.d32(xd31))

        xu4 = self.upconv4(xd32)
        xd41 = nn.ReLU()(self.d41(xu4))
        xd42 = nn.ReLU()(self.d42(xd41))

        # Output layer
        out = self.outconv(xd42)

        return out

# Load pre-trained decoder model with 3 output channels
decoder_model = UNetDecoder(3)
decoder_model.load_state_dict(torch.load('Models/unet/unet_80epoch'), strict=False)

In [4]:
# Original ConvLSTM cell as proposed by Shi et al.
class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTMCell, self).__init__()  

        if activation == "tanh":
            self.activation = torch.tanh 
        elif activation == "relu":
            self.activation = torch.relu
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels, 
            out_channels=4 * out_channels, 
            kernel_size=kernel_size, 
            padding=padding)           

        
        # Initialize weights for Hadamard Products
        self.W_ci = nn.Parameter(nn.init.kaiming_normal_(torch.zeros(out_channels, *frame_size), nonlinearity="relu"))
        self.W_co = nn.Parameter(nn.init.kaiming_normal_(torch.zeros(out_channels, *frame_size), nonlinearity="relu"))
        self.W_cf = nn.Parameter(nn.init.kaiming_normal_(torch.zeros(out_channels, *frame_size), nonlinearity="relu"))

    def forward(self, X, H_prev, C_prev):
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)
        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )
        
        # Current Cell output
        C = forget_gate*C_prev + input_gate * self.activation(torch.sigmoid(C_conv))
        output_gate = torch.sigmoid(o_conv + self.W_co * C )
        
        # Current Hidden State
        H = output_gate * self.activation(C)
        return H, C

In [5]:
#from ConvLSTMCell import ConvLSTMCell
class ConvLSTM(nn.Module):

    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTM, self).__init__()

        self.out_channels = out_channels

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, 
        kernel_size, padding, activation, frame_size)
        
    def forward(self, X):

        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, self.out_channels, seq_len, 
        height, width, device=device)
        
        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels, 
        height, width, device=device)

        # Initialize Cell Input
        C = torch.zeros(batch_size,self.out_channels, 
        height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):

            H, C = self.convLSTMcell(X[:,:,time_step], H, C)
            
            output[:,:,time_step] = H

        return output

In [6]:
#from ConvLSTM import ConvLSTM
class Seq2Seq(nn.Module):

    def __init__(self, num_channels, num_kernels, kernel_size, padding, 
    activation, frame_size, num_layers):
        super(Seq2Seq, self).__init__()
        self.num_channels = num_channels
        
        self.encoder = encoder_model
        self.decoder = decoder_model
        
        # Freeze the encoder and decoder
        for param in self.encoder.parameters():
            param.requires_grad = False
        for param in self.decoder.parameters():
            param.requires_grad = False
            
        self.sequential = nn.Sequential()

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding, 
                activation=activation, frame_size=tuple(int(dim // 16) for dim in frame_size))
        )

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
        ) 

        # Add rest of the layers
        for l in range(2, num_layers+1):

            self.sequential.add_module(
                f"convlstm{l}", ConvLSTM(
                    in_channels=num_kernels, out_channels=num_kernels,
                    kernel_size=kernel_size, padding=padding, 
                    activation=activation, frame_size=tuple(int(dim // 16) for dim in frame_size))
                )
                
            self.sequential.add_module(
                f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels)
                ) 

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=kernel_size, padding=padding)

    def forward(self, X):
        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()
        # Initialize encoded sequence
        encoded_X = torch.zeros(batch_size, self.num_channels, seq_len, height//16, width//16, device=device)

        # Encode the sequence by iterating over time steps
        for time_step in range(seq_len):
            encoded_X[:,:,time_step] = self.encoder(X[:,:,time_step])
            
        # Send the encoded sequence to ConvLSTM
        output = self.sequential(encoded_X)
        
        # Decode and return the last output frame
        output = self.decoder(self.conv(output[:,:,-1]))
        return nn.Sigmoid()(output)


In [7]:
class SHMUDataset(Dataset):
    def __init__(self, data_frame, input_frames_length, target_frames_length, minutes):
        # Initialize the dataset with the given parameters
        self.data_frame = data_frame # DataFrame containing image paths
        self.input_frames_length = input_frames_length
        self.target_frames_length = target_frames_length
        self.minutes = minutes # Time difference between images in minutes 
        self.selected_paths = self.data_frame[::self.minutes // 5].iloc[:, 0].tolist() # Select every nth image based on the specified time difference

    def transform(self, image_path):
        # Load and transform an image
        img = cv.imread(image_path)
        # Apply morphological operations
        morph_operator = cv.MORPH_OPEN
        element = cv.getStructuringElement(cv.MORPH_CROSS, (3, 3))
        morphed = cv.morphologyEx(src=img, op=morph_operator, kernel=element, iterations=2)
        # Crop and resize the image
        cropped = morphed[283:1147, 537:2087]
        resized = cv.resize(cropped, (512,288))
        # Convert the image to RGB and then to NumPy array
        image_rgb = cv.cvtColor(resized, cv.COLOR_BGR2RGB)
        image = np.array(image_rgb)
        return image

    def __len__(self):
        return (len(self.selected_paths) - self.input_frames_length - self.target_frames_length)

    def __getitem__(self, idx):  
        np_input_frames = np.stack([self.transform(path) for path in self.selected_paths[idx:idx+self.input_frames_length]] , axis=0)
        np_target_frames = np.stack([self.transform(path) for path in self.selected_paths[idx+self.input_frames_length:idx+self.input_frames_length+self.target_frames_length]] , axis=0)
        
        # Convert to float, and normalize by dividing by 255 to scale pixel values to [0, 1]
        input_frames = torch.from_numpy(np_input_frames.transpose(0,3,1,2)).transpose(0,1).float() / 255.0
        target_frames = torch.from_numpy(np_target_frames.transpose(0,3,1,2)).transpose(0,1).float() / 255.0
         # If there's only one target frame, remove the singleton dimension
        if self.target_frames_length == 1:
            target_frames = target_frames.squeeze(1)
          
        return input_frames, target_frames


In [12]:
# Load Data as Numpy Array
data = pd.read_csv("dataset.csv")
print(data.shape)

# # Train, Test, Validation splits
train_data = SHMUDataset(data[602245:666245], minutes = 5, input_frames_length = 20, target_frames_length = 1) 
val_data = SHMUDataset(data[666245:674245], minutes = 5, input_frames_length = 20, target_frames_length = 1)      
test_data = SHMUDataset(data[674245:682245], minutes = 5, input_frames_length = 20, target_frames_length = 1)    

# Training Data Loader
train_loader = DataLoader(train_data, shuffle=True, batch_size=8,num_workers=24)

# Validation Data Loader
val_loader = DataLoader(val_data, shuffle=True, batch_size=8,num_workers=24)

In [10]:
# num_channels = 256 to match the output channels from the encoder.
model = Seq2Seq(num_channels=256, num_kernels=128, kernel_size=(3, 3), padding=(1, 1), activation="relu", frame_size=(288, 512), num_layers=8).to(device)
optim = Adam(model.parameters(), lr=1e-4)
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)

criterion = nn.MSELoss(reduction='mean')

In [13]:
# Set `use_wandb` to True if you want to enable the use of Weights & Biases for experiment tracking and visualization.
use_wandb = True
# Initialize the GradScaler for automatic mixed precision (AMP) training
scaler = amp.GradScaler()
num_epochs = 5
for epoch in range(1, num_epochs+1):
    train_loss = 0
    model.train()
    # Initialize a progress bar for the training loop
    loop = tqdm(train_loader)
    for batch_num, (input, target) in enumerate(loop, 1):
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            output = model(input.to(device))
            output = output.to(device)
            target = target.to(device)
            # Compute the MSE loss
            mse_loss = criterion(output.flatten(), target.flatten())
            # Compute the perceptual loss
            perceptual_loss = loss_fn_vgg(output, target)
            # Combine the losses
            loss = mse_loss + perceptual_loss
            loss = loss.sum()
        # Scale the loss, perform backpropagation, and update the weights
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        optim.zero_grad()
        train_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    # Calculate the average training loss and accuracy for the epoch
    train_loss /= len(train_loader)
    
    # Validation loop
    val_loss = 0
    model.eval()
    with torch.no_grad():
        for input, target in val_loader:
            with torch.cuda.amp.autocast():
                output = model(input.to(device))
                output = output.to(device)
                target = target.to(device)
                # Compute the MSE loss
                mse_loss = criterion(output.flatten(), target.flatten())
                # Compute the perceptual loss
                perceptual_loss = loss_fn_vgg(output, target)
                # Combine the losses
                loss = mse_loss + perceptual_loss
                loss = loss.sum()
            val_loss += loss.item()
            
    # Calculate the average validation loss and accuracy for the epoch
    val_loss /= len(val_loader)
    print("Epoch:{} Training Loss:{:.2f} Validation Loss:{:.2f}\n".format(
        epoch, train_loss, val_loss))
    torch.cuda.empty_cache()
    # Update the progress bar description and postfix
    loop.set_description(f"Epoch [{epoch}/{num_epochs}]")
    if use_wandb: wandb.log({"val_loss": val_loss, "train_loss": train_loss})
if use_wandb: wandb.finish()

In [14]:
# Testing Data Loader
test_loader = DataLoader(test_data,shuffle=False, batch_size=1,num_workers=24)

num_of_seq = 10
frames_per_seq = 100
out = np.zeros((num_of_seq,3,frames_per_seq,288,512), dtype=np.uint8)
tgt = np.zeros((num_of_seq,3,frames_per_seq,288,512), dtype=np.uint8)

seq = 0
timestep = 0
for (input, target) in test_loader:
    out[seq,:,timestep]=(model(input.to(device)).detach().cpu())*255.0    
    tgt[seq,:,timestep]=target*255.0
    timestep+=1
    if timestep == frames_per_seq-1:
        timestep = 0
        seq+=1
        print(seq)
    if seq == num_of_seq:
        break
  

1


In [None]:
# Predict from predicted
num_of_seq = 10
frames_per_seq = 100
out = np.zeros((num_of_seq,3,frames_per_seq,288,512), dtype=np.uint8)
tgt = np.zeros((num_of_seq,3,frames_per_seq,288,512), dtype=np.uint8)

seq = 0
timestep = 0
batch, target = next(iter(test_loader))
predicted_mse_perc = batch


for (batch, target) in test_loader:
    out[seq,:,timestep]=(model(predicted_mse_perc.to(device)).detach().cpu())*255.0  
    tgt[seq,:,timestep]=target*255.0
    predicted_mse_perc = predicted_mse_perc[:,:,1:] .to(device)
    predicted_mse_perc = torch.cat((predicted_mse_perc, torch.zeros(1, 3, 1, 288, 512, device=device)), dim=2)
    predicted_mse_perc[:,:,19] = torch.from_numpy(out[seq,:,timestep]).float() / 255.0
    timestep+=1
    if timestep == frames_per_seq-1:
        timestep = 0
        seq+=1
        predicted_mse_perc = batch     
        print(seq)
    if seq == num_of_seq:
        break
        

1


In [None]:
#Save gifs
target_frames = []
output_frames = []

for video_idx in range(out.shape[0]): # Loop over videos
    for frame_idx in range(out.shape[2]): # Loop over frames in the sequence
        # Extract a single frame from the video
        tgt_frame = tgt[video_idx, :, frame_idx]
        out_frame = out[video_idx, :, frame_idx]
        
        tgt_frame = tgt_frame.astype(np.uint8)
        out_frame = out_frame.astype(np.uint8)
        
        tgt_frame = np.transpose(tgt_frame, (1, 2, 0))
        out_frame = np.transpose(out_frame, (1, 2, 0))
        
        tgt_frame_pil = Image.fromarray(tgt_frame)
        out_frame_pil = Image.fromarray(out_frame)
        
        # Append the frames to the lists
        target_frames.append(tgt_frame_pil)
        output_frames.append(out_frame_pil)

# Save the frames as GIFs
target_frames[0].save('target.gif', save_all=True, append_images=target_frames[1:], duration=100, loop=0)
output_frames[0].save('output.gif', save_all=True, append_images=output_frames[1:], duration=100, loop=0)


In [None]:
torch.save(model.state_dict(), "path/to/model")
# model.load_state_dict(torch.load("path/to/model"))
# model.eval()

Seq2Seq(
  (encoder): UNetEncoder(
    (e11): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (e12): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (e21): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (e22): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (e31): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (e32): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (e41): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (e42): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=Fa