---

# *Initial* **Setup**

---

## **Library** *Settings*

The Real Package Name must be found in https://pypi.org

In [1]:
# Library Import
import os
import sys
import pickle
import psutil
import numpy as np
import argparse
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary
import torchvision
import pytorch_lightning as pl
import tensorboard
import sklearn
import fvcore
import matplotlib.pyplot as plt
import itk
import itkwidgets
import time
import timeit
import warnings
import alive_progress

In [2]:
# Functionality Import
from pathlib import Path
from typing import List, Literal, Optional, Callable, Dict, Literal, Optional, Union, Tuple
from collections import OrderedDict
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from torch.nn.utils import spectral_norm
from torchsummary import summary
from pytorch_lightning.loggers import TensorBoardLogger
torch.autograd.set_detect_anomaly(True)
from sklearn.preprocessing import StandardScaler
from PIL import Image
from ipywidgets import interactive, IntSlider
from tabulate import tabulate
from alive_progress import alive_bar
warnings.filterwarnings('ignore')

## **Control** *Station*

In [None]:
# Dataset Parametrizations Parser Initialization
data_parser = argparse.ArgumentParser(
description = "2/3D MUDI Dataset Settings")
data_parser.add_argument(                               # Dataset Version Variable
        '--version', type = int,                        # Default: 0
        default = 2,
        help = "Dataset Save Version")
data_parser.add_argument(                               # Dataset Dimensionality
        '--dim', type = int,                            # Default: 3
        default = 2,
        help = "Dataset Dimensionality")
data_parser.add_argument(                               # Dataset Batch Size Value
        '--batch_size', type = int,                     # Default: 500
        default = 500,
        help = "Dataset Batch Size Value")

# --------------------------------------------------------------------------------------------

# Dataset Label Parametrization Arguments
data_parser.add_argument(                       # Control Variable for the Inclusion of Patient ID in Labels
        '--patient_id', type = bool,            # Default: True
        default = False,
        help = "Control Variable for the Inclusion of Patient ID in Labels")
data_parser.add_argument(                       # Control Variable for the Conversion of 3 Gradient Directions
        '--gradient_coord', type = bool,        # Coordinates into 2 Gradient Direction Angles (suggested by prof. Chantal)
        default = False,                        # Default: True (3 Coordinate Gradient Values)
        help = "Control Variable for the Conversion of Gradient Direction Mode")
data_parser.add_argument(                       # Control Variable for the Rescaling & Normalization of Labels
        '--label_norm', type = bool,            # Default: True
        default = True,
        help = "Control Variable for the Rescaling & Normalization of Labels")
data_settings = data_parser.parse_args("")
num_labels = 7
if not(data_settings.patient_id): num_labels -= 1       # Exclusion of Patiend ID
if not(data_settings.gradient_coord): num_labels -= 1   # Conversion of Gradient Coordinates to Angles
data_parser.add_argument(                               # Dataset Number of Labels
        '--num_labels', type = int,             # Default: 7
        default = num_labels,
        help = "MUDI Dataset Number of Labels")

# --------------------------------------------------------------------------------------------

# Addition of File & Folderpath Arguments
data_parser.add_argument(                               # Path for Main Dataset Folder
        '--main_folderpath', type = str,
        default = '../../../Datasets/MUDI Dataset',
        help = 'Main Folderpath for Root Dataset')
data_settings = data_parser.parse_args("")
data_parser.add_argument(                               # Path for Folder Containing Patient Data Files
        '--patient_folderpath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Patient Data'),
        help = 'Input Folderpath for Segregated Patient Data')
data_parser.add_argument(                               # Path for Folder Containing Mask Data Files
        '--mask_folderpath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Patient Mask'),
        help = 'Input Folderpath for Segregated Patient Mask Data')
data_parser.add_argument(                               # Path for Parameter Value File
        '--param_filepath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Raw Data/parameters_new.xlsx'),
        help = 'Input Filepath for Parameter Value Table')
data_parser.add_argument(                               # Path for Patient Information File
        '--info_filepath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Raw Data/header1_.csv'),
        help = 'Input Filepath for Patient Information Table')
data_parser.add_argument(                               # Path for Dataset Saved Files
        '--save_folderpath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Saved Data/V{data_settings.version}'),
        help = 'Output Folderpath for MUDI Dataset Saved Versions')

# --------------------------------------------------------------------------------------------

# Dataset Splitting  Arguments
data_parser.add_argument(                       # Number of Patients to be used in the Test Set
        '--test_patients', type = int,          # Default: 1
        default = 1,
        help = "Number of Patients in Test Set")
data_parser.add_argument(                       # Number / Percentage of Parameters for Training Set's Training
        '--train_params', type = int,           # Default: 500
        default = 500,
        help = "Number / Percentage of Patients in the Training of the Training Set")
data_parser.add_argument(                       # Number / Percentage of Parameters for Training Set's Training
        '--test_params', type = int,            # Default: 20
        default = 20,
        help = "Number / Percentage of Patients in the Training of the Test Set")

# Boolean Control Input & Shuffling Arguments
data_parser.add_argument(                       # Control Variable for the Usage of Percentage Values in Parameters
        '--percentage', type = bool,            # Default: False
        default = False,
        help = "Control Variable for the Usage of Percentage Values in Parameters")
data_parser.add_argument(                       # Ability to Shuffle the Patients that compose both Training and Test Sets
        '--patient_shuffle', type = bool,       # Default: False
        default = False,
        help = "Ability to Shuffle the Patients that compose both Training and Test Sets")
data_parser.add_argument(                       # Ability to Shuffle the Samples inside both Training and Validation Sets
        '--sample_shuffle', type = bool,        # Default: False
        default = False,
        help = "Ability to Shuffle the Samples inside both Training and Validation Sets")

# --------------------------------------------------------------------------------------------

# Dataset Pre-Processing Arguments
data_parser.add_argument(                       # Data Pre-Processing Method
        '--pre_processing', type = str,         # Default: Zero Padding
        default = 'Zero Padding',
        choices = ['Interpolation', 'Zero Padding', 'CNN'],
        help = "Data Pre-Processing Method")
data_parser.add_argument(                       # Final 3D Image Shape for Pre-Processing Output
        '--img_shape', type = np.array,         # Default: [85, 128, 128]
        default = np.array((85, 128, 128)),
        help = "Final 3D Image Shape for Pre-Processing Output")
data_parser.add_argument(                       # Number of Selected Slices in 2D Conversion
        '--num_slices', type = int,             # Default: 35
        default = 35,
        help = "Number of Selected Slices in 2D Conversion")
data_settings = data_parser.parse_args("")

In [3]:
# 2D CVAE-mCcGAN Model Parametrizations Parser Initialization
model_parser = argparse.ArgumentParser(
description = "2D CcGAN Settings")
model_parser.add_argument(              # Model Version Variable
        '--model_version', type = int,        # Default: 0
        default = 0,
        help = "Experiment Version")
model_parser.add_argument(              # Dataset Version Variable
        '--data_version', type = int,   # Default: 0
        default = 2,
        help = "MUDI Dataset Version")

# --------------------------------------------------------------------------------------------

# Addition of Filepath Arguments
model_parser.add_argument(
        '--reader_folderpath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Dataset Reader'),
        help = 'Input Folderpath for MUDI Dataset Reader')
model_parser.add_argument(
        '--data_folderpath', type = Path,
        default = Path(f'{data_settings.main_folderpath}/Saved Data/V{data_settings.version}'),
        help = 'Input Folderpath for MUDI Dataset Saved Versions')
model_parser.add_argument(
        '--model_folderpath', type = str,
        default = 'Model Builds',
        help = 'Input Folderpath for Model Build & Architecture')
model_parser.add_argument(
        '--script_folderpath', type = str,
        default = 'Training Scripts',
        help = 'Input Folderpath for Training & Testing Script Functions')
model_parser.add_argument(
        '--save_folderpath', type = str,
        default = 'Saved Models',
        help = 'Output Folderpath for Saved & Saving Models')

# --------------------------------------------------------------------------------------------

# Addition of Label Embedding Models Training Requirement Arguments
model_parser.add_argument(              # Number of Epochs
        '--num_epochs', type = int,     # Default: 1
        default = 50,
        help = "Number of Epochs in Training Mode")
model_parser.add_argument(              # Base Learning Rate Value for Label Embedding
        '--base_lr', type = float,      # Default: 0.01
        default = 0.01,
        help = "Base Learning Rate Value in Training Mode")
model_parser.add_argument(              # Weight Decay Value
        '--weight_decay', type = float, # Default: 0.0001
        default = 1e-4,
        help = "Weight Decay Value in Training Mode")
model_parser.add_argument(              # Learning Rate Decay Ratio
        '--lr_decay', type = float,     # Default: 0.9
        default = 0.9,
        help = "Learning Rate Decay Value in Training Mode")
model_parser.add_argument(              # Control Variable for the Inclusion of Label Gamma Noise
        '--noise', type = bool,         # Default: True
        default = True,
        help = "Control Variable for the Inclusion of Label Gamma Noise")

