<a href="https://colab.research.google.com/github/mengmengwoo/CNN-LSTM-GAN/blob/main/CNN_LSTM_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Author
- **Meng-Hsuan (Michelle) Wu** (JHU)

## Projects
- CNN-LSTM with MSE loss

In [None]:
import os
import os.path
import cv2
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as fn
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import math
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import pickle
import shutil

# Model

## Encoder

In [None]:
class ConvNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.network = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size = 2, stride = 1),
        nn.ReLU(),
        nn.BatchNorm2d(64,momentum=0.9),
        nn.MaxPool2d(2),

        nn.Conv2d(64, 128, kernel_size = 3, stride = 1),
        nn.ReLU(),
        nn.BatchNorm2d(128,momentum=0.9),
        nn.MaxPool2d(3),

        nn.Conv2d(128, 256, kernel_size = 2, stride = 1),
        nn.ReLU(),
        nn.BatchNorm2d(256,momentum=0.9),
        nn.MaxPool2d(3),

        nn.Conv2d(256, 256, kernel_size = 2, stride = 1),
        nn.ReLU(),

        nn.BatchNorm2d(256,momentum=0.9),
        nn.MaxPool2d(3),

        nn.Conv2d(256, 512, kernel_size = 3, stride = 1),
        nn.ReLU(),
        nn.Dropout(0.2)
    )

  def forward(self, x):
    cnn_val = self.network(x)
    return_val = F.max_pool2d(cnn_val, kernel_size=cnn_val.size()[2:])
    return_val = torch.squeeze(return_val)
    return return_val

In [None]:
class Encoder(nn.Module):
  def __init__(self):
    super(Encoder, self).__init__()
    self.convNN = ConvNN()
    self.lstm = nn.LSTM(input_size = 512, hidden_size = 512,
                        batch_first = True, bidirectional = True,
                        num_layers = 4, dropout = 0.2)

  def forward(self, input):
    batch_size = int(input.shape[0]/img_length)

    in_features = self.convNN(input)
    in_features = torch.reshape(in_features,(batch_size, img_length, 512))
    output, (h_n, c_n) = self.lstm(in_features)
    return h_n[-1]

## Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 512, out_channels = 256,
                               kernel_size = 4, stride = 1, padding = 0,
                               bias=False),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2, mode = 'nearest'),
            nn.BatchNorm2d(256),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2, mode = 'nearest'),
            nn.BatchNorm2d(128),

            nn.ConvTranspose2d(128, 128, 4, 2, 1, bias=False),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2, mode = 'nearest'),
            nn.BatchNorm2d(128),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.ReLU(),
            nn.Upsample(scale_factor = 2, mode = 'nearest'),
            nn.BatchNorm2d(64),

            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Sigmoid()

        )
    def forward(self, input):
      return self.main(input)

# Models

In [None]:
class CLSTM(nn.Module):
  def __init__(self):
    super(CLSTM, self).__init__()
    self.encoder = Encoder()
    self.decoder = Decoder()
  def forward(self, x):
    y_hidden = self.encoder(x)
    y_hidden = torch.unsqueeze(y_hidden,2)
    y_hidden = torch.unsqueeze(y_hidden,3)
    y_predict = self.decoder(y_hidden)
    return y_predict

# Train

