In [1]:
!pip install -q torch-summary
!pip install -q patchify

In [2]:
# import the necessary packages
import math
import random
import rasterio
import gc
gc.collect()
import pickle
import numpy as np

# import splitfolders
from pathlib import Path
import matplotlib.pyplot as plt
from datetime import datetime
from patchify import patchify, unpatchify


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchmetrics import Metric
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler, random_split

import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer, LightningDataModule

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [3]:
class Config():
    def __init__(self):
        # Dataset/Dataloaders params
        self.years = ['2016','2017','2018','2019','2020','2021','2022']
        self.months = ['1','2','3','4']
        self.base_dir = '../input/cdmipcamergeddataset/CDMI-PCA/CDMI_'
        self.patch_size = 128
        self.batch_size = 12
        self.num_workers = 2
        
        # Model params
        self.input_dim  = 1              #Number of channels of input tensor.
        self.hidden_dim = 64             #Number of channels of hidden state.
        self.kernel_size = 5             #Size of the convolutional kernel.
        self.num_layers = 1              #Number of ConvLSTM layers
        
        
        # Training params
        self.learning_rate = 1e-4
        self.weight_decay = 1e-5
        self.exec_mode = None
        self.num_epochs = 100

        
        self.seed = 26012022
        self.generator = torch.Generator().manual_seed(self.seed)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Config()

In [5]:
# def get_CDMI_chronologically(base_dir, years, months):
#     cdmi_paths = [base_dir+j+'_'+i+'.tif' for i in years for j in months]
#     gc.collect()
#     return cdmi_paths

# def read_rasters(base_dir, years, months):
#     cdmi_paths = get_CDMI_chronologically(base_dir, years, months)[:34]
#     cdmi_stack = np.stack([rasterio.open(f).read() for f in cdmi_paths], axis=1)
#     cdmi_stack[np.isnan(cdmi_stack)] = 0
#     gc.collect()
#     return cdmi_stack[:,:,:4352,:5120]  # (1, 34, 4352, 5120)

# def patchify_images(base_dir, years, months, patch_size):
#     cdmi_stack = read_rasters(base_dir, years, months)
#     patchified_images = patchify(cdmi_stack, (1, 34, patch_size, patch_size), step=patch_size).squeeze()  #(17, 20, 128, 128))
#     data_cube = patchified_images.reshape((-1, 34, patch_size, patch_size))
#     np.save('./data_cube.npy', data_cube)
#     gc.collect()

In [6]:
# patchify_images(config.base_dir, config.years, config.months, config.patch_size)

In [7]:
class CDMIDataset (Dataset):
    def __init__(self, base_dir, years, months, patch_size):  
        
        self.base_dir = base_dir
        self.years = years
        self.months = months
        self.patch_size = patch_size
        self.patchified_cdmi = np.load('../input/datacube256-6m/data_cube.npy')
        gc.collect()
#         self.patchified_cdmi = patchify_images(self.base_dir, self.years, self.months, self.patch_size)
        
    def __len__(self):
        return len(self.patchified_cdmi)

        
    def __getitem__(self, idx):
        x = self.patchified_cdmi[idx]      
        # numpy array --> torch tensor
        x_tensor = torch.tensor(x[None], dtype=torch.float32)
        return x_tensor   

In [8]:
class CDMIDataModule(LightningDataModule):
    def __init__(self, config):
        self.config = config

        
    def setup(self, stage=None):
        cdmi_dataset = CDMIDataset (self.config.base_dir, self.config.years,
                                    self.config.months, self.config.patch_size)
        self.cdmi_train, self.cdmi_val = random_split(cdmi_dataset,
                                                      [math.floor(0.9*len(cdmi_dataset)),
                                                       math.ceil(0.1*len(cdmi_dataset))],
                                                       generator=self.config.generator) 
    
        
    def train_dataloader(self):
        return DataLoader(self.cdmi_train, batch_size=self.config.batch_size,
                          num_workers=self.config.num_workers)

    
    def val_dataloader(self):
        return DataLoader(self.cdmi_val, batch_size=3,
                          shuffle=False, num_workers=self.config.num_workers,drop_last=False)

In [9]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size):
        super(ConvLSTMCell, self).__init__()  

        self.input_dim  = input_dim      #Number of channels of input tensor.
        self.hidden_dim = hidden_dim     #Number of channels of hidden state.
        self.kernel_size = kernel_size   #Size of the convolutional kernel.
#         self.padding = kernel_size[0] // 2, kernel_size[1] // 2

        
        self.Gates = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                               out_channels=4 * self.hidden_dim,
                               kernel_size=self.kernel_size,
                               padding=self.kernel_size//2)


    def forward(self, x, h_prev, c_prev):

        combined_conv = self.Gates(torch.cat([x, h_prev], dim=1))

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.relu(cc_g)

        c_cur = f * c_prev + i * g    # Current Cell output
        h_cur = o * torch.relu(c_cur) # Current Hidden State

        return h_cur, c_cur

In [10]:
class ConvLSTM(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, kernel_size):
        super(ConvLSTM, self).__init__()

        self.input_dim  = input_dim      #Number of channels of input tensor.
        self.hidden_dim = hidden_dim     #Number of channels of hidden state.
        self.kernel_size = kernel_size   #Size of the convolutional kernel.
       

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(input_dim=self.input_dim,
                                          hidden_dim=self.hidden_dim,
                                          kernel_size=self.kernel_size)

        
    def forward(self, x):

        # Get the dimensions
        batch_size, _, seq_len, height, width = x.size() # shape(B, C, S, H, W)
        # Initialize output
        output = torch.zeros(batch_size, self.hidden_dim, seq_len, height, width, device=device)
        # Initialize Hidden State
        H = torch.zeros(batch_size, self.hidden_dim, height, width, device=device)
        # Initialize Cell Input
        C = torch.zeros(batch_size, self.hidden_dim, height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):
            H, C = self.convLSTMcell(x[:,:,time_step], H, C)
            output[:,:,time_step] = H
        return output