# --------------------------------------------------------------------------------------------

# Addition of Model Architecture Arguments
model_parser.add_argument(              # Embedding Space Dimensionality
        '--dim_embedding', type = int,  # Default: 128
        default = 128,
        help = "Embedding Space Dimensionality Value")
model_parser.add_argument(              # Z Space Dimensionality
        '--dim_z', type = int,          # Default: 256
        default = 256,
        help = "Z Space Dimensionality Value")
model_parser.add_argument(              # Generator & Discriminator No. Intermediate Channels
        '--num_channels', type = int,   # Default: 64
        default = 64,
        help = "Generator & Discriminator No. Intermediate Channels")
model_parser.add_argument(              # T1 & T2 Model Expansion
        '--expansion', type = int,      # Default: 1
        default = 1,
        help = "T1 & T2 Model Expansion Value")
model_parser.add_argument(              # Batch Size Value
        '--batch_size', type = int,     # Default: 500
        default = data_settings.batch_size,
        help = "Batch Size Value")
model_parser.add_argument(              # Dataset Number of Labels
        '--num_labels', type = int,     # Default: 7
        default = data_settings.num_labels,
        help = "MUDI Dataset Number of Labels")

# --------------------------------------------------------------------------------------------

# Addition of CcGAN Training Requirement Arguments
model_parser.add_argument(              # Base Learning Rate Value
        '--lr_ccgan', type = float,     # Default: 0.0001
        default = 1e-4,
        help = "Base Learning Rate Value in Training Mode")
model_parser.add_argument(              # Sigma Kernel Size in Gaussian Noise Added
        '--kernel_eps', type = np.array,# Default: -1.0
        default = 1.0,
        help = "Sigma Kernel Size in Gaussian Noise Added")
model_parser.add_argument(              # Discriminator's Number of Updates per Iteration
        '--dis_update', type = int,     # Default: 4
        default = 4,
        help = "Number of Updates / Iteration for Discriminator")
model_parser.add_argument(              # Generator's Number of Updates per Iteration
        '--gen_update', type = int,     # Default: 4
        default = 4,
        help = "Number of Updates / Iteration for Generator")
model_parser.add_argument(              # GAN Loss Function
        '--loss', type = str,           # Default: Hinge Loss
        choices = ['vanilla', 'hinge'],
        default = 'hinge',
        help = 'GAN Loss Function')

model_settings = model_parser.parse_args("")
model_settings.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---

# **Data** *Access*

---

In [None]:
# Dataset Access Requirements
sys.path.append(f'{main_folderpath}Dataset Reader')
from v3DMUDI import v3DMUDI

# Dataset Initialization & Saving Example
#mudi = v3DMUDI(data_settings)
#mudi.split(); mudi.save()

In [None]:
# 3D Interactive Plotting Function
def plot(
    sample_number,
    slice_number,
):

    # Patient Sample & Slice for Visualization
    img = X[sample_number]
    img = img[slice_number]
    #img = data[slice_number, :, :, sample_number].T
    plt.figure(figsize = (10, 20)); plt.imshow(img, cmap = 'gray'); plt.axis('off')
    plt.title(f"Patient #11 to 14 | Sample #{sample_number} | Slice #{slice_number}")

# ----------------------------------------------------------------------------------------------------------------------------

# Patient Data Visualization Function
#patient_number = 0; data = mudi.get_patient(patient_number)
#X = data.train_set['X_train']
sample_slider = IntSlider(value = 0, min = 0, max = X.shape[0], description = 'Sample', continuous_update = False)
slice_slider = IntSlider(value = 0, min = 0, max = X.shape[1], description = 'Slice', continuous_update = False)
interactive(plot, sample_number = sample_slider, slice_number = slice_slider)


---

# *Model* **Building**

---

## **CVAE** *Model*

In [None]:
##############################################################################################
# --------------------------------------- Encoder Build --------------------------------------
##############################################################################################

# Encoder Model Class
class Encoder(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        settings: argparse.ArgumentParser    # Model Settings & Parametrizations
        #num_labels: int = 5,                # Number of Labels contained in Dataset
        #num_channel: int = 64,              # Number of Output Channels for Encoder
        #num_layers: int = 3,                # Number of Main Convolutional Layers
        #latent_dim: int = 64                # Latent Space Dimensionality
    ):

        # Class Variable Logging
        super(Encoder, self).__init__()
        self.settings = settings

        # Encoder Downsampling Architecture Definition
        net = []; in_channel = 1 + settings.num_labels
        for i in reversed(range(settings.num_layers)):
            out_channel = int(settings.num_channel / (2 ** i))      # Current Main Layer's Output Channels
            #k = 2 * (i + 1)                                        # Kernel Size Value (6 is too high for Voxel-Wise CVAE)
            #print(f"{in_channel} -> {out_channel}")
            net.append(nn.Sequential(                               # Main Layer Block Repeatable Architecture
                nn.Conv1d(      in_channels = in_channel,
                                out_channels = out_channel,
                                kernel_size = 1, stride = 2, padding = 0),
                nn.LeakyReLU(   inplace = True)))
            in_channel = out_channel                                # Next Main Layer's Input Channels
        self.net = nn.Sequential(*net)
        
        # Mean and LogVariance Computation Linear Layers
        self.mean_layer = nn.Linear(    in_features = settings.num_channel,
                                        out_features = settings.latent_dim)
        self.logvar_layer = nn.Linear(  in_features = settings.num_channel,
                                        out_features = settings.latent_dim)

    # --------------------------------------------------------------------------------------------

    # Encoder Application Function
    def forward(
        self,
        X: np.ndarray or torch.Tensor,      # 1D Image Input
        y: np.ndarray or torch.Tensor       # Image Labels Input
    ):

        # Net Input Handling
        X = torch.Tensor(X).view(-1, 1, 1).to(self.settings.device)                             # Input Features | [batch_size, 1,              1]
        y = torch.Tensor(y).view(-1, self.settings.num_labels, 1).to(self.settings.device)      # Input Labels   | [batch_size, num_labels,     1]
        input = torch.cat((X, y), dim = 1)                                                      # Encoder Output | [batch_size, 1+num_channel,  1]
        
        # Forward Propagation in Encoder Architecture
        output = self.net(input)                                                                # Encoder Output | [batch_size, num_channel,    1]
        z_mean = self.mean_layer(output.view(-1, self.settings.num_channel))                    # Latent Mean    | [batch_size, latent_dim]
        z_logvar = self.logvar_layer(output.view(-1, self.settings.num_channel))                # Latent LogVar  | [batch_size, latent_dim]

        # Display Settings for Experimental Model Version
        if self.settings.model_version == 0:
            print(f"Encoder Input  | {list(input.shape)}")
            print(f"Encoder Output | {list(output.shape)}")
            print(f"Latent Mean    | {list(z_mean.shape)}")
            print(f"Latent LogVar  | {list(z_logvar.shape)}\n")
        return z_mean, z_logvar

import tensorboard

## **GAN** *Model*

### *Image* **Generator**

### *Image* **Discriminator**

---

# *Model* **Building**

---

## **Label** *Embedding*

In [5]:
# Main / Repeatable ResNet Block Construction Class
class SimpleBlock(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel:int,
        out_channel: int,
        stride: int = 1,
        expansion: int = 1
    ):

        # Main Block's Common Section Architecture
        super(SimpleBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(  in_channel, out_channel, kernel_size = 3,
                        stride = stride, padding = 1, bias = False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),          # WARNING: They use torch.functional's ReLU for some reason here!
            nn.Conv2d(  out_channel, out_channel, kernel_size = 3,
                        stride = 1, padding = 1, bias = False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),          # WARNING: They use torch.functional's ReLU for some reason here!
            nn.Conv2d(  out_channel, expansion * out_channel,
                        kernel_size = 1, bias = False),
            nn.BatchNorm2d(expansion * out_channel)
        )

        # Main Block's Shortcut Section Architecture
        if stride != 1 or in_channel != expansion * out_channel:
             self.shortcut = nn.Sequential(
                nn.Conv2d(  in_channel, expansion * out_channel, kernel_size = 1,
                            stride = stride, bias = False),
                nn.BatchNorm2d(expansion * out_channel)
             )
        else: self.shortcut = nn.Sequential()

    # Main Block Application Function
    def forward(self, X):

        # Main Block Architecture Walkthrough
        out = self.block(X)
        out += self.shortcut(X)
        return F.relu(out)