In [None]:
def train(start_epochs,n_epochs, train_loader, val_loader, img_length,
          valid_loss_min_input, checkpoint_path,best_model_path,model,
          criterion,opt_function, train_history_lst, val_history_lst,use_cuda
          ):

  train_history = train_history_lst
  val_history = val_history_lst

  # initialize tracker for minimum validation loss
  valid_loss_min = valid_loss_min_input
  for epoch in range(start_epochs, n_epochs+1):

    ####################
    # training process #
    ####################
    model.train()
    train_loss_lst = []
    for batch_train in train_loader:

      x_train = batch_train[:,:-1,:,:,:]
      y_train = batch_train[:,-1,:,:,:]
      if use_cuda:
        x_train, y_train = x_train.cuda(), y_train.cuda()

      n_series = x_train.shape[0]
      n_img_in_series = x_train.shape[1]
      img_channels = x_train.shape[2]
      img_height = x_train.shape[3]
      img_width = x_train.shape[4]

      x_train_new_dim = (n_series*n_img_in_series, img_channels, img_height, img_width)
      x_train = torch.reshape(x_train,x_train_new_dim)

      y_train_predict = model(x_train)
      y_train_crop = fn.center_crop(y_train, output_size=[256])
      y_train_predict = fn.center_crop(y_train_predict, output_size =[256])
      train_loss = criterion(y_train_predict,y_train_crop) # train the model and calculate mse loss

      optimizer.zero_grad()
      train_loss.backward() # backpropogation
      optimizer.step() # update the weight for the model
      train_loss_lst.append(train_loss.cpu().detach().numpy())

    avg_train_loss = sum(train_loss_lst)/len(train_loss_lst)
    train_history.append(avg_train_loss)

    ######################
    # validate the model #
    ######################
    model.eval()
    with torch.no_grad():
      tot_val_loss = 0
      val_loss_lst = []
      for batch_val in val_loader:
        x_val = batch_val[:,:-1,:,:,:]
        y_val = batch_val[:,-1,:,:,:]
        if use_cuda:
          x_val, y_val = x_val.cuda(), y_val.cuda()


        n_series_val = x_val.shape[0]
        n_img_in_series_val = x_val.shape[1]
        img_channels_val = x_val.shape[2]
        img_height_val = x_val.shape[3]
        img_width_val = x_val.shape[4]

        x_val_new_dim = (n_series_val*n_img_in_series_val, img_channels_val, img_height_val, img_width_val)
        x_val = torch.reshape(x_val,x_val_new_dim)

        y_val_predict = model(x_val)
        y_val_crop = fn.center_crop(y_val, output_size=[256])
        y_val_predict = fn.center_crop(y_val_predict, output_size =[256])
        val_loss = criterion(y_val_predict, y_val_crop)
        val_loss_lst.append(val_loss.cpu().numpy())
      avg_val_loss = sum(val_loss_lst)/len(val_loss_lst)
      avg_val_loss_np = avg_val_loss
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(epoch+1, avg_train_loss,avg_val_loss))
    val_history.append(avg_val_loss_np)

    # create checkpoint variable and add important data
    checkpoint = {
        'epoch': epoch + 1,
        'valid_loss_min': avg_val_loss,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'train_history': train_history,
        'val_history': val_history
    }
    # save checkpoint
    save_ckp(checkpoint, False, checkpoint_path, best_model_path)

    # save the model if validation loss has decreased
    if avg_val_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,avg_val_loss))
        # save checkpoint as best model
        save_ckp(checkpoint, True, checkpoint_path, best_model_path)
        valid_loss_min = avg_val_loss

# Checkpoint

In [None]:
def save_ckp(state, is_best, checkpoint_path, best_model_path):
  f_path = checkpoint_path
  torch.save(state, f_path)
  if is_best:
      print("Saving a new best model")
      best_fpath = best_model_path
      shutil.copyfile(f_path, best_fpath)


In [None]:
def load_ckp(checkpoint_fpath, model, optimizer):

    # load check point
    checkpoint = torch.load(checkpoint_fpath)

    # initialize state_dict from checkpoint to model
    model.load_state_dict(checkpoint['state_dict'])

    # initialize optimizer from checkpoint to optimizer
    optimizer.load_state_dict(checkpoint['optimizer'])

    # initialize valid_loss_min from checkpoint to valid_loss_min
    valid_loss_min = checkpoint['valid_loss_min']
    train_loss_lst = checkpoint['train_history']
    val_loss_lst = checkpoint['val_history']

    # return model, optimizer, epoch value, min validation loss
    return model, optimizer, checkpoint['epoch'], valid_loss_min.item(), train_loss_lst,val_loss_lst

# Setting Parameters

In [None]:
# setting parameters
model = CLSTM()
use_cuda = torch.cuda.is_available()
if use_cuda:
    model = model.cuda()
start_epochs = 0
num_epochs = 1000
batch_train_num = 16
batch_valid_num = 4
batch_test_num= 10

criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(), lr = 0.00001)
ngf = 64
nz = 100
nc = 3
img_length = 2 # how many images per series I extract
checkpoint_dir = "/content/drive/MyDrive/Capstone/Capstone_checkpoint_models/clstm_checkpoint_mse_final.pt"
best_model_dir = "/content/drive/MyDrive/Capstone/Capstone_models/clstm_best_model_mse_final.pt"
valid_loss_min_input = np.Inf
train_history_lst = []
val_history_lst = []


# Loading Datapoint

In [None]:
# convert the image to array

x_arr = np.zeros(shape=(100,2, 3,288,432))
y_arr = np.zeros(shape=(100,1,3,288,432))


for i in range(100): # loop through 100 series
  for j in range(img_length): # loop through images within series
    model_path = "/content/drive/MyDrive/Capstone/Capstone_data_jpg/"+str(i)
    x_to_convert = Image.open(model_path+'/'+"{:02d}".format(j)+'im.jpg')
    x = np.asarray(x_to_convert)
    x = np.moveaxis(x,-1,0) # change image dimension to channel first
    x_arr[[i,j]] = x
  y_to_convert = Image.open(model_path+'/'+"{:02d}".format(img_length+1)+'im.jpg')
  y = np.asarray(y_to_convert)
  y = np.moveaxis(y,-1,0) # change image dimension to channel first
  y_arr[[i,0]] = y

# image dimention
n_series = x_arr.shape[0]
n_img_in_series = x_arr.shape[1]
img_channels = x_arr.shape[2]
img_height = x_arr.shape[3]
img_width = x_arr.shape[4]

# train/validation/test split
X_train, X_test, y_train, y_test = train_test_split(x_arr, y_arr, test_size = 0.2, random_state = 1)
X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, test_size = 0.5, random_state = 1)

num_train_data = X_train.shape[0]
num_val_data = X_val.shape[0]
num_test_data = X_test.shape[0]

# generating means and std for normalization
trans_X_train= torch.tensor((np.array(X_train)/255).astype(np.float32))
trans_y_train = torch.tensor((np.array(y_train)/255).astype(np.float32))
trans_X_val = torch.tensor((np.array(X_val)/255).astype(np.float32))
trans_y_val = torch.tensor((np.array(y_val)/255).astype(np.float32))
trans_X_test = torch.tensor((np.array(X_test)/255).astype(np.float32))
trans_y_test = torch.tensor((np.array(y_test)/255).astype(np.float32))


# final train/val/test data
final_train_data = torch.cat((trans_X_train, trans_y_train), dim = 1)
final_val_data = torch.cat((trans_X_val,trans_y_val),dim = 1)
final_test_data = torch.cat((trans_X_test,trans_y_test),dim = 1)

In [None]:
train_dl = DataLoader(final_train_data, batch_train_num, shuffle = True)
val_dl = DataLoader(final_val_data, batch_valid_num, shuffle = True)
test_dl = DataLoader(final_test_data, batch_test_num, shuffle = False)


# Calling Train

In [None]:
# fitting data to models
# train(start_epochs,num_epochs, train_dl, val_dl, img_length,
#                       valid_loss_min_input, checkpoint_dir, best_model_dir, model,
#                       criterion, optimizer, train_history_lst, val_history_lst, use_cuda)

In [None]:
# loading checkpoint if model fail
ckp_path = "/content/drive/MyDrive/Capstone/Capstone_checkpoint_models/clstm_checkpoint_mse_final.pt"
optimizer = optim.RMSprop(model.parameters(), lr = 0.00001)
model, optimizer, start_epoch, valid_loss_min, train_lst, val_lst = load_ckp(ckp_path, model, optimizer)

print("model = ", model)
print("optimizer = ", optimizer)
print("start_epoch = ", start_epoch)
print("train_lst = ",train_lst)
print("validation_lst = ",val_lst)
print("valid_loss_min = ", valid_loss_min)
print("valid_loss_min = {:.6f}".format(valid_loss_min))


In [None]:
# train(start_epoch,num_epochs, train_dl, val_dl, img_length,
#                       valid_loss_min, ckp_path, best_model_dir, model,
#                       criterion, optimizer, train_lst, val_lst, use_cuda)