---

# *Initial* **Setup**

---

## **Library** *Settings*

The Real Package Name must be found in https://pypi.org

In [2]:
# Library Import
import os
import sys
import io
import pickle
import psutil
import numpy as np
import argparse
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary
import torchvision
import pytorch_lightning as pl
import pl_bolts                     # lightning_bolts
import tensorboard
import tensorflow as tf
import fvcore
import matplotlib.pyplot as plt
import itk
import itkwidgets
import time
import timeit
import warnings
import alive_progress

  warn_missing_pkg("wandb")
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)
  warn_missing_pkg("gym")


In [3]:
# Functionality Import
from pathlib import Path
from typing import List, Literal, Optional, Callable, Dict, Literal, Optional, Union, Tuple
from collections import OrderedDict
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from torch.nn.utils import spectral_norm
from torchsummary import summary
from pytorch_lightning.loggers import TensorBoardLogger
from pl_bolts.models.autoencoders.components import resnet18_encoder, resnet18_decoder
torch.autograd.set_detect_anomaly(True)
from PIL import Image
from ipywidgets import interactive, IntSlider
from tabulate import tabulate
from alive_progress import alive_bar
warnings.filterwarnings('ignore')

## **Control** *Station*

In [4]:
# Dataset Parametrizations Parser Initialization
data_parser = argparse.ArgumentParser(
        description = "2/3D MUDI Dataset Settings")
data_parser.add_argument(                               # Dataset Version Variable
        '--version', type = int,                        # Default: 0
        default = 3,
        help = "Dataset Save Version")
data_parser.add_argument(                               # Dataset Dimensionality
        '--dim', type = int,                            # Default: 3
        default = 2,
        help = "Dataset Dimensionality")
data_parser.add_argument(                               # Dataset Batch Size Value
        '--batch_size', type = int,                     # Default: 500
        default = 500,
        help = "Dataset Batch Size Value")

# --------------------------------------------------------------------------------------------

# Dataset Label Parametrization Arguments
data_parser.add_argument(                       # Control Variable for the Inclusion of Patient ID in Labels
        '--patient_id', type = bool,            # Default: True
        default = False,
        help = "Control Variable for the Inclusion of Patient ID in Labels")
data_parser.add_argument(                       # Control Variable for the Conversion of 3 Gradient Directions
        '--gradient_coord', type = bool,        # Coordinates into 2 Gradient Direction Angles (suggested by prof. Chantal)
        default = False,                        # Default: True (3 Coordinate Gradient Values)
        help = "Control Variable for the Conversion of Gradient Direction Mode")
data_parser.add_argument(                       # Control Variable for the Rescaling & Normalization of Labels
        '--label_norm', type = bool,            # Default: True
        default = True,
        help = "Control Variable for the Rescaling & Normalization of Labels")
data_settings = data_parser.parse_args("")
num_labels = 7
if not(data_settings.patient_id): num_labels -= 1           # Exclusion of Patiend ID
if not(data_settings.gradient_coord): num_labels -= 1       # Conversion of Gradient Coordinates to Angles
data_parser.add_argument(                                   # Dataset Number of Labels
        '--num_labels', type = int,                         # Default: 7
        default = num_labels,
        help = "MUDI Dataset Number of Labels")

# --------------------------------------------------------------------------------------------

# Addition of File & Folderpath Arguments
data_parser.add_argument(                               # Path for Main Dataset Folder
        '--main_folderpath', type = str,
        default = '../../../Datasets/MUDI Dataset',
        help = 'Main Folderpath for Root Dataset')
data_settings = data_parser.parse_args("")
data_parser.add_argument(                               # Path for Folder Containing Patient Data Files
        '--patient_folderpath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Patient Data'),
        help = 'Input Folderpath for Segregated Patient Data')
data_parser.add_argument(                               # Path for Folder Containing Mask Data Files
        '--mask_folderpath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Patient Mask'),
        help = 'Input Folderpath for Segregated Patient Mask Data')
data_parser.add_argument(                               # Path for Parameter Value File
        '--param_filepath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Raw Data/parameters_new.xlsx'),
        help = 'Input Filepath for Parameter Value Table')
data_parser.add_argument(                               # Path for Patient Information File
        '--info_filepath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Raw Data/header1_.csv'),
        help = 'Input Filepath for Patient Information Table')
data_parser.add_argument(                               # Path for Dataset Saved Files
        '--save_folderpath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Saved Data/V{data_settings.version}'),
        help = 'Output Folderpath for MUDI Dataset Saved Versions')

# --------------------------------------------------------------------------------------------

# Dataset Splitting Arguments
data_parser.add_argument(                       # Number of Patients to be used in the Test Set
        '--test_patients', type = int,          # Default: 1
        default = 1,
        help = "Number of Patients in Test Set")
data_parser.add_argument(                       # Number / Percentage of Parameters for Training Set's Training
        '--train_params', type = int,           # Default: 500
        default = 500,
        help = "Number / Percentage of Patients in the Training of the Training Set")
data_parser.add_argument(                       # Number / Percentage of Parameters for Training Set's Training
        '--test_params', type = int,            # Default: 20
        default = 20,
        help = "Number / Percentage of Patients in the Training of the Test Set")

# Boolean Control Input & Shuffling Arguments
data_parser.add_argument(                       # Control Variable for the Usage of Percentage Values in Parameters
        '--percentage', type = bool,            # Default: False
        default = False,
        help = "Control Variable for the Usage of Percentage Values in Parameters")