In [6]:
# Label Embedding CNN Class (X -> h -> y)
class LabelEmbedding(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        num_blocks: list = [3, 4, 6, 3],    # Number of Blocks in ResNet Main Block Intermediate Layers
        in_channel: int = 64,               # Number of Input Channels in ResNet Main Block Intermediate Layers' Blocks
        expansion: int = 1,                 # Expansion Factor for Stride Value in ResNet Main Block Intermediate Layers
        dim_embedding: int = 128,           # Embedding Space Dimensionality (WIP)
    ):

        # Class Variable Logging
        super(LabelEmbedding, self).__init__()
        assert(len(num_blocks) == 4), "Number of Blocks provided Not Supported!"
        self.num_blocks = num_blocks
        self.in_channel = in_channel
        self.expansion = expansion
        self.dim_embedding = dim_embedding

        # Main ResNet50 Architecture Construction
        self.mainNet = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size = 3, stride = 1, padding = 1, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            self.main_layer(64, self.num_blocks[0]),
            self.main_layer(128, self.num_blocks[1]),
            self.main_layer(256, self.num_blocks[2]),
            self.main_layer(512, self.num_blocks[3]),
            nn.AdaptiveAvgPool2d((1, 1))
        )
    
    # --------------------------------------------------------------------------------------------

        # 1st SubNetwork for Label Embedding (X -> h)
        self.t1Net = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, self.dim_embedding),
            nn.BatchNorm1d(self.dim_embedding),
            nn.ReLU())

        # 2nd SubNetwork for Label Embedding (h -> y)
        self.t2Net = nn.Sequential(
            nn.Linear(self.dim_embedding, 7),
            nn.ReLU())

    # --------------------------------------------------------------------------------------------

    # ResNet Repeatable Layer Definition Function
    def main_layer(
        self,
        out_channel: int,
        num_blocks: int,
        stride: int = 2
    ):

        # Layer Architecture Creation
        stride = [stride] + [1] * (num_blocks - 1); layer = []
        for s in stride:
            layer.append(SimpleBlock(self.in_channel, out_channel, s, expansion = self.expansion))
            self.in_channel = out_channel * self.expansion
        return nn.Sequential(*layer)
            
    # --------------------------------------------------------------------------------------------

    # Label Embedding Application Function
    def forward(self,
        X: np.ndarray or torch.Tensor       # 3D Image Input
    ):

        # Forwad Propagation in CNN Architecture
        X = torch.Tensor(X)                 # Numpy Array to Tensor Conversion
        h = self.mainNet(X)                 # Main ResNet Application
        h = h.view(h.size(0), -1)           # Linearization of ResNet Features
        h = self.t1Net(h)                   # 1st SubNewtork Application (X -> h)
        y = self.t2Net(h)                   # 2nd SubNewtork Application (h -> y)
        return h, y
    

In [7]:
# Label Embedding SubNetwork Class (y -> h)
class t3Net(nn.Module):
    
    # Constructor / Initialization Function
    def __init__(self,
        dim_embedding: int = 128,           # Embedding Space Dimensionality
        num_blocks: int = 4,                # Number of Blocks in T3 SubNewtork
        num_labels: int = 7                 # Number of Labels in Dataset provided
    ):

        # Class Variable Logging
        super(t3Net, self).__init__()
        self.dim_embedding = dim_embedding      # h Variable Dimension
        self.num_blocks = num_blocks            # Number of Blocks in Embedding MLP
        self.num_labels = num_labels            # Number of Labels in Dataset provided
        self.mlp = nn.Sequential()              # Empty Embedding MLP Variable
        norm = True                             # Default Block Group Normalization

        # MLP Architecture Definition
        for i in range(num_blocks + 1):
            if i == 0: in_channel = self.num_labels         # 1st Block Entry Features
            else: in_channel = self.dim_embedding           # Intermediate Block Input Features
            if i == num_blocks - 1: norm = False            # No Group Normalization on Last Block
            self.mlp.add_module(f'Block #{i + 1}',
                                self.main_block(in_channel,
                                self.dim_embedding, norm = norm))

    # --------------------------------------------------------------------------------------------

    # MLP Repeatable Main Block Definition
    def main_block(self,
        in_channel: int,            # Input Features for Linear Layer
        out_channel: int,           # Output Features for Linear Layer
        norm: bool = True,          # Group Normalization Boolean Control Variable
    ):

        # Main Block Architecture Definition
        if norm:
            block = nn.Sequential(
                nn.Linear(in_channel, out_channel),
                nn.GroupNorm(8, out_channel), nn.ReLU())
        else:
            block = nn.Sequential(
                nn.Linear(in_channel, out_channel),
                nn.ReLU())
        return block
    
    # --------------------------------------------------------------------------------------------
    
    # Label Embedding Application Function
    def forward(self,
        y: pd.DataFrame       # 3D Image Input Labels
    ):

        # Label Embedding using MLP
        assert(y.ndim == 2), f"ERROR: Input Labels not Correctly Dimensioned!"
        return self.mlp(torch.Tensor(np.array(y)))          # h Embedded Labels Variable 


In [8]:
# 2D Data Sample Creation
X = pickle.load(open(f'Test MUDI Data', 'rb'))
y = pickle.load(open(f'Test MUDI Labels', 'rb'))
X = X[0:35,:]; y = y.iloc[0:35]; y['Patient'] = 11
plt.imshow(X[0, 0, :, :], cmap = 'gray')

# Label Embedding Network Testing (X -> h -> y)
embedNet = LabelEmbedding()                 # T1 & T2 Model Creation
h1, y2 = embedNet(X)                        # T1 (X -> h) & T2 (h -> y) Models Application
y2h_model = t3Net()                            # T3 Model Creation
h3 = y2h_model(y)                           # T3 (y -> h) Model Application
print(tabulate([["X", X.shape, "->", "h", list(h1.shape)],
                ["h", list(h1.shape), "->", "y", list(y2.shape)],
                ["y", y.shape, "->", "h", list(h3.shape)]],
                headers = ["Input", "Input Shape", "->", "Output", "Output Shape"],
                showindex = ["T1", "T2", "T3"], tablefmt = 'fancy_grid'))


FileNotFoundError: [Errno 2] No such file or directory: 'Test MUDI Data'

## *Image* **Generator**

In [26]:
# Linear Spectral Normalization Layer Class
class LinearSpectralNorm(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel: int,
        out_channel: int,
        num_channels: int = 64,
        bias: bool = True
    ):

        # Layer Architecture Construction
        super().__init__()
        self.num_channels = num_channels
        self.layer = spectral_norm(nn.Linear(in_channel, out_channel, bias))

    # Layer Application Function
    def forward(
        self,
        z: torch.Tensor
    ):

        # Layer Walkthrough
        out = self.layer(z)
        out = out.view(-1, self.num_channels * 16, 4, 4)
        return out

# --------------------------------------------------------------------------------------------

# 2D Convolutional Spectral Normalization Layer Class
class Conv2DSpectralNorm(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel: int,
        out_channel: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True
    ):

        # Layer Architecture Construction
        super().__init__()
        self.layer = spectral_norm( nn.Conv2d(in_channel, out_channel, kernel_size,
                                    stride, padding, dilation, groups, bias))

    # Layer Application Function
    def forward(
        self,
        z: torch.Tensor
    ):

        # Layer Walkthrough
        out = self.layer(z)
        return out

# --------------------------------------------------------------------------------------------