In [11]:
class Seq2Seq(nn.Module):

    def __init__(self, config):
        super(Seq2Seq, self).__init__()
 
        self.sequential = nn.Sequential()
        self.config = config
               

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(input_dim=self.config.input_dim,
                                  hidden_dim=self.config.hidden_dim,
                                  kernel_size=self.config.kernel_size))
        

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=self.config.hidden_dim)) 

        # Add rest of the layers
        for l in range(2, self.config.num_layers+1):
            self.sequential.add_module(
                f"convlstm{l}", ConvLSTM(input_dim=self.config.hidden_dim,
                                         hidden_dim=self.config.hidden_dim,
                                         kernel_size=self.config.kernel_size))
                
            self.sequential.add_module(
                f"batchnorm{l}", nn.BatchNorm3d(num_features=self.config.hidden_dim)
                ) 

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(in_channels=self.config.hidden_dim,
                              out_channels=self.config.input_dim,
                              kernel_size=self.config.kernel_size,
                              padding = self.config.kernel_size//2)
                
        self.sigmoid = nn.Sigmoid()
        
        # initialize weights
        self._initialize_weights()
                

    def forward(self, x):

        # Forward propagation through all the layers
        output = self.sequential(x)
        # Return only the last output frame
        output = self.conv(output[:,:,-1])
        # Pass it through a sigmoid
        next_frame = self.sigmoid(output)
                
        return next_frame
    
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d): 
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [12]:
def collate(batch):
    batch = batch.to(device) 
    # Randomly pick 14 frames as input, 15th frame is target
    rand = np.random.randint(6,40)                     
    return batch[:,:,rand-6:rand], batch[:,:, rand] 

def collate_test(batch):
    # Last 10 frames are target
    target = batch[:,:,30:]                          
    batch = batch.to(device)                          
    return batch, target.squeeze()

In [13]:
# The input video frames are grayscale, thus single channel
model = Seq2Seq(config).to(device)
optimizer = torch.optim.Adam(model.parameters(),
                             lr=config.learning_rate,
                             weight_decay=config.weight_decay)

criterion = nn.MSELoss()                    
# summary(model, (1, 4, 256, 256)) #sanity check

In [None]:
dm = CDMIDataModule(config)
dm.setup() 

for epoch in range(1, config.num_epochs+1):
    
    train_loss = 0                                                 
    model.train()   
    
    for batch  in dm.train_dataloader():  
        x, y = collate(batch)
        
        optimizer.zero_grad()
        output = model(x) 
        loss = criterion(output, y)  
        loss.backward()                                            
        optimizer.step()                                               
        train_loss += loss.item()                                 
    train_loss /= len(dm.train_dataloader().dataset) 
    torch.cuda.empty_cache()
    gc.collect()
    
    val_loss = 0                                                 
    model.eval()                                                   
    with torch.no_grad():                                          
        for batch  in dm.val_dataloader():
            x, y = collate(batch)
            
            output = model(x)                                   
            loss = criterion(output, y)   
            val_loss += loss.item()                                
    val_loss /= len(dm.val_dataloader().dataset) 
    torch.cuda.empty_cache()
    gc.collect()

    print("Epoch:{} Training Loss:{:.2f} Validation Loss:{:.2f}\n".format(
        epoch, train_loss, val_loss))
torch.save(model.state_dict(), './model.pt')

## Check Model's outputs

In [14]:
#Load the saved model
model.load_state_dict(torch.load('../input/convlstm6m6ts10e/model.pt'))
model.eval()
dm = CDMIDataModule(config)
dm.setup() 

In [39]:
#Get a batch
batch = next(iter(dm.val_dataloader()))
batch, target = collate_test(batch) #torch.Size([10, 1, 28, 256, 256]) #torch.Size([10, 1, 18, 256, 256])

# Initialize output sequence
output = np.zeros(target.shape, dtype=np.float32) 
target = target.cpu().detach().numpy()
# target[target== 0] = np.nan
# output[output== 0] = np.nan

for timestep in range(target.shape[1]):
    input = batch[:,:,timestep:timestep+6] 
#     output[:,timestep]= model(input).squeeze(1).cpu().detach().numpy()

In [None]:
fig, ax = plt.subplots(1,2, figsize = (20,10)) 
ax[0].imshow(target[2,8], cmap='gray') #(3, 30, 128, 128)
ax[1].imshow(output[2,8], cmap='gray') #(3, 30, 128, 128)

## Forecasting

In [23]:
# Load data for forecasting
forecasting_cdmi = np.load('../input/foreceasting-data/data_cube.npy') #(340, 6, 256, 256)
pred_data = torch.tensor(forecasting_cdmi, dtype=torch.float32)
pred_data = pred_data.unsqueeze(1).to(device) 

In [50]:
# Initialize output sequence
# output = np.zeros((340, 1, 256, 256), dtype=np.float32) 
Allarrays = []
patch=0
for i in range(int(pred_data.shape[0]/10)):
    output = model(pred_data[patch:patch+10,...]).squeeze(1).cpu().detach().numpy()
    Allarrays.append(output)
    patch+=10

Output = np.concatenate(Allarrays, axis=0, dtype=np.float32)  #(340, 256, 256)

In [64]:
plt.imshow(Output[300], cmap='gray')

In [67]:
# reconstruct the prediction for Zambia
Output_reshaped = Output.reshape((17, 20, 256, 256))
recon_CDMI = unpatchify(Output_reshaped, (4352,5120))

In [69]:
plt.imshow(recon_CDMI, cmap='gray')