data_parser.add_argument(                       # Ability to Shuffle the Patients that compose both Training and Test Sets
        '--patient_shuffle', type = bool,       # Default: False
        default = False,
        help = "Ability to Shuffle the Patients that compose both Training and Test Sets")
data_parser.add_argument(                       # Ability to Shuffle the Samples inside both Training and Validation Sets
        '--sample_shuffle', type = bool,        # Default: False
        default = False,
        help = "Ability to Shuffle the Samples inside both Training and Validation Sets")
data_parser.add_argument(                       # Number of Workers for DataLoader Usage
        '--num_workers', type = int,                # Default: 1
        default = 20,
        help = "Number of Workers for DataLoader Usage")

# --------------------------------------------------------------------------------------------

# Dataset Pre-Processing Arguments
data_parser.add_argument(                       # Data Pre-Processing Method
        '--pre_processing', type = str,         # Default: Zero Padding
        default = 'Zero Padding',
        choices = ['Interpolation', 'Zero Padding', 'CNN'],
        help = "Data Pre-Processing Method")
data_parser.add_argument(                       # Final 3D Image Shape for Pre-Processing Output
        '--img_shape', type = np.array,         # Default: [85, 128, 128]
        default = np.array((60, 96, 96)),
        help = "Final 3D Image Shape for Pre-Processing Output")
data_parser.add_argument(                       # Number of Selected Slices in 2D Conversion
        '--num_slices', type = int,             # Default: 35
        default = 35,
        help = "Number of Selected Slices in 2D Conversion")
data_settings = data_parser.parse_args("")

In [5]:
# All4One 2D VAE Model Parametrizations Parser Initialization
model_parser = argparse.ArgumentParser(
        description = "All4One 2D VAE Settings")
model_parser.add_argument(              # Model Version Variable
        '--model_version', type = int,  # Default: 0
        default = 2,
        help = "Experiment Version")
model_parser.add_argument(              # Dataset Version Variable
        '--data_version', type = int,   # Default: 0
        default = 3,
        help = "MUDI Dataset Version")

# --------------------------------------------------------------------------------------------

# Addition of Filepath Arguments
model_parser.add_argument(
        '--reader_folderpath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Dataset Reader'),
        help = 'Input Folderpath for MUDI Dataset Reader')
model_parser.add_argument(
        '--data_folderpath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Saved Data/V{data_settings.version}'),
        help = 'Input Folderpath for MUDI Dataset Saved Versions')
model_parser.add_argument(
        '--model_folderpath', type = str,
        default = 'Model Builds',
        help = 'Input Folderpath for Model Build & Architecture')
model_parser.add_argument(
        '--script_folderpath', type = str,
        default = 'Training Scripts',
        help = 'Input Folderpath for Training & Testing Script Functions')
model_parser.add_argument(
        '--save_folderpath', type = str,
        default = 'Saved Models',
        help = 'Output Folderpath for Saved & Saving Models')

# --------------------------------------------------------------------------------------------

# Addition of Training Requirement Arguments
model_parser.add_argument(              # Number of Epochs
        '--num_epochs', type = int,     # Default: 1
        default = 50,
        help = "Number of Epochs in Training Mode")
model_parser.add_argument(              # Base Learning Rate
        '--base_lr', type = float,      # Default: 1e-4
        default = 1e-4,
        help = "Base Learning Rate Value in Training Mode")
model_parser.add_argument(              # Weight Decay Value
        '--weight_decay', type = float, # Default: 0.0001
        default = 1e-4,
        help = "Weight Decay Value in Training Mode")
model_parser.add_argument(              # Learning Rate Decay Ratio
        '--lr_decay', type = float,     # Default: 0.9
        default = 0.9,
        help = "Learning Rate Decay Value in Training Mode")
        
# --------------------------------------------------------------------------------------------

# Addition of Model Architecture Arguments
model_parser.add_argument(              # Dataset Number of Labels
        '--num_labels', type = int,     # Default: 7
        default = data_settings.num_labels,
        help = "MUDI Dataset Number of Labels")
model_parser.add_argument(              # Latent Space Dimensionality
        '--latent_dim', type = int,     # Default: 64
        default = 128,
        help = "Latent Space Dimensionality Value")
model_parser.add_argument(              # Convolutional Layer Expansion Value
        '--expansion', type = int,      # Default: 1
        default = 1,
        help = "Convolutional Layer Expansion Value")
model_parser.add_argument(              # Kullback-Leibler Loss Weight
        '--kl_alpha', type = int,       # Default: 1
        default = 1,
        help = "Kullback-Leibler Loss Weight")
if data_settings.dim == 2: num_channel = 1
else: num_channel = data_settings.num_slices
model_parser.add_argument(              # Number of Channels in given 3D Image Batch
        '--num_channel', type = int,    # Default: 1 (for 2D MUDI Dataset)
        default = num_channel,
        help = "Number of Channels in given 3D Image Batch")
model_parser.add_argument(              # Image Side Length (No Support for non-Square Images)
        '--img_shape', type = int,      # Default: 128 
        default = data_settings.img_shape[-1],
        help = "Image Side Length")

model_settings = model_parser.parse_args("")
model_settings.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---

# *Model* **Building**

---

In [61]:
##############################################################################################
# --------------------------------------- Encoder Build --------------------------------------
##############################################################################################