# Self Attention Layer Class
class SelfAttention(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel: int
    ):

        # Block Architecture Construction
        super().__init__()
        self.in_channel = in_channel
        self.conv2DSN_theta = Conv2DSpectralNorm(   in_channel, in_channel // 8,
                                                    kernel_size = 1, stride = 1, padding = 0)
        self.conv2DSN_phi = Conv2DSpectralNorm(     in_channel, in_channel // 8,
                                                    kernel_size = 1, stride = 1, padding = 0)
        self.conv2DSN_gamma = Conv2DSpectralNorm(   in_channel, in_channel // 2,
                                                    kernel_size = 1, stride = 1, padding = 0)
        self.conv2DSN_attent = Conv2DSpectralNorm(  in_channel // 2, in_channel,
                                                    kernel_size = 1, stride = 1, padding = 0)
        self.MaxPool2D = nn.MaxPool2d(2, stride = 2, padding = 0)
        self.SoftMax = nn.Softmax(dim = -1)
    
    # Layer Application Function
    def forward(
        self,
        X: torch.Tensor                 # Tensor containing Data
    ):

        # Theta Path
        theta = self.conv2DSN_theta(X)
        theta = theta.view(-1, X.shape[1] // 8, X.shape[2] * X.shape[3])

        # Phi Path
        phi = self.conv2DSN_phi(X)
        phi = self.MaxPool2D(phi)
        phi = phi.view(-1, X.shape[1] // 8, (X.shape[2] * X.shape[3]) // 4)

        # Gamma Path
        gamma = self.conv2DSN_gamma(X)
        gamma = self.MaxPool2D(gamma)
        gamma = gamma.view(-1, X.shape[1] // 2, (X.shape[2] * X.shape[3]) // 4)

        # Attention Map
        attent = torch.bmm(theta.permute(0, 2, 1), phi)
        attent = self.SoftMax(attent)
        attent = torch.bmm(gamma, attent.permute(0, 2, 1))
        attent = attent.view(-1, X.shape[1] // 2, X.shape[2], X.shape[3])
        attent = self.conv2DSN_attent(attent)

        return X + (nn.Parameter(torch.zeros(1)) * attent)

In [27]:
# 2D Conditional Batch Normalization Layer Class
class c2DBatchNorm(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        num_feats: int,                 #
        dim_embedding: int = 128,       #
        momentum: float = 0.001         #
    ):

        # Layer Architecture Construction
        super().__init__()
        self.num_feats = num_feats
        self.layer = nn.BatchNorm2d(num_feats, momentum = momentum, affine = False)
        self.embedding = nn.Linear(dim_embedding, num_feats, bias = False)

    # Layer Application Function
    def forward(
        self,
        X: torch.Tensor,                # Tensor containing Data
        y: torch.Tensor                 # Tensor containing Labels
    ):

        # Layer Walkthrough
        out = self.layer(X)
        gamma = beta = self.embedding(y).view(-1, self.num_feats, 1, 1)
        out = out + (gamma * out) + beta
        return out

# --------------------------------------------------------------------------------------------

# Main / Repeatable Generator Block Construction Class
class GeneratorBlock(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel: int,
        out_channel: int,
        dim_embedding: int = 128
    ):

        # Block Architecture Construction
        super().__init__()
        self.c2DBN_1 = c2DBatchNorm(            in_channel, dim_embedding)
        self.conv2DSN_1 = Conv2DSpectralNorm(   in_channel, out_channel,
                                                kernel_size = 3, stride = 1, padding = 1)
        self.c2DBN_2 = c2DBatchNorm(            out_channel, dim_embedding)
        self.conv2DSN_2 = Conv2DSpectralNorm(   out_channel, out_channel,
                                                kernel_size = 3, stride = 1, padding = 1)
        self.conv2DSN_X = Conv2DSpectralNorm(   in_channel, out_channel,
                                                kernel_size = 1, stride = 1, padding = 0)

    # Layer Application Function
    def forward(
        self,
        X: torch.Tensor,                # Tensor containing Data
        y: torch.Tensor                 # Tensor containing Labels
    ):

        # Block Walkthrough (Data Processing)
        X_0 = X                              # Copy of Original Data
        X_0 = F.interpolate(X_0, scale_factor = 2, mode = 'nearest')
        X_0 = self.conv2DSN_X(X_0)

        # Block Walkthrough
        X = self.c2DBN_1(X, y)
        X = nn.ReLU(inplace = True)(X)
        X = F.interpolate(X, scale_factor = 2, mode = 'nearest')
        X = self.conv2DSN_1(X)
        X = self.c2DBN_2(X, y)
        X = nn.ReLU(inplace = True)(X)
        X = self.conv2DSN_2(X)

        out = X + X_0
        return out


In [28]:
# Weight Initialization Function
def weightInit(module):
    if type(module) == nn.Linear or type(module) == nn.Conv2d:
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None: module.bias.data.fill_(0.)

# Generator Model Class
class Generator(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        dim_z: int = 256,               # Z Space Dimensionality
        dim_embedding: int = 128,       # Embedding Space Dimensionality
        num_channels: int = 64,         #
        momentum: float = 0.0001,       #
        eps: float = 1e-5,              #
    ):

        # Class Variable Logging
        super().__init__()
        self.dim_z = dim_z
        self.num_channels = num_channels

        # Generator Architecture
        #self.gen = nn.Sequential(); genList = dict()
        self.linearSN = LinearSpectralNorm( self.dim_z,  self.num_channels * 16 * 4 * 4,
                                            self.num_channels)                  # (4 x 4) Image
        self.genBlock1 = GeneratorBlock(    num_channels * 16,                  #    |
                                            num_channels * 16, dim_embedding)   # (8 x 8) Image
        self.genBlock2 = GeneratorBlock(    num_channels * 16,                  #    |
                                            num_channels * 8, dim_embedding)    # (16 x 16) Image
        self.genBlock3 = GeneratorBlock(    num_channels * 8,                   #     |
                                            num_channels * 4, dim_embedding)    # (32 x 32) Image
        self.selfAttention = SelfAttention( num_channels * 4)                   # (32 x 32) Image
        self.genBlock4 = GeneratorBlock(    num_channels * 4,                   #     |
                                            num_channels * 2, dim_embedding)    # (64 x 64) Image
        self.genBlock5 = GeneratorBlock(    num_channels * 2,                   #      |
                                            num_channels * 1, dim_embedding)    # (128 x 128) Image
        self.genPost = nn.Sequential(                                           # Image Post-Processing
            nn.BatchNorm2d(     num_channels, eps = eps,
                                momentum = momentum, affine = True),
            nn.ReLU(            inplace = True),
            Conv2DSpectralNorm( num_channels, 1,
                                kernel_size = 3, stride = 1, padding = 1),
            nn.Tanh())
        
        # Weight & Parameter Initialization
        self.visualizer()               # Parameter Numbers Visualization
        self.apply(weightInit)          # Weight Initialization Function     

    # Model Visualizer Function
    def visualizer(self):
        
        # Number of Total & Trainable Parameters
        num_total = sum(p.numel()   for p in self.parameters())     # Number of Total Parameters
        num_train = sum(p.numel()   for p in self.parameters()      # Number of Trainable Parameters 
                                    if p.requires_grad)             # (those that Require Autograd)
        print(f"Generator | Total Parameters: {num_total}\n          | Trainable Parameters: {num_train}")

    # Layer Application Function
    def forward(
        self,
        z: torch.Tensor,                # Tensor containing Z Space Data
        y: torch.Tensor                 # Tensor containing Labels
    ):

        # Block Walkthrough
        out = self.linearSN(z)                              # (4 x 4) Image
        out = out.view(-1, self.num_channels * 16, 4, 4)    # (4 x 4) Image
        out = self.genBlock1(out, y)                        # (8 x 8) Image
        out = self.genBlock2(out, y)                        # (16 x 16) Image
        out = self.genBlock3(out, y)                        # (32 x 32) Image
        out = self.selfAttention(out)                       # (32 x 32) Image
        out = self.genBlock4(out, y)                        # (64 x 64) Image
        out = self.genBlock5(out, y)                        # (128 x 128) Image
        out = self.genPost(out)                             # Image Post-Processing
        return out


In [None]:
# 2D Data Sample Creation
X = pickle.load(open(f'Test MUDI Data', 'rb'))
y = pickle.load(open(f'Test MUDI Labels', 'rb'))
X = X[0:35,:]; y = y.iloc[0:35]; y['Patient'] = 11
h = y2h_model(y); h = torch.Tensor(h)
X = torch.Tensor(X); y = torch.Tensor(y.values)

# Generator Model Testing Example
gen = Generator()
z = torch.randn(X.shape[0], 256, dtype = torch.float)   # Random Noise Generation
genX = gen(z, y = h)
trueFig = plt.figure(1); genFig = plt.figure(2)
for i in range(genX.shape[0]):
    trueAxis = trueFig.add_subplot(5, 7, i + 1); trueAxis.axis('off')
    genAxis = genFig.add_subplot(5, 7, i + 1); genAxis.axis('off')
    trueAxis.imshow(X.detach().numpy()[i, 0, :, :], cmap = 'gray')
    genAxis.imshow(genX.detach().numpy()[i, 0, :, :], cmap = 'gray')
    

## *Image* **Discriminator**

In [51]:
# Main / Repeatable Discriminator Block Construction Class
class DiscriminatorBlock(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel: int,
        out_channel: int,
        downsample: bool = True         # Boolean Control Variable for Downsampling
    ):

        # Block Architecture Construction
        super().__init__()
        if in_channel != out_channel: self.mismatch = True
        else: self.mismatch = False
        self.downsample = downsample
        self.conv2DSN_1 = Conv2DSpectralNorm(   in_channel, out_channel,
                                                kernel_size = 3, stride = 1, padding = 1)
        self.conv2DSN_2 = Conv2DSpectralNorm(   out_channel, out_channel,
                                                kernel_size = 3, stride = 1, padding = 1)
        self.downsampleLayer = nn.AvgPool2d(2)
        self.conv2DSN_X = Conv2DSpectralNorm(   in_channel, out_channel,
                                                kernel_size = 1, stride = 1, padding = 0)
        
    # Block Application Function
    def forward(
        self,
        X: torch.Tensor                 # Tensor containing Data
    ):

        # Block Walkthrough (Data Processing)
        X_0 = X.detach().clone()        # Copy of Original Data
        if self.downsample or self.mismatch:
            X_0 = self.conv2DSN_X(X_0)
            if self.downsample: X_0 = self.downsampleLayer(X_0)

        # Block Walkthrough
        X = nn.ReLU(inplace = True)(X)
        X = self.conv2DSN_1(X)
        X = nn.ReLU(inplace = True)(X)
        X = self.conv2DSN_2(X)
        if self.downsample: X = self.downsampleLayer(X)

        out = (X + X_0)
        return out

# --------------------------------------------------------------------------------------------

# Optimal Discriminator Block Construction Class
class OptimalBlock(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        in_channel: int,
        out_channel: int,
    ):

        # Block Architecture Construction
        super().__init__()
        self.X0Block = nn.Sequential(
            nn.AvgPool2d(2),
            Conv2DSpectralNorm( in_channel, out_channel,
                                kernel_size = 1, stride = 1, padding = 0))
        self.XBlock = nn.Sequential(
            Conv2DSpectralNorm( in_channel, out_channel,
                                kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(inplace = True),
            Conv2DSpectralNorm( out_channel, out_channel,
                                kernel_size = 3, stride = 1, padding = 1),
            nn.AvgPool2d(2))

    # Block Application Function
    def forward(
        self,
        X: torch.Tensor                 # Tensor containing Data
    ):

        # Block Walkthrough
        X_0 = self.X0Block(X)
        X = self.XBlock(X)
        out = X + X_0
        return out
        

In [52]:
# Discriminator Model Class
class Discriminator(nn.Module):

    # Constructor / Initialization Function
    def __init__(
        self,
        #num_labels: int = 7,           # Number of Labels provided in y
        dim_embedding: int = 128,       # Embedding Space Dimensionality
        num_channels: int = 64,         # Number of Neural Net Channels
    ):

        # Class Variable Logging
        super().__init__()
        self.num_channels = num_channels

        # Discriminator Architecture
        self.main = nn.Sequential(
            OptimalBlock(       1,                  num_channels),          # (128 x 128) Image
            DiscriminatorBlock( num_channels,       num_channels * 2),      # (64 x 64) Image
            SelfAttention(      num_channels * 2),                          # (64 x 64) Image
            DiscriminatorBlock( num_channels * 2,   num_channels * 4),      # (32 x 32) Image
            DiscriminatorBlock( num_channels * 4,   num_channels * 8),      # (16 x 16) Image
            DiscriminatorBlock( num_channels * 8,   num_channels * 16),     # (8 x 8) Image
            DiscriminatorBlock( num_channels * 16,   num_channels * 16,     # (4 x 4) Image
                                downsample = False),
            nn.ReLU(inplace = True))                                        # (4 x 4) Image
        self.linearSN = self.LinearSpectralNorm(num_channels * 16 * 4 * 4, 1)
        self.embedding = self.LinearSpectralNorm(dim_embedding, num_channels * 16 * 4 * 4, bias = False)

        # Weight & Parameter Initialization
        self.visualizer()               # Parameter Numbers Visualization
        self.apply(weightInit)          # Weight Initialization Function    
        nn.init.xavier_uniform_(self.embedding.weight) 

    # Model Visualizer Function
    def visualizer(self):
        
        # Number of Total & Trainable Parameters
        num_total = sum(p.numel()   for p in self.parameters())     # Number of Total Parameters
        num_train = sum(p.numel()   for p in self.parameters()      # Number of Trainable Parameters 
                                    if p.requires_grad)             # (those that Require Autograd)
        print(f"Discriminator | Total Parameters: {num_total}\n              | Trainable Parameters: {num_train}")

    # Linear Spectral Normalization Layer Function
    def LinearSpectralNorm(
        self,
        in_channel: int,
        out_channel: int,
        bias: bool = True
    ):
        return spectral_norm(nn.Linear(in_channel, out_channel, bias = bias))

    # Layer Application Function
    def forward(
        self,
        X: torch.Tensor,                # Tensor containing Data
        h: torch.Tensor                 # Tensor containing Embedded Labels
    ):

        # Block Walkthrough
        out = self.main(X)                              # (128 x 128) -> (4 x 4) Image
        out = out.view(-1, self.num_channels * 16 * 4 * 4)
        out1 = torch.squeeze(self.linearSN(out))        # 1st Output Section (Linear)
        h = self.embedding(h)                           # Embedded Labels
        out2 = torch.sum(torch.mul(out, h), dim = [1])  # 2nd Output Section (Projection)
        return (out1 + out2).unsqueeze(-1)

In [None]:
# 2D Data Sample Creation
X = pickle.load(open(f'Test MUDI Data', 'rb'))
y = pickle.load(open(f'Test MUDI Labels', 'rb'))
X = X[0:35,:]; y = y.iloc[0:35]; y['Patient'] = 11
h = y2h_model(y); h = torch.Tensor(h)
X = torch.Tensor(X); y = torch.Tensor(y.values)

# Discriminator Model Testing Example
testX = torch.randn(X.shape[0], 1, 128, 128)
dis = Discriminator()
out = dis(testX, y)
#summary(dis, [testX.shape, y.shape])

"""
h = torch.Tensor(t3Net(y_train[0:500,:]))
dis = Discriminator()
out = dis(X_train[0:500,:], h)
print(out.shape)
"""

---

# **Running** *Scripts*

---

In [6]:
# Label Embedding Models Training, Validation & Testing Script Class
class LitT12Net(pl.LightningModule):

    ##############################################################################################
    # --------------------------------------- Initial Setup --------------------------------------
    ##############################################################################################

    # Constructor / Initialization Function
    def __init__(
        self,
        settings: argparse.ArgumentParser,              # Model Settings & Parametrizations
    ):

        # Class Variable Logging
        super().__init__()
        self.settings = settings
        self.lr_decay_epochs = [80, 140]                # Epochs for Learning Rate Decay

        # Model Initialization
        self.model = LabelEmbedding(in_channel = 64,
                                    expansion= settings.expansion,
                                    dim_embedding = settings.dim_embedding,
                                    num_labels = settings.num_labels)
        self.criterion = nn.MSELoss()
        
        # Existing Model Checkpoint Loading
        self.model_filepath = Path(f"{self.settings.save_folderpath}/V{self.settings.model_version}/Embedding Net (V{self.settings.model_version}).pth")
        if self.settings.model_version != 0 and self.model_filepath.exists():

            # Checkpoint Fixing (due to the use of nn.DataParallel)
            checkpoint = torch.load(self.model_filepath); self.checkpoint_fix = dict()
            for sd, sd_value in checkpoint.items():
                if sd == 'ModelSD' or sd == 'OptimizerSD':
                    self.checkpoint_fix[sd] = OrderedDict()
                    for key, value in checkpoint[sd].items():
                        if key[0:7] == 'module.':
                            self.checkpoint_fix[sd][key[7:]] = value
                        else: self.checkpoint_fix[sd][key] = value
                else: self.checkpoint_fix[sd] = sd_value
            
            # Application of Checkpoint's State Dictionary
            self.model.load_state_dict(self.checkpoint_fix['ModelSD'])
            torch.set_rng_state(self.checkpoint_fix['RNG State'])
            del checkpoint#, checkpoint_fix
        self.model = nn.DataParallel(self.model.to(self.settings.device))

    # Optimizer Initialization Function
    def configure_optimizers(self):
        self.optimizer = torch.optim.SGD(   self.model.parameters(),                # T3 Model Optimizer
                                            lr = self.settings.base_lr,             # using T2's Parameters
                                            momentum = 0.9,
                                            weight_decay = self.settings.weight_decay)
        if self.settings.model_version != 0 and self.model_filepath.exists():       # T3 Model Optimizer
            self.optimizer.load_state_dict(self.checkpoint_fix['OptimizerSD'])      # Checkpoint Loading
        self.lr_schedule = torch.optim.lr_scheduler.ExponentialLR(  self.optimizer, # Learning Rate Decay
                                                    gamma = self.settings.lr_decay) # in Chosen Epochs
        return self.optimizer

    # Foward Functionality
    def forward(self, X): return self.model(X)

    # --------------------------------------------------------------------------------------------
    
    # Train Set DataLoader Download
    def train_dataloader(self):
        TrainTrainLoader = v3DMUDI.loader(  Path(f"{self.settings.data_folderpath}"),
                                            dim = 2, version = self.settings.data_version,
                                            set_ = 'Train', mode_ = 'Train')
        self.train_batches = len(TrainTrainLoader)
        return TrainTrainLoader
    
    # Validation Set DataLoader Download
    def val_dataloader(self):
        TrainValLoader = v3DMUDI.loader(Path(f"{self.settings.data_folderpath}"),
                                        dim = 2, version = self.settings.data_version,
                                        set_ = 'Train', mode_ = 'Val')
        self.val_batches = len(TrainValLoader)
        return TrainValLoader

    # Test Set DataLoader Download
    def test_dataloader(self):
        TestValLoader = v3DMUDI.loader( Path(f"{self.settings.data_folderpath}"),
                                        dim = 2, version = self.settings.data_version,
                                        set_ = 'Test', mode_ = 'Val')
        self.test_batches = len(TestValLoader)
        return TestValLoader

    ##############################################################################################
    # ------------------------------------- Training Script --------------------------------------
    ##############################################################################################

    # Functionality called upon the Start of Training
    def on_train_start(self):
        
        # Model Training Mode Setup
        self.model.train()
        if self.settings.model_version == 0:
            self.current_epoch = self.settings.num_epochs - 1

        # TensorBoard Loggers Initialization
        self.train_loss_logger = TensorBoardLogger(f'{self.settings.save_folderpath}/V{self.settings.model_version}/T12 Net', 'Training Performance')
        self.val_loss_logger = TensorBoardLogger(f'{self.settings.save_folderpath}/V{self.settings.model_version}/T12 Net', 'Validation Performance')

    # Functionality called upon the Start of Training Epoch
    def on_train_epoch_start(self): self.train_loss = 0

    # --------------------------------------------------------------------------------------------

    # Training Step / Batch Loop 
    def training_step(self, batch, batch_idx):
        
        # Label Handling + Noise Addition
        X_batch, ygt_batch = batch
        X_batch = X_batch.type(torch.float).to(self.settings.device)
        ygt_batch = ygt_batch.type(torch.float).to(self.settings.device)

        # Forward Pass
        h_batch, y_batch = self.model(X_batch)                  # T12 Model (X -> h -> y)
        loss = self.criterion(y_batch, ygt_batch)               # Loss Computation
        del X_batch, ygt_batch, h_batch
        return loss

    # Functionality called upon the End of a Batch Training Step
    def on_train_batch_end(self, loss): self.train_loss = self.train_loss + loss.cpu().item()

    # --------------------------------------------------------------------------------------------

    # Functionality called upon the End of a Training Epoch
    def on_train_epoch_end(self):

        # Learning Rate Decay
        if (self.trainer.current_epoch + 1) in self.lr_decay_epochs:
            self.lr_schedule.step()

        # TensorBoard Logger Update
        self.train_loss = self.train_loss / self.train_batches
        self.train_loss_logger.experiment.add_scalar("Training Loss", self.train_loss, self.current_epoch)
        
        # Model Checkpoint Saving
        torch.save({'ModelSD': self.model.state_dict(),
                    'OptimizerSD': self.optimizer.state_dict(),
                    'Training Epochs': self.current_epoch,
                    'RNG State': torch.get_rng_state()},
                    self.model_filepath)

    ##############################################################################################
    # ------------------------------------ Validation Script -------------------------------------
    ##############################################################################################

    # Functionality called upon the Start of Validation
    def on_validation_start(self): self.model.eval()

    # Functionality called upon the Start of Training Epoch
    def on_validation_epoch_start(self): self.val_loss = 0

    # --------------------------------------------------------------------------------------------

    # Training Step / Batch Loop 
    def validation_step(self, batch, batch_idx):
        
        # Label Handling + Noise Addition
        X_batch, ygt_batch = batch
        X_batch = X_batch.type(torch.float).to(self.settings.device)
        ygt_batch = ygt_batch.type(torch.float).to(self.settings.device)

        # Forward Pass
        h_batch, y_batch = self.model(X_batch)                  # T12 Model (X -> h -> y)
        loss = self.criterion(y_batch, ygt_batch)               # Loss Computation
        self.val_loss = self.val_loss + loss.cpu().item()
        del X_batch, ygt_batch, h_batch
        return loss

    # --------------------------------------------------------------------------------------------

    # Functionality called upon the End of a Training Epoch
    def on_train_epoch_end(self):
        
        # TensorBoard Logger Update
        self.val_loss = self.val_loss / self.val_batches
        self.val_loss_logger.experiment.add_scalar("Training Loss", self.val_loss, self.current_epoch)
    
    ##############################################################################################
    # ------------------------------------- Testing Script ---------------------------------------
    ##############################################################################################


In [7]:
# T3 Embedding Model Training, Validation & Testing Script Class
class LitT3Net(pl.LightningModule):

    ##############################################################################################
    # --------------------------------------- Initial Setup --------------------------------------
    ##############################################################################################

    # Constructor / Initialization Function
    def __init__(
        self,
        settings: argparse.ArgumentParser,              # Model Settings & Parametrizations
        embedNet: LabelEmbedding,                       # Trained T1 & T2 Conjoint Model
    ):

        # Class Variable Logging
        super().__init__()
        self.settings = settings
        #self.embedNet = embedNet                        # Label Embedding Variable
        self.t2Net = embedNet.module.t2Net              # T2 Model Contained in Label Embedding Variable
        self.lr_decay_epochs = [150, 250, 350]          # Epochs for Learning Rate Decay

        # Model Initialization
        self.model = t3Net( dim_embedding = self.settings.dim_embedding,
                            num_labels = self.settings.num_labels)
        self.criterion = nn.MSELoss()
        
        # Existing Model Checkpoint Loading
        self.model_filepath = Path(f"{self.settings.save_folderpath}/V{self.settings.model_version}/T3 Net (V{self.settings.model_version}).pth")
        if self.settings.model_version != 0 and self.model_filepath.exists():

            # Checkpoint Fixing (due to the use of nn.DataParallel)
            checkpoint = torch.load(self.model_filepath); self.checkpoint_fix = dict()
            for sd, sd_value in checkpoint.items():
                if sd == 'ModelSD' or sd == 'OptimizerSD':
                    self.checkpoint_fix[sd] = OrderedDict()
                    for key, value in checkpoint[sd].items():
                        if key[0:7] == 'module.':
                            self.checkpoint_fix[sd][key[7:]] = value
                        else: self.checkpoint_fix[sd][key] = value
                else: self.checkpoint_fix[sd] = sd_value
            
            # Application of Checkpoint's State Dictionary
            self.model.load_state_dict(self.checkpoint_fix['ModelSD'])
            #self.optimizer.load_state_dict(self.checkpoint_fix['OptimizerSD'])
            #current_epoch = self.checkpoint_fix['Training Epochs']
            torch.set_rng_state(self.checkpoint_fix['RNG State'])
            del checkpoint#, checkpoint_fix
        self.model = nn.DataParallel(self.model.to(self.settings.device))

    # Optimizer Initialization Function
    def configure_optimizers(self):
        self.optimizer = torch.optim.SGD(   self.t2Net.parameters(),                # T3 Model Optimizer
                                            lr = self.settings.base_lr,             # using T2's Parameters
                                            momentum = 0.9,
                                            weight_decay = self.settings.weight_decay)
        if self.settings.model_version != 0 and self.model_filepath.exists():       # T3 Model Optimizer
            self.optimizer.load_state_dict(self.checkpoint_fix['OptimizerSD'])      # Checkpoint Loading
        self.lr_schedule = torch.optim.lr_scheduler.ExponentialLR(  self.optimizer, # Learning Rate Decay
                                                    gamma = self.settings.lr_decay) # in Chosen Epochs
        return self.optimizer

    # Foward Functionality
    def forward(self, y): return self.model(y)

    # --------------------------------------------------------------------------------------------
    
    # Train Set DataLoader Download
    def train_dataloader(self):
        TrainTrainLoader = v3DMUDI.loader(  Path(f"{self.settings.data_folderpath}"),
                                            dim = 2, version = self.settings.data_version,
                                            set_ = 'Train', mode_ = 'Train')
        self.train_batches = len(TrainTrainLoader)
        return TrainTrainLoader
    
    # Validation Set DataLoader Download
    def val_dataloader(self):
        TrainValLoader = v3DMUDI.loader(Path(f"{self.settings.data_folderpath}"),
                                        dim = 2, version = self.settings.data_version,
                                        set_ = 'Train', mode_ = 'Val')
        self.val_batches = len(TrainValLoader)
        return TrainValLoader

    # Test Set DataLoader Download
    def test_dataloader(self):
        TestValLoader = v3DMUDI.loader( Path(f"{self.settings.data_folderpath}"),
                                        dim = 2, version = self.settings.data_version,
                                        set_ = 'Test', mode_ = 'Val')
        self.test_batches = len(TestValLoader)
        return TestValLoader

    ##############################################################################################
    # ------------------------------------- Training Script --------------------------------------
    ##############################################################################################

    # Functionality called upon the Start of Training
    def on_train_start(self):
        
        # Model Training Mode Setup
        self.model.train()
        if self.settings.model_version == 0:
            self.current_epoch = self.settings.num_epochs - 1

        # TensorBoard Loggers Initialization
        self.train_loss_logger = TensorBoardLogger(f'{self.settings.save_folderpath}/V{self.settings.model_version}/T3 Net', 'Training Performance')

    # Functionality called upon the Start of Training Epoch
    def on_train_epoch_start(self): self.train_loss = 0

    # --------------------------------------------------------------------------------------------

    # Training Step / Batch Loop 
    def training_step(self, batch, batch_idx):
        
        # Label Handling + Noise Addition
        X_batch, ygt_batch = batch
        ygt_batch = ygt_batch.type(torch.float).to(self.settings.device)
        gamma_batch = np.random.normal(0, 0.2, ygt_batch.shape)
        gamma_batch = torch.from_numpy(gamma_batch).type(torch.float).to(self.settings.device)
        ygt_noise_batch = torch.clamp(ygt_batch + gamma_batch, 0.0, 1.0)

        # Forward Pass
        h_noise_batch = self.model(ygt_noise_batch)              # T3 Model (y -> h)
        y_noise_batch = self.t2Net(h_noise_batch)                # T2 Model (h -> y)
        loss = self.criterion(y_noise_batch, ygt_noise_batch)    # Loss Computation
        del X_batch, ygt_batch, gamma_batch, ygt_noise_batch, y_noise_batch, h_noise_batch
        return loss

    # Functionality called upon the End of a Batch Training Step
    def on_train_batch_end(self, loss):
        
        # Backward Pass
        #self.optimizer.zero_grad()
        #loss.backward(retain_graph = True)
        #self.optimizer.step()
        self.train_loss = self.train_loss + loss.cpu().item()

    # --------------------------------------------------------------------------------------------

    # Functionality called upon the End of a Training Epoch
    def on_train_epoch_end(self):

        # Learning Rate Decay
        if (self.trainer.current_epoch + 1) in self.lr_decay_epochs:
            self.lr_schedule.step()

        # TensorBoard Logger Update
        self.train_loss = self.train_loss / self.train_batches
        self.train_loss_logger.experiment.add_scalar("Training Loss", self.train_loss, self.current_epoch)
        
        # Model Checkpoint Saving
        torch.save({'ModelSD': self.model.state_dict(),
                    'OptimizerSD': self.optimizer.state_dict(),
                    'Training Epochs': self.current_epoch,
                    'RNG State': torch.get_rng_state()},
                    self.model_filepath)

    ##############################################################################################
    # ------------------------------------ Validation Script -------------------------------------
    ##############################################################################################



In [6]:
# Generator & Discriminator Models Training
def train_gan(
    t3Net: t3Net,                                   # Trained T3 Model
    train_set: DataLoader,                          # Training Set's Train DataLoader
    settings: argparse.ArgumentParser,              # Model Settings & Parametrizations
    train: bool = True,                             # Boolean Control Variable: False if the purpose...
):                                                  # is to just Load the Selected Model model_version

    # Models Architecture & Optimizer Initialization
    gen = Generator(dim_z = settings.dim_z, dim_embedding = settings.dim_embedding)
    dis = Discriminator(); current_epoch = 0
    gen_optimizer = torch.optim.Adam(gen.parameters(), lr = settings.lr_ccgan, betas = (0.5, 0.999))
    dis_optimizer = torch.optim.Adam(dis.parameters(), lr = settings.lr_ccgan, betas = (0.5, 0.999))

    # Existing Model Checkpoint Loading
    gen_filepath = Path(f"{settings.save_folderpath}/V{settings.model_version}/Generator (V{settings.model_version}).pth")
    dis_filepath = Path(f"{settings.save_folderpath}/V{settings.model_version}/Discriminator (V{settings.model_version}).pth")
    if settings.model_version != 0 and gen_filepath.exists() and dis_filepath.exists():

        # Generator Checkpoint Fixing (due to the use of nn.DataParallel)
        gen_checkpoint = torch.load(gen_filepath); gen_checkpoint_fix = dict()
        for sd, sd_value in gen_checkpoint.items():
            if sd == 'ModelSD' or sd == 'OptimizerSD':
                gen_checkpoint_fix[sd] = OrderedDict()
                for key, value in gen_checkpoint[sd].items():
                    if key[0:7] == 'module.':
                        gen_checkpoint_fix[sd][key[7:]] = value
                    else: gen_checkpoint_fix[sd][key] = value
            else: gen_checkpoint_fix[sd] = sd_value

        # Discriminator Checkpoint Fixing (due to the use of nn.DataParallel)
        dis_checkpoint = torch.load(dis_filepath); dis_checkpoint_fix = dict()
        for sd, sd_value in dis_checkpoint.items():
            if sd == 'ModelSD' or sd == 'OptimizerSD':
                dis_checkpoint_fix[sd] = OrderedDict()
                for key, value in dis_checkpoint[sd].items():
                    if key[0:7] == 'module.':
                        dis_checkpoint_fix[sd][key[7:]] = value
                    else: dis_checkpoint_fix[sd][key] = value
            else: dis_checkpoint_fix[sd] = sd_value

        # Generator Checkpoint Loading
        gen.load_state_dict(gen_checkpoint_fix['ModelSD'])
        gen_optimizer.load_state_dict(gen_checkpoint_fix['OptimizerSD'])
        current_epoch = gen_checkpoint_fix['Training Epochs']

        # Discriminator Checkpoint Loading
        dis.load_state_dict(dis_checkpoint_fix['ModelSD'])
        dis_optimizer.load_state_dict(dis_checkpoint_fix['OptimizerSD'])
        torch.set_rng_state(dis_checkpoint_fix['RNG State'])
        del gen_checkpoint, dis_checkpoint, gen_checkpoint_fix, dis_checkpoint_fix
    
    # Model Transfer to CUDA Device
    t3Net = t3Net.to(settings.device); t3Net.eval()
    gen = nn.DataParallel(gen).to(settings.device)
    dis = nn.DataParallel(dis).to(settings.device)

    # --------------------------------------------------------------------------------------------

    if not(train):
        print(f"DOWNLOAD: Generator Model (V{settings.model_version})")
        print(f"DOWNLOAD: Discriminator Model (V{settings.model_version})")
    else:

        # Data Accessing Betterment
        nRow= nCol = 10
        X_train, y_train = train_set.dataset[:]
        nSamples, nLabels = y_train.shape
        if settings.model_version != 0: assert(settings.batch_size == train_set.batch_size
        ), "ERROR: Batch Size Value not Corresponding"

        # Training Image Selection between the 5th & 95th Label Percentile
        sel_label = np.empty((nRow, nLabels))
        start_label = np.quantile(y_train, 0.05, axis = 0)      # 5th Percentile Label Value
        end_label = np.quantile(y_train, 0.95, axis = 0)        # 95th Percentile Label Value
        for l in range(nLabels): sel_label[:, l] = np.linspace(start_label[l], end_label[l], num = nRow)
        del start_label, end_label

        # Automated Label-Specific Kappa Difference Computation & Gaussian Noise Generation Function
        kappa_list = np.empty(nLabels)
        for l in range(nLabels): kappa_list[l] = np.mean(np.diff(np.sort(np.unique(y_train[:, l].numpy()))))
        if nLabels >= 6: kappa_list[-1] = 0.0
        def eps(samples):
            eps_pos = np.random.normal(0, settings.kernel_eps, (samples, nLabels))
            eps_neg = np.random.normal(0, settings.kernel_eps, (samples, nLabels))
            return (eps_pos - eps_neg) * (kappa_list / 2)

        # --------------------------------------------------------------------------------------------
        
        # Epoch Loop
        if settings.model_version == 0: settings.num_epochs = 1; settings.batch_size = 1
        dis_loss_table = np.empty(0, dtype = np.float)
        gen_loss_table = np.empty(0, dtype = np.float)
        for epoch in range(current_epoch, current_epoch + settings.num_epochs):

            # Discriminator Training Process / Step Loop
            with alive_bar( settings.dis_update, bar = 'blocks',
                        title = f'Epoch #{epoch} | Discriminator ',
                        force_tty = True) as dis_bar:
                for up in range(settings.dis_update):

                    # Random Draw of Batch of Target Labels (w/ Added Gaussian Noise)
                    """[Bug] Since the Target Labels are chosen from the Overall Dataset
                    and not from a List of Unique Labels, there might be some Repeats,
                    though it is also a fact there aren't that many Unique Values"""
                    y_target = y_train[np.random.choice(nSamples, size = settings.batch_size, replace = False), :]   # Random Batch of Target Labels
                    y_vic = y_target + eps(settings.batch_size)                                                      # Initial Addition of Gaussian Noise
                    index_target = np.empty(settings.batch_size, dtype = int)                       # Vicinity-Labelled Real Image Indexes
                    y_fake = np.empty((settings.batch_size, nLabels))                               # Fake Vicinity Labels for Image Generation
                
                    # Batch Loop
                    for i in range(settings.batch_size):
                        
                        # Hard Vicinity-Labelled Real Image Searching
                        index_vic = np.where(np.all(np.abs(y_train - y_vic[i, :]).numpy()           # Index of Real in-Vicinity Training Sample Labels
                                                        <= (kappa_list * 2.0), axis = 1))[0]
                        while len(index_vic) < nLabels:                                             # Redoing of the Vicinity Area Loop...
                            y_vic[i, :] = y_target[i, :] + eps(1)                                   # using different Gaussian Noise values...
                            index_vic = np.where(np.all(np.abs(y_train - y_vic[i, :]).numpy()       # to Ensure that at least 'nlabels' Neighbour...
                                                        <= (kappa_list * 2.0), axis = 1))[0]        # for each Target in the Random Batch is found!
                        index_target[i] = np.random.choice(index_vic, size = 1)                     # Choosing of 1 in-Vicinity Sample per Target

                        # Fake Image Label Generation
                        inf_bound = y_vic[i, :] - kappa_list
                        sup_bound = y_vic[i, :] + kappa_list
                        assert(np.all((inf_bound <= sup_bound).numpy())), "ERROR: Kappa Paremeter wrongly Set!"
                        y_fake[i, :] = np.random.uniform(inf_bound, sup_bound, size = nLabels)      # Random Creation of a Fake Label Sample...
                        assert(np.all(np.abs(y_fake[i, :] - y_vic[i, :].numpy())                    # that must remain within Vicinity...
                                    <= kappa_list)), "ERROR: Kappa Paremeter wrongly Set!"          # of the used batch Target Label 

                    # Hard Vicinity-Labelled Real Image Drawing
                    X_vic = torch.Tensor(X_train[index_target]).type(torch.float).to(settings.device)
                    y_vic = torch.Tensor(y_train[index_target]).type(torch.float).to(settings.device)
                    del index_target, index_vic, inf_bound, sup_bound

                    # Fake Image Generation
                    y_fake = torch.from_numpy(y_fake).type(torch.float).to(settings.device)
                    z_fake = torch.randn(settings.batch_size, settings.dim_z, dtype = torch.float).to(settings.device)
                    X_fake = gen(z_fake, t3Net(y_fake))

                    # Forward Pass
                    w_target = w_fake = torch.ones(settings.batch_size, dtype = torch.float).to(settings.device)
                    out_target = dis(X_vic, t3Net(y_target))		    # Real Sample Discriminator Output
                    out_fake = dis(X_fake, t3Net(y_fake))		        # Fake Sample Discriminator Output

                    # Vanilla Loss Function Computation Switch Case
                    assert(settings.loss == 'vanilla' or settings.loss == 'hinge'
                    ), f"ERROR: Loss Function not Supported!"
                    if settings.loss == 'vanilla':
                        loss_target = torch.nn.Sigmoid()(out_target)
                        loss_fake = torch.nn.Sigmoid()(out_fake)
                        loss_target = torch.log(loss_target + 1e-20)        # Real Sample Loss Value
                        loss_fake = torch.log(loss_fake + 1e-20)            # Fake Sample Loss Value

                    # Hinge Loss Function Computation Switch Case
                    elif settings.loss == 'hinge':                          
                        loss_target = torch.nn.ReLU()(1.0 - out_target)		# Real Sample Loss Value
                        loss_fake = torch.nn.ReLU()(1.0 + out_fake)		    # Fake Sample Loss Value
                    del X_vic, y_vic, z_fake, X_fake, out_target, out_fake, y_fake, y_target

                    # Backward Pass & Step Update
                    w_target = w_target.unsqueeze(-1); loss_target = loss_target.unsqueeze(-1)
                    w_fake = w_fake.unsqueeze(-1); loss_fake = loss_fake.unsqueeze(-1)
                    dis_loss =  torch.mean(w_target.view(-1) * loss_target.view(-1)) + torch.mean(w_fake.view(-1) * loss_fake.view(-1))
                    dis_optimizer.zero_grad()
                    dis_loss.backward()
                    dis_optimizer.step()
                    dis_loss_table = np.append(dis_loss_table, dis_loss.detach().numpy())
                    time.sleep(1); dis_bar()

            # --------------------------------------------------------------------------------------------

            # Generator Training Process / Step Loop
            gen.train()
            with alive_bar( settings.gen_update, bar = 'blocks',
                    title = f'Epoch #{epoch} | Generator     ',
                    force_tty = True) as gen_bar:
                for up in range(settings.gen_update):

                    # Random Draw of Batch of Target Labels (w/ Added Gaussian Noise)
                    """[Bug] Since the Target Labels are chosen from the Overall Dataset
                    and not from a List of Unique Labels, there might be some Repeats,
                    though it is also a fact there aren't that many Unique Values"""
                    y_target = y_train[np.random.choice(nSamples, size = settings.batch_size, replace = False), :]  # Random Batch of Target Labels
                    y_fake = (y_target + eps(settings.batch_size)).type(torch.float).to(settings.device)            # Initial Addition of Gaussian Noise

                    # Fake Image Generation & Forward Pass
                    z_fake = torch.randn(settings.batch_size, settings.dim_z, dtype = torch.float).to(settings.device)
                    X_fake = gen(z_fake, t3Net(y_fake))
                    out_fake = dis(X_fake, t3Net(y_fake))		        # Fake Sample Discriminator Output

                    # Loss Function Computation Switch Case
                    assert(settings.loss == 'vanilla' or settings.loss == 'hinge'
                    ), f"ERROR: Loss Function not Supported!"
                    if settings.loss == 'vanilla':
                        gen_loss = torch.nn.Sigmoid()(out_fake)
                        gen_loss = torch.log(gen_loss + 1e-20)      # Fake Sample Loss Value
                    elif settings.loss == 'hinge':                          
                        gen_loss = - out_fake.mean()		        # Fake Sample Loss Value
                    del z_fake, X_fake, out_fake, y_fake, y_target

                    # Backward Pass & Step Update
                    gen_optimizer.zero_grad()
                    gen_loss.backward()
                    gen_optimizer.step()
                    gen_loss_table = np.append(gen_loss_table, gen_loss.detach().numpy())
                    time.sleep(1); gen_bar()

            # --------------------------------------------------------------------------------------------

            # Model Progress & State Dictionary Saving
            print(f"Epoch #{epoch} | Discriminator Train Loss: {np.round(dis_loss.detach().numpy(), 3)}")
            print(f"Epoch #{epoch} | Generator Train Loss: {np.round(gen_loss.detach().numpy(), 3)}")
            torch.save({'ModelSD': dis.state_dict(),
                        'OptimizerSD': dis_optimizer.state_dict(),
                        'Training Epochs': epoch,
                        'RNG State': torch.get_rng_state()},
                        dis_filepath)
            torch.save({'ModelSD': gen.state_dict(),
                        'OptimizerSD': gen_optimizer.state_dict(),
                        'Training Epochs': epoch,
                        'RNG State': torch.get_rng_state()},
                        gen_filepath)

        # --------------------------------------------------------------------------------------------

        # Training Performance Evaluation - Example Images Visualization
        z_fix = torch.randn(nRow * nCol, settings.dim_z, dtype = torch.float).to(settings.device)
        gen.eval(); y_fix = np.empty((nRow * nCol, nLabels))
        for i in range(nRow):
            current_label = sel_label[i, :]
            for j in range(nCol):
                y_fix[(i * nCol) + j, :] = current_label
        y_fix = torch.from_numpy(y_fix).type(torch.float).to(settings.device)
        with torch.no_grad(): X_fix = gen(z_fix, t3Net(y_fix)).detach().cpu()
        fig, axs = plt.subplots(int(np.ceil(nRow/2)), int(np.ceil(nCol/2)), figsize=(15, 15)); fig.tight_layout()
        for i in range(int(np.ceil(nRow/2))):
            for j in range(int(np.ceil(nCol/2))):
                axs[i, j].imshow(X_fix[int((i * np.ceil(nRow/2)) + j), 0, :, :], cmap = 'gray')
                plt.axis('off'), axs[i, j].xaxis.set_visible(False); axs[i, j].yaxis.set_visible(False)
        plt.savefig(Path(f"{settings.save_folderpath}/V{settings.model_version}/Example Images (V{settings.model_version}).png"))
        del z_fix, X_fix, y_fix
        
        # Training Performance Evaluation - Loss Analysis
        fig, ax = plt.subplots(figsize = (10, 10))
        ax.plot(dis_loss_table, 'g', label = 'Discriminator')
        ax.plot(gen_loss_table, 'r', label = 'Generator')
        ax.legend(loc = 'upper right'); ax.set_title('GAN Loss'); ax.set_xticks([])
        plt.savefig(Path(f"{settings.save_folderpath}/V{settings.model_version}/GAN Loss (V{settings.model_version}).png"))
        
        # Training Performance Evaluation - Other Analytics
        #

    return dis, gen


---

# **Main** *Scripts*

---

In [5]:
"""
# Dataset Version Creation
sys.path.append(data_settings.main_folderpath + "/Dataset Reader")
from v3DMUDI import v3DMUDI
data = v3DMUDI(data_settings)
data.split(data_settings)
data.save()
"""

# --------------------------------------------------------------------------------------------

# Dataset Access
sys.path.append(f"{data_settings.main_folderpath}/Dataset Reader")
from v3DMUDI import v3DMUDI

# Dataset Version Creation
#data = v3DMUDI(data_settings)
#data.split(data_settings)
#data.save()

# --------------------------------------------------------------------------------------------

# Full 2D CcGAN Model Class Importing
sys.path.append(model_settings.model_folderpath)
from LabelEmbedding import LabelEmbedding, t3Net
from Generator import Generator
from Discriminator import Discriminator

# Full 2D CcGAN Model Training Importing
sys.path.append(model_settings.script_folderpath)
from train_embedNet import LitT12Net
from train_t3Net import LitT3Net
#from train_gan import Lit2DCcGAN

# --------------------------------------------------------------------------------------------

# Model Initialization & Training | T12 Net
"""
t12Net = LitT12Net(model_settings)
t12_trainer = pl.Trainer(       max_epochs = 3,#model_settings.num_epochs,
                                devices = 1 if torch.cuda.is_available() else None,
                                enable_progress_bar = True,
                                callbacks = [pl.callbacks.TQDMProgressBar(refresh_rate = 1)])
t12_trainer.fit(t12Net)
"""

# Model Initialization & Training | T3 Net
"""
t3Net = LitT3Net(model_settings, t12Net)
t3_trainer = pl.Trainer(max_epochs = model_settings.num_epochs,
                        devices = 1 if torch.cuda.is_available() else None,
                        enable_progress_bar = True,
                        callbacks = [pl.callbacks.TQDMProgressBar(refresh_rate = 1)])
t3_trainer.fit(t3Net)
"""

# Model Initialization & Training | 2D CcGAN

%load_ext tensorboard