# Main Encoder Block Construction Class
class EncoderBlock(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel: int,        # Number of Encoder Block's Convolutional Input Channels
        stride: int = 1,
        padding: int = 1
    ):

        # Main Block's Downsampling Architecture Definition
        super(EncoderBlock, self).__init__()
        out_channel = in_channel * stride
        self.block = nn.Sequential(
            nn.Conv2d(      in_channel, out_channel, kernel_size = 3,
                            stride = stride, padding = padding, bias = False),
            nn.BatchNorm2d( out_channel),
            nn.ReLU(),
            nn.Conv2d(      out_channel, out_channel, kernel_size = 3,
                            stride = 1, padding = padding, bias = False),
            nn.BatchNorm2d( out_channel))

        # Main Block's Shorcut Architecture Definition
        if stride == 1: self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(      in_channel, out_channel, kernel_size = 1,
                                stride = stride, bias = False),
                nn.BatchNorm2d( out_channel))

    # Main Block Application Function
    def forward(self, X):

        # Main Block Architecture Walkthrough
        out = self.block(X)
        out = out + self.shortcut(X)
        return F.relu(out)

# --------------------------------------------------------------------------------------------

# Encoder Model Class
class Encoder(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel: int = 64,               # Number of Input Channels in ResNet Main Block Intermediate Layers' Blocks
        num_channel: int = 1,               # Number of Channels in each Image (Default: 1 for 2D Dataset)
        latent_dim: int = 64,               # Latent Space Dimensionality
        expansion: int = 1,                 # Expansion Factor for Stride Value in ResNet Main Block Intermediate Layers
        num_blocks: list = [2, 2, 2, 2]     # Number of Blocks in ResNet Main Block Intermediate Layers
    ):

        # Class Variable Logging
        super(Encoder, self).__init__()
        assert(len(num_blocks) == 4), "Number of Blocks provided Not Supported!"
        self.in_channel = in_channel; self.num_channel = num_channel
        self.latent_dim = latent_dim; self.expansion = expansion

        # Encoder Downsampling Architecture Definition
        self.net = nn.Sequential(
            nn.Conv2d(      self.num_channel, 64, kernel_size = 3,
                            stride = 2, padding = 1, bias = False),
            nn.BatchNorm2d( 64),
            nn.ReLU(),
            self.main_layer(out_channel = 64, num_blocks = num_blocks[0], stride = 1),
            self.main_layer(out_channel = 128, num_blocks = num_blocks[1], stride = 2),
            self.main_layer(out_channel = 256, num_blocks = num_blocks[2], stride = 2),
            self.main_layer(out_channel = 512, num_blocks = num_blocks[3], stride = 2))
        self.linear = nn.Linear(512, 2 * self.latent_dim)

    # Encoder Repeatable Layer Definition Function
    def main_layer(
        self,
        out_channel: int,
        num_blocks: int,
        stride: int = 2
    ):

        # Layer Architecture Creation
        stride = [stride] + [1] * (num_blocks - 1); layer = []
        for s in stride:
            layer.append(EncoderBlock(self.in_channel, stride = s))
            self.in_channel = out_channel
        return nn.Sequential(*layer)
    
    # Encoder Application Function
    def forward(
        self,
        X: np.ndarray or torch.Tensor       # 3D Image Input
    ):

        # Forwad Propagation in Encoder Architecture
        X = torch.Tensor(X)
        out = self.net(X)
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        mu = out[:, :self.latent_dim]
        var = out[:, self.latent_dim:]
        return mu, var

In [104]:
##############################################################################################
# --------------------------------------- Decoder Build --------------------------------------
##############################################################################################

# 2D Resizing Convolution
class ResizeConv2d(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel: int,
        out_channel: int,
        kernel_size: int,
        scale_factor: int,
        mode: str ='nearest'
    ):

        # 2D Resizing Convolution 
        super(ResizeConv2d, self).__init__()
        self.scale_factor = scale_factor; self.mode = mode
        self.conv = nn.Conv2d(  in_channel, out_channel,
                                kernel_size, stride = 1, padding = 1)

    # Resizing Convolutional Block Application Function
    def forward(
        self,
        X: np.ndarray or torch.Tensor       # 3D Image Input
    ):
        
        # Resizing Convolutional Block Architecture Walkthrough
        X = torch.Tensor(X)
        out = F.interpolate(X, scale_factor = self.scale_factor,
                            mode = self.mode)
        return self.conv(out)

# --------------------------------------------------------------------------------------------

# Main Decoder Block Construction Class
class DecoderBlock(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel: int,        # Number of Decoder Block's Convolutional Input Channels
        stride: int = 1,
        padding: int = 1
    ):

        # Main Block's Upsampling Architecture Definition
        super(DecoderBlock, self).__init__()
        out_channel = int(in_channel / stride)
        self.block1 = nn.Sequential(
            nn.Conv2d(      in_channel, in_channel, kernel_size = 3,
                            stride = 1, padding = padding, bias = False),
            nn.BatchNorm2d( in_channel),
            nn.ReLU())

        # Main Block's Shorcut Architecture Definition
        if stride == 1:
            self.block2 = nn.Sequential(
                nn.Conv2d(      in_channel, out_channel, kernel_size = 3,
                                stride = 1, padding = padding, bias = False),
                nn.BatchNorm2d( out_channel))
            self.shortcut = nn.Sequential()
        else:
            self.block2 = nn.Sequential(
                ResizeConv2d(   in_channel, out_channel,
                                kernel_size = 3, scale_factor = stride),
                nn.BatchNorm2d( out_channel))
            self.shortcut = self.block2

    # Main Block Application Function
    def forward(
        self,
        X: np.ndarray or torch.Tensor       # 3D Image Input
    ):

        # Main Block Architecture Walkthrough
        out = self.block1(X)
        out = self.block2(out)
        out = out + self.shortcut(X)
        return F.relu(out)

# --------------------------------------------------------------------------------------------

# Decoder Model Class
class Decoder(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        num_channel: int = 1,               # Number of Channels in each Image (Default: 1 for 2D Dataset)
        img_shape: int = 96,                # Square Image Side Length (1/4th of pre-Convolution No. Channels)
        latent_dim: int = 64,               # Latent Space Dimensionality
        expansion: int = 1,                 # Expansion Factor for Stride Value in ResNet Main Block Intermediate Layers
        num_blocks: list = [2, 2, 2, 2]     # Number of Blocks in ResNet Main Block Intermediate Layers
    ):

        # Class Variable Logging
        super(Decoder, self).__init__()
        assert(len(num_blocks) == 4), "Number of Blocks provided Not Supported!"
        self.num_blocks = num_blocks
        self.in_channel = img_shape * (2 ** (len(self.num_blocks)))
        self.channel = self.in_channel; self.num_channel = num_channel
        self.img_shape = img_shape; self.latent_dim = latent_dim

        # Decoder Upsampling Architecture Definition
        self.linear = nn.Linear(self.latent_dim, self.in_channel)
        self.net = nn.Sequential(
            self.main_layer(out_channel = int(self.channel / 2), num_blocks = num_blocks[3], stride = 2),
            self.main_layer(out_channel = int(self.channel / 2), num_blocks = num_blocks[2], stride = 2),
            self.main_layer(out_channel = int(self.channel / 2), num_blocks = num_blocks[1], stride = 2),
            self.main_layer(out_channel = int(self.channel), num_blocks = num_blocks[0], stride = 1),
            nn.Sigmoid(),
            ResizeConv2d(   self.img_shape * 2, self.num_channel,
                            kernel_size = 3, scale_factor = img_shape / 64))

    # Decoder Repeatable Layer Definition Function
    def main_layer(
        self,
        out_channel: int,
        num_blocks: int,
        stride: int = 2
    ):

        # Layer Architecture Creation
        stride = [stride] + [1] * (num_blocks - 1); layer = []
        for s in reversed(stride):
            layer.append(DecoderBlock(self.channel, stride = s))
        self.channel = out_channel
        return nn.Sequential(*layer)

    # Decoder Application Function
    def forward(
        self,
        z: np.ndarray or torch.Tensor       # 3D Latent Representation Input
    ):

        # Forward Propagation in Decoder Architecture
        z = torch.Tensor(z)
        out = self.linear(z)
        out = out.view(z.size(0), self.in_channel, 1, 1)
        out = F.interpolate(out, scale_factor = 2 ** (len(self.num_blocks) - 1))
        out = self.net(out)
        out = out.view( z.size(0), self.num_channel,
                        self.img_shape, self.img_shape)
        return out


In [5]:
##############################################################################################
# ------------------------------------- All4One VAE Build ------------------------------------
##############################################################################################

# VAE Model Class
class All4One(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        latent_dim: int = 64,               # Latent Space Dimensionality
        num_channel: int = 1,               # Number of Channels in each Image (Default: 1 for 2D Dataset)
        img_shape: int = 96,                # Square Image Side Length (1/4th of pre-Convolution No. Channels)
        in_channel: int = 64,               # Number of Input Channels in ResNet Main Block Intermediate Layers' Blocks
        expansion: int = 1,                 # Expansion Factor for Stride Value in ResNet Main Block Intermediate Layers
        num_blocks: list = [2, 2, 2, 2]     # Number of Blocks in ResNet Main Block Intermediate Layers
    ):

        # Encoder & Decoder Construction
        super(All4One, self).__init__()
        self.encoder = Encoder(in_channel, num_channel, latent_dim, expansion, num_blocks)
        self.decoder = Decoder(num_channel, img_shape, latent_dim, expansion, num_blocks)

    # Latent Space Reparametrization
    @staticmethod
    def reparam(mean, logvar):
        std = torch.exp(logvar / 2)
        epsilon = torch.randn_like(std)
        return epsilon * std + mean

    # All4One VAE Application Function
    def forward(
        self,
        X: np.ndarray or torch.Tensor       # 3D Latent Representation Input
    ):

        # Forward Propagation in VAE Architecture
        mu, var = self.encoder(X)
        z = self.reparam(mu, var)
        X_fake = self.decoder(z)
        return mu, var, z, X_fake

In [105]:
# Encoder Initialization Example
enc = Encoder()
#summary(enc, (1, 96, 96))
X = torch.rand(5, 1, 96, 96)
mu, var = enc(X)

def reparameterize(mean, logvar):
    std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two
    epsilon = torch.randn_like(std)
    return epsilon * std + mean

# Decoder Initialization Example
z = reparameterize(mu, var)
dec = Decoder()
X_fake = dec(z)
print(X_fake.shape)
#summary(dec, (5, 128))

torch.Size([5, 1536, 8, 8])
torch.Size([5, 1, 96, 96])
torch.Size([5, 1, 96, 96])


---

# **Running** *Scripts*

---

In [127]:
# VAE Model Training, Validation & Testing Script Class
class LitAll4One(pl.LightningModule):

    ##############################################################################################
    # ----------------------------------- Model & Dataset Setup ----------------------------------
    ##############################################################################################

    # Constructor / Initialization Function
    def __init__(
        self,
        settings: argparse.ArgumentParser,              # Model Settings & Parametrizations
    ):

        # Class Variable Logging
        super().__init__()
        self.settings = settings
        self.lr_decay_epochs = [80, 140]                # Epochs for Learning Rate Decay

        # Model Initialization
        self.model = All4One(               latent_dim = settings.latent_dim,
                                            num_channel = settings.num_channel,
                                            img_shape = settings.img_shape,
                                            expansion = settings.expansion)
        self.optimizer = torch.optim.Adam(  self.model.parameters(),
                                            lr = self.settings.base_lr,
                                            weight_decay = self.settings.weight_decay, )
        self.recon_criterion = nn.MSELoss(); self.past_epochs = 0

        # Existing Model Checkpoint Loading
        self.model_filepath = Path(f"{self.settings.save_folderpath}/V{self.settings.model_version}/All4One (V{self.settings.model_version}).pth")
        if self.settings.model_version != 0 and self.model_filepath.exists():

            # Checkpoint Fixing (due to the use of nn.DataParallel)
            print(f"DOWNLOADING All4One 2D VAE (Version {self.settings.model_version})")
            checkpoint = torch.load(self.model_filepath); self.checkpoint_fix = dict()
            for sd, sd_value in checkpoint.items():
                if sd == 'ModelSD' or sd == 'OptimizerSD':
                    self.checkpoint_fix[sd] = OrderedDict()
                    for key, value in checkpoint[sd].items():
                        if key[0:7] == 'module.':
                            self.checkpoint_fix[sd][key[7:]] = value
                        else: self.checkpoint_fix[sd][key] = value
                else: self.checkpoint_fix[sd] = sd_value
            
            # Application of Checkpoint's State Dictionary
            self.model.load_state_dict(self.checkpoint_fix['ModelSD'])
            self.optimizer.load_state_dict(self.checkpoint_fix['OptimizerSD'])
            self.past_epochs = self.checkpoint_fix['Training Epochs']
            torch.set_rng_state(self.checkpoint_fix['RNG State'])
            del checkpoint
        self.lr_schedule = torch.optim.lr_scheduler.ExponentialLR(  self.optimizer,     # Learning Rate Decay
                                                    gamma = self.settings.lr_decay)     # in Chosen Epochs
        self.model = nn.DataParallel(self.model.to(self.settings.device))
        
    # Optimizer Initialization Function
    def configure_optimizers(self): return super().configure_optimizers()

    # Foward Functionality
    def forward(self, X): return self.model(X)

    # --------------------------------------------------------------------------------------------
    
    # Train Set DataLoader Download
    def train_dataloader(self):
        TrainTrainLoader = v3DMUDI.loader(  Path(f"{self.settings.data_folderpath}"),
                                            dim = 2, version = self.settings.data_version,
                                            set_ = 'Test', mode_ = 'Train')
                                            #set_ = 'Train', mode_ = 'Train')
        self.train_batches = len(TrainTrainLoader)
        return TrainTrainLoader
    
    # Validation Set DataLoader Download
    def val_dataloader(self):
        TrainValLoader = v3DMUDI.loader(Path(f"{self.settings.data_folderpath}"),
                                        dim = 2, version = self.settings.data_version,
                                        set_ = 'Test', mode_ = 'Train')
                                        #set_ = 'Train', mode_ = 'Val')
        self.val_batches = len(TrainValLoader)
        return TrainValLoader

    # Test Set DataLoader Download
    def test_dataloader(self):
        TestValLoader = v3DMUDI.loader( Path(f"{self.settings.data_folderpath}"),
                                        dim = 2, version = self.settings.data_version,
                                        set_ = 'Test', mode_ = 'Val')
        self.test_batches = len(TestValLoader)
        return TestValLoader

    ##############################################################################################
    # ------------------------------------- Training Script --------------------------------------
    ##############################################################################################

    # Functionality called upon the Start of Training
    def on_train_start(self):
        
        # Model Training Mode Setup
        self.model.train()
        self.automatic_optimization = False

        # TensorBoard Logger Initialization
        self.train_logger = TensorBoardLogger(f'{self.settings.save_folderpath}/V{self.settings.model_version}', 'Training Performance')

    # Functionality called upon the Start of Training Epoch
    def on_train_epoch_start(self):
        self.train_loss = 0
        self.train_kl_loss = 0
        self.train_recon_loss = 0

    # --------------------------------------------------------------------------------------------

    # Training Step / Batch Loop 
    def training_step(self, batch, batch_idx):

        # Data Handling
        X_batch, ygt_batch = batch
        X_batch = X_batch.type(torch.float).to(self.settings.device)
        #ygt_batch = ygt_batch.type(torch.float).to(self.settings.device)

        # Forward Propagation & Loss Computation
        mu_batch, var_batch, z_batch, X_fake_batch = self.model(X_batch)
        kl_loss =  (-0.5 * (1 + var_batch - mu_batch ** 2 - torch.exp(var_batch)).sum(dim = 1)).mean(dim = 0)        
        recon_loss = self.recon_criterion(X_fake_batch, X_batch)
        loss = (recon_loss * self.settings.kl_alpha) + kl_loss

        # Backwards Propagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        del X_batch, ygt_batch, mu_batch, var_batch, z_batch, X_fake_batch
        return {'loss': loss, 'kl_loss': kl_loss, 'recon_loss': recon_loss}

    # Functionality called upon the End of a Batch Training Step
    def on_train_batch_end(self, loss, batch, batch_idx):

        # Loss Values Update
        self.train_loss = self.train_loss + loss['loss'].item()
        self.train_kl_loss = self.train_kl_loss + loss['kl_loss'].item()
        self.train_recon_loss = self.train_recon_loss + loss['recon_loss'].item()

        # Last Batch's Example Original Image Saving
        if batch_idx == self.train_batches - 1:
            self.X_example, self.y_example = batch
            self.y_example = self.y_example[-1, :]
            self.X_example = self.X_example[-1, :, :, :]
            self.X_example = self.X_example.view(1, self.settings.num_channel,
                        self.settings.img_shape, self.settings.img_shape)

    # --------------------------------------------------------------------------------------------

    # Example Original vs Reconstructed Example Image Plotting Function
    def img_plot(self, num_epochs: int = 0):

        # Original vs Reconstruced Image Sampling
        mu_example, var_example, z_example, self.X_fake_example = self.model(self.X_example)
        self.X_example = self.X_example.view(self.settings.img_shape, self.settings.img_shape)
        self.X_fake_example = self.X_fake_example.view(self.settings.img_shape, self.settings.img_shape)
        del mu_example, var_example, z_example

        # Original Example Image Subplot
        figure = plt.figure(num_epochs, figsize = (60, 60))
        plt.tight_layout(); plt.title(f'Epoch #{num_epochs}')
        plt.subplot(2, 1, 1, title = 'Original')
        plt.xticks([]); plt.yticks([]); plt.grid(False); plt.tight_layout()
        plt.imshow(self.X_example.detach().numpy(), cmap = plt.cm.binary)

        # Reconstructed Example Image Subplot
        plt.subplot(2, 1, 2, title = 'Reconstruction')
        plt.xticks([]); plt.yticks([]); plt.grid(False)
        plt.imshow(self.X_fake_example.detach().numpy(), cmap = plt.cm.binary)
        return figure

    # Functionality called upon the End of a Training Epoch
    def on_train_epoch_end(self):

        # Learning Rate Decay
        if (self.trainer.current_epoch + 1) in self.lr_decay_epochs:
            self.lr_schedule.step()

        # Loss Value Updating (Batch Division)
        num_epochs = self.past_epochs + self.current_epoch
        self.train_loss = self.train_loss / self.train_batches
        self.train_kl_loss = self.train_kl_loss / self.train_batches
        self.train_recon_loss = self.train_recon_loss / self.train_batches

        # TensorBoard Logger Model Visualizer, Update for Scalar Values & Example Image Plotting
        if num_epochs == 0:
            self.train_logger.experiment.add_graph(self.model, self.X_example)
        self.train_logger.experiment.add_scalar("Training Loss", self.train_loss, num_epochs)
        self.train_logger.experiment.add_scalar("Kullback Leibler Divergence", self.train_kl_loss, num_epochs)
        self.train_logger.experiment.add_scalar("Image Reconstruction Loss", self.train_recon_loss, num_epochs)
        plot = self.img_plot(num_epochs)
        self.train_logger.experiment.add_figure("Original vs Reconstruction", plot, num_epochs)

        # Model Checkpoint Saving
        torch.save({'ModelSD': self.model.state_dict(),
                    'OptimizerSD': self.optimizer.state_dict(),
                    'Training Epochs': num_epochs,
                    'RNG State': torch.get_rng_state()},
                    self.model_filepath)

    ##############################################################################################
    # ------------------------------------ Validation Script -------------------------------------
    ##############################################################################################


---

# **Linear** *VAE*

---

In [6]:
#
class Encoder(nn.Module):
    def __init__(
        self,
        latent_dim: int = 128,
        img_shape: int = 96
    ):
        
        #
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim
        self.img_shape = img_shape

        #
        self.fc1 = nn.Linear(img_shape * img_shape, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, latent_dim * 2)
        
    def forward(self, x):
        x = x.view(-1, self.img_shape * self.img_shape)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        mean, logvar = x[:, :self.latent_dim], x[:, self.latent_dim:]
        return mean, logvar

#
class Decoder(nn.Module):

    #
    def __init__(
        self,
        latent_dim: int = 128,
        img_shape: int = 96
    ):

        #
        super(Decoder, self).__init__()
        self.img_shape = img_shape

        #
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, img_shape * img_shape)

    #
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x.view(-1, 1, self.img_shape, self.img_shape)

#
class All4One(nn.Module):

    #
    def __init__(
        self,
        latent_dim: int = 128,
        img_shape: int = 96
    ):

        #
        super(All4One, self).__init__()
        self.encoder = Encoder(latent_dim, img_shape)
        self.decoder = Decoder(latent_dim, img_shape)

    
    # Latent Space Reparametrization
    @staticmethod
    def reparam(mean, logvar):
        std = torch.exp(logvar / 2)
        epsilon = torch.randn_like(std)
        return epsilon * std + mean

    # All4One VAE Application Function
    def forward(
        self,
        X: np.ndarray or torch.Tensor       # 3D Latent Representation Input
    ):

        # Forward Propagation in VAE Architecture
        mu, var = self.encoder(X)
        z = self.reparam(mu, var)
        X_fake = self.decoder(z)
        return mu, var, z, X_fake


In [None]:
vae = VAE()
X = torch.rand(5, 1, 96, 96)
mu, var, z, X_fake = vae(X)
plt.subplot(2, 1, 1); plt.imshow(X[0, 0, :, :])
plt.subplot(2, 1, 2); plt.imshow(X_fake[0, 0, :, :].detach().numpy())

In [7]:
# VAE Model Training, Validation & Testing Script Class
class LitAll4One(pl.LightningModule):

    ##############################################################################################
    # ----------------------------------- Model & Dataset Setup ----------------------------------
    ##############################################################################################

    # Constructor / Initialization Function
    def __init__(
        self,
        settings: argparse.ArgumentParser,              # Model Settings & Parametrizations
    ):

        # Class Variable Logging
        super().__init__()
        self.settings = settings
        self.lr_decay_epochs = [80, 140]                # Epochs for Learning Rate Decay

        # Model Initialization
        self.model = All4One(               latent_dim = settings.latent_dim,
                                            img_shape = settings.img_shape)
        self.optimizer = torch.optim.Adam(  self.model.parameters(),
                                            lr = self.settings.base_lr,
                                            weight_decay = self.settings.weight_decay, )
        self.recon_criterion = nn.MSELoss(); self.past_epochs = 0

        # Existing Model Checkpoint Loading
        self.model_filepath = Path(f"{self.settings.save_folderpath}/V{self.settings.model_version}/All4One (V{self.settings.model_version}).pth")
        if self.settings.model_version != 0 and self.model_filepath.exists():

            # Checkpoint Fixing (due to the use of nn.DataParallel)
            print(f"DOWNLOADING All4One 2D VAE (Version {self.settings.model_version})")
            checkpoint = torch.load(self.model_filepath); self.checkpoint_fix = dict()
            for sd, sd_value in checkpoint.items():
                if sd == 'ModelSD' or sd == 'OptimizerSD':
                    self.checkpoint_fix[sd] = OrderedDict()
                    for key, value in checkpoint[sd].items():
                        if key[0:7] == 'module.':
                            self.checkpoint_fix[sd][key[7:]] = value
                        else: self.checkpoint_fix[sd][key] = value
                else: self.checkpoint_fix[sd] = sd_value
            
            # Application of Checkpoint's State Dictionary
            self.model.load_state_dict(self.checkpoint_fix['ModelSD'])
            self.optimizer.load_state_dict(self.checkpoint_fix['OptimizerSD'])
            self.past_epochs = self.checkpoint_fix['Training Epochs']
            torch.set_rng_state(self.checkpoint_fix['RNG State'])
            del checkpoint
        self.lr_schedule = torch.optim.lr_scheduler.ExponentialLR(  self.optimizer,     # Learning Rate Decay
                                                    gamma = self.settings.lr_decay)     # in Chosen Epochs
        self.model = nn.DataParallel(self.model.to(self.settings.device))
        
    # Optimizer Initialization Function
    def configure_optimizers(self): return super().configure_optimizers()

    # Foward Functionality
    def forward(self, X): return self.model(X)

    # --------------------------------------------------------------------------------------------
    
    # Train Set DataLoader Download
    def train_dataloader(self):
        TrainTrainLoader = v3DMUDI.loader(  Path(f"{self.settings.data_folderpath}"),
                                            dim = 2, version = self.settings.data_version,
                                            #set_ = 'Test', mode_ = 'Train')
                                            set_ = 'Train', mode_ = 'Train')
        self.train_batches = len(TrainTrainLoader)
        return TrainTrainLoader
    
    # Validation Set DataLoader Download
    def val_dataloader(self):
        TrainValLoader = v3DMUDI.loader(Path(f"{self.settings.data_folderpath}"),
                                        dim = 2, version = self.settings.data_version,
                                        #set_ = 'Test', mode_ = 'Train')
                                        set_ = 'Train', mode_ = 'Val')
        self.val_batches = len(TrainValLoader)
        return TrainValLoader

    # Test Set DataLoader Download
    def test_dataloader(self):
        TestValLoader = v3DMUDI.loader( Path(f"{self.settings.data_folderpath}"),
                                        dim = 2, version = self.settings.data_version,
                                        set_ = 'Test', mode_ = 'Val')
        self.test_batches = len(TestValLoader)
        return TestValLoader

    ##############################################################################################
    # ------------------------------------- Training Script --------------------------------------
    ##############################################################################################

    # Functionality called upon the Start of Training
    def on_train_start(self):
        
        # Model Training Mode Setup
        self.model.train()
        self.automatic_optimization = False

        # TensorBoard Logger Initialization
        self.train_logger = TensorBoardLogger(f'{self.settings.save_folderpath}/V{self.settings.model_version}', 'Training Performance')

    # Functionality called upon the Start of Training Epoch
    def on_train_epoch_start(self):
        self.train_loss = 0
        self.train_kl_loss = 0
        self.train_recon_loss = 0

    # --------------------------------------------------------------------------------------------

    # Training Step / Batch Loop 
    def training_step(self, batch, batch_idx):

        # Data Handling
        X_batch, ygt_batch = batch
        X_batch = X_batch.type(torch.float).to(self.settings.device)
        #ygt_batch = ygt_batch.type(torch.float).to(self.settings.device)

        # Forward Propagation & Loss Computation
        mu_batch, var_batch, z_batch, X_fake_batch = self.model(X_batch)
        kl_loss =  (-0.5 * (1 + var_batch - mu_batch ** 2 - torch.exp(var_batch)).sum(dim = 1)).mean(dim = 0)        
        recon_loss = self.recon_criterion(X_fake_batch, X_batch)
        loss = (recon_loss * self.settings.kl_alpha) + kl_loss

        # Backwards Propagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        del X_batch, ygt_batch, mu_batch, var_batch, z_batch, X_fake_batch
        return {'loss': loss, 'kl_loss': kl_loss, 'recon_loss': recon_loss}

    # Functionality called upon the End of a Batch Training Step
    def on_train_batch_end(self, loss, batch, batch_idx):

        # Loss Values Update
        self.train_loss = self.train_loss + loss['loss'].item()
        self.train_kl_loss = self.train_kl_loss + loss['kl_loss'].item()
        self.train_recon_loss = self.train_recon_loss + loss['recon_loss'].item()

        # Last Batch's Example Original Image Saving
        if batch_idx == self.train_batches - 1:
            self.X_example, self.y_example = batch
            self.y_example = self.y_example[-1, :]
            self.X_example = self.X_example[-1, :, :, :]
            self.X_example = self.X_example.view(1, self.settings.num_channel,
                        self.settings.img_shape, self.settings.img_shape)

    # --------------------------------------------------------------------------------------------

    # Example Original vs Reconstructed Example Image Plotting Function
    def img_plot(self, num_epochs: int = 0):

        # Original vs Reconstruced Image Sampling
        mu_example, var_example, z_example, self.X_fake_example = self.model(self.X_example)
        self.X_example = self.X_example.view(self.settings.img_shape, self.settings.img_shape)
        self.X_fake_example = self.X_fake_example.view(self.settings.img_shape, self.settings.img_shape)
        del mu_example, var_example, z_example

        # Original Example Image Subplot
        figure = plt.figure(num_epochs, figsize = (60, 60))
        plt.tight_layout(); plt.title(f'Epoch #{num_epochs}')
        plt.subplot(2, 1, 1, title = 'Original')
        plt.xticks([]); plt.yticks([]); plt.grid(False); plt.tight_layout()
        plt.imshow(self.X_example.detach().numpy(), cmap = plt.cm.binary)

        # Reconstructed Example Image Subplot
        plt.subplot(2, 1, 2, title = 'Reconstruction')
        plt.xticks([]); plt.yticks([]); plt.grid(False)
        plt.imshow(self.X_fake_example.detach().numpy(), cmap = plt.cm.binary)
        return figure

    # Functionality called upon the End of a Training Epoch
    def on_train_epoch_end(self):

        # Learning Rate Decay
        if (self.trainer.current_epoch + 1) in self.lr_decay_epochs:
            self.lr_schedule.step()

        # Loss Value Updating (Batch Division)
        num_epochs = self.past_epochs + self.current_epoch
        self.train_loss = self.train_loss / self.train_batches
        self.train_kl_loss = self.train_kl_loss / self.train_batches
        self.train_recon_loss = self.train_recon_loss / self.train_batches

        # TensorBoard Logger Model Visualizer, Update for Scalar Values & Example Image Plotting
        if num_epochs == 0:
            self.train_logger.experiment.add_graph(self.model, self.X_example)
        self.train_logger.experiment.add_scalar("Training Loss", self.train_loss, num_epochs)
        self.train_logger.experiment.add_scalar("Kullback Leibler Divergence", self.train_kl_loss, num_epochs)
        self.train_logger.experiment.add_scalar("Image Reconstruction Loss", self.train_recon_loss, num_epochs)
        plot = self.img_plot(num_epochs)
        self.train_logger.experiment.add_figure("Original vs Reconstruction", plot, num_epochs)

        # Model Checkpoint Saving
        torch.save({'ModelSD': self.model.state_dict(),
                    'OptimizerSD': self.optimizer.state_dict(),
                    'Training Epochs': num_epochs,
                    'RNG State': torch.get_rng_state()},
                    self.model_filepath)

    ##############################################################################################
    # ------------------------------------ Validation Script -------------------------------------
    ##############################################################################################


---

# **Main** *Scripts*

---

In [8]:
# Dataset Access
sys.path.append(f"{data_settings.main_folderpath}/Dataset Reader")
from v3DMUDI import v3DMUDI

# Dataset Version Creation
#data = v3DMUDI(data_settings)
#data.split(data_settings)
#data.save()
#v3DMUDI.label_unscale(data_settings.save_folderpath, version = data_settings.version, y = data.params)

# --------------------------------------------------------------------------------------------

# Full All4One 2D VAE Model Class Importing
#sys.path.append(model_settings.model_folderpath)
#from Encoder import Encoder
#from Decoder import Decoder
#from All4OneVAE import All4One

# Full All4One 2D VAE Model Training Importing
#sys.path.append(model_settings.script_folderpath)
#from LitAll4One import LitAll4One

# --------------------------------------------------------------------------------------------

# Model Initialization & Training
vae = LitAll4One(model_settings)
vae_trainer = pl.Trainer(   max_epochs = model_settings.num_epochs,
                            devices = 1 if torch.cuda.is_available() else None,
                            enable_progress_bar = True,
                            callbacks = [pl.callbacks.TQDMProgressBar(refresh_rate = 1)])
vae_trainer.fit(vae)
%load_ext tensorboard

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: c:\Users\pfernan2\Desktop\Experiments\Autoencoders\All4One\lightning_logs

  | Name            | Type         | Params
-------------------------------------------------
0 | model           | DataParallel | 9.8 M 
1 | recon_criterion | MSELoss      | 0     
-------------------------------------------------
9.8 M     Trainable params
0         Non-trainable params
9.8 M     Total params
39.235    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


: 