In [1]:

import pandas as pd
import pickle as pkl
import scanpy as sc
import anndata as ad
import numpy as np
from sklearn.model_selection import train_test_split
import h5py


## Import dataset

In [2]:

with open("../data_for_training/train_cell_embeddings.pkl", "rb") as file:
    train_cell_embeddings = pkl.load(file)

with open("../data_for_training/train_smiles_embeddings.pkl", "rb") as file:
    train_smiles_embeddings = pkl.load(file)

with open("../data_for_training/train_images.pkl", "rb") as file:
    train_images = pkl.load(file)

with open("../data_for_training/valid_cell_embeddings.pkl", "rb") as file:
    valid_cell_embeddings = pkl.load(file)

with open("../data_for_training/valid_smiles_embeddings.pkl", "rb") as file:
    valid_smiles_embeddings = pkl.load(file)

with open("../data_for_training/valid_images.pkl", "rb") as file:
    valid_images = pkl.load(file)

with open("../data_for_training/test_cell_embeddings.pkl", "rb") as file:
    test_cell_embeddings = pkl.load(file)

with open("../data_for_training/test_smiles_embeddings.pkl", "rb") as file:
    test_smiles_embeddings = pkl.load(file)

with open("../data_for_training/test_images.pkl", "rb") as file:
    test_images = pkl.load(file)


## GAN training

        self.gen = nn.Sequential(
            self._generator_block(input_dim, hidden_dim * 8, kernel_size=5, stride=5),   # (1x1) → (5x5)
            self._generator_block(hidden_dim * 8, hidden_dim * 4, kernel_size=4, stride=2, padding=1),   # (5x5) → (10x10)
            self._generator_block(hidden_dim * 4, hidden_dim * 2, kernel_size=3, stride=2, padding=0),   # (10x10) → (21x21)
            self._generator_block(hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding=1),   # (21x21) → (42x42)
            self._generator_block(hidden_dim, hidden_dim // 2, kernel_size=3, stride=2, padding=1),   # (42x42) → (85x85)
            self._generator_block(hidden_dim // 2, hidden_dim // 4, kernel_size=4, stride=2, padding=1),   # (85x85) → (170x170)
            self._generator_block(hidden_dim // 4, hidden_dim // 8, kernel_size=3, stride=2, padding=1),   # (170x170) → (255x255)
            self._generator_block(hidden_dim // 8, hidden_dim // 16, kernel_size=4, stride=1, padding=0),  # (255x255) → (272x272)
            self._generator_block(hidden_dim // 16, hidden_dim // 32, kernel_size=4, stride=1, padding=0),  # (272x272) → (300x300)
            self._generator_block(hidden_dim // 32, image_channel, kernel_size=4, stride=1, padding=0, final_layer=True),  # (300x300) → (340x340)
        )

In [130]:
import torch
from torch import nn

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, input_dim=2769, image_channel=1, hidden_dim=256):
        super(Generator, self).__init__()
        self.input_dim = input_dim

        self.gen = nn.Sequential(
            self._generator_block(input_dim, hidden_dim * 8, kernel_size=5, stride=5),   # (1x1) → (5x5)
            self._generator_block(hidden_dim * 8, hidden_dim * 4, kernel_size=4, stride=2, padding=1),   # (5x5) → (10x10)
            self._generator_block(hidden_dim * 4, hidden_dim * 2, kernel_size=3, stride=2, padding=0),   # (10x10) → (21x21)
            self._generator_block(hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding=1),   # (21x21) → (42x42)
            self._generator_block(hidden_dim, hidden_dim // 2, kernel_size=3, stride=2, padding=1),   # (42x42) → (85x85)
            self._generator_block(hidden_dim // 2, hidden_dim // 4, kernel_size=4, stride=2, padding=1),   # (85x85) → (170x170)
            self._generator_block(hidden_dim // 4, hidden_dim // 8, kernel_size=3, stride=2, padding=1),   # (170x170) → (255x255)
            self._generator_block(hidden_dim // 8, hidden_dim // 16, kernel_size=4, stride=1, padding=0),  # (255x255) → (272x272)
            self._generator_block(hidden_dim // 16, hidden_dim // 32, kernel_size=4, stride=1, padding=0),  # (272x272) → (300x300)
            self._generator_block(hidden_dim // 32, image_channel, kernel_size=4, stride=1, padding=0, final_layer=True),  # (300x300) → (340x340)
        )

    def _generator_block(self, input_channels, output_channels, kernel_size=3, stride=2, padding=0, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)


def create_noise_vector(n_samples, input_dim, device="cpu"):
    return torch.randn(n_samples, input_dim, device=device)


class Discriminator(nn.Module):
    def __init__(self, image_channel=1, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self._discriminator_block(image_channel, hidden_dim),
            self._discriminator_block(hidden_dim, hidden_dim * 2),
            self._discriminator_block(hidden_dim * 2, 1, final_layer=True),
        )

    def _discriminator_block(
        self,
        input_channels,
        output_channels,
        kernel_size=4,
        stride=2,
        final_layer=False,
    ):

        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)



In [131]:
import torch
from torch import nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torch.nn.functional as F

torch.manual_seed(0)  # Set for our testing purposes, please do not change!


def plot_images_from_tensor(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    """
    Plots a grid of images from a given tensor.

    The function first scales the image tensor to the range [0, 1]. It then detaches the tensor from the computation
    graph and moves it to the CPU if it's not already there. After that, it creates a grid of images and plots the grid.

    Args:
        image_tensor (torch.Tensor): A 4D tensor containing the images.
            The tensor is expected to be in the shape (batch_size, channels, height, width).
        num_images (int, optional): The number of images to include in the grid. Default is 25.
        size (tuple, optional): The size of a single image in the form of (channels, height, width). Default is (1, 28, 28).
        nrow (int, optional): Number of images displayed in each row of the grid. The final grid size is (num_images // nrow, nrow). Default is 5.
        show (bool, optional): Determines if the plot should be shown. Default is True.

    Returns:
        None. The function outputs a plot of a grid of images.
    """

    # Normalize the image tensor to [0, 1]
    image_tensor = (image_tensor + 1) / 2

    # Detach the tensor from its computation graph and move it to the CPU
    image_unflat = image_tensor.detach().cpu()

    # Create a grid of images using the make_grid function from torchvision.utils
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)

    # Plot the grid of images
    # The permute() function is used to rearrange the dimensions of the grid for plotting
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())

    # Show the plot if the 'show' parameter is True
    if show:
        plt.show()

    



""" The reason for doing "image_grid.permute(1, 2, 0)"

PyTorch modules processing image data expect tensors in the format C × H × W.

Whereas PILLow and Matplotlib expect image arrays in the format H × W × C

so to use them with matplotlib you need to reshape it
to put the channels as the last dimension:

I could have used permute() method as well like below
"np.transpose(npimg, (1, 2, 0))"

------------------

Tensor.detach() is used to detach a tensor from the current computational graph. It returns a new tensor that doesn't require a gradient.

When we don't need a tensor to be traced for the gradient computation, we detach the tensor from the current computational graph.

We also need to detach a tensor when we need to move the tensor from GPU to CPU.

"""


def weights_init(m):
    """
    Initialize the weights of convolutional and batch normalization layers.

    Args:
        m (torch.nn.Module): Module instance.

    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


def ohe_vector_from_labels(labels, n_classes):
    return F.one_hot(labels, num_classes=n_classes)


"""
x = torch.tensor([4, 3, 2, 1, 0])
F.one_hot(x, num_classes=6)

# Expected result
# tensor([[0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 1, 0, 0],
#         [0, 0, 1, 0, 0, 0],
#         [0, 1, 0, 0, 0, 0],
#         [1, 0, 0, 0, 0, 0]])
"""


""" Concatenation of Multiple Tensor with `torch.cat()` - RULE - To concatenate WITH torch.cat(), where the list of tensors are concatenated across the specified dimensions, requires 2 conditions to be satisfied

1. All tensors need to have the same number of dimensions and
2. All dimensions except the one that they are concatenated on, need to have the same size. """


def concat_vectors(x, y):
    """
    Concatenate two tensors along the second dimension.

    Args:
        x (torch.Tensor): First input tensor.
        y (torch.Tensor): Second input tensor.

    Returns:
        torch.Tensor: Concatenated tensor.

    """
    combined = torch.cat((x.float(), y.float()), 1)
    return combined

def calculate_input_dim(z_dim, mnist_shape, n_classes):
    """
    Calculate the input dimensions for the generator and discriminator networks.

    Args:
        z_dim (int): Dimension of the random noise vector (latent space).
        mnist_shape (tuple): Shape of the MNIST images, e.g., (1, 28, 28).
        n_classes (int): Number of classes in the dataset.

    Returns:
        tuple: Tuple containing the generator input dimension and discriminator image channel.

    mnist_shape = (1, 28, 28)
    n_classes = 10"""
    generator_input_dim = z_dim + n_classes

    # mnist_shape[0] is 1 as its grayscale images
    discriminator_image_channel = mnist_shape[0] + n_classes

    return generator_input_dim, discriminator_image_channel

In [132]:
import torch
import torch.nn as nn

#from utils import *

####################################################
def test_weights_init():
    # Create a sample model with Conv2d and BatchNorm2d layers
    model = nn.Sequential(
        nn.Conv2d(3, 16, kernel_size=3),
        nn.BatchNorm2d(16),
        nn.ConvTranspose2d(16, 3, kernel_size=3),
        nn.BatchNorm2d(3)
    )

    # Initialize the model weights
    model.apply(weights_init)

    # Check the weights of Conv2d layers
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            assert torch.allclose(module.weight.mean(), torch.tensor(0.0), atol=0.02)
            assert torch.allclose(module.weight.std(), torch.tensor(0.02), atol=0.02)

    # Check the weights of BatchNorm2d layers
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            assert torch.allclose(module.weight.mean(), torch.tensor(0.0), atol=0.02)
            assert torch.allclose(module.weight.std(), torch.tensor(0.02), atol=0.02)
            assert torch.allclose(module.bias, torch.tensor(0.0))

    print("Unit test passed!")

# Run the unit test
# test_weights_init()

####################################################
def test_concat_vectors():
    # Create sample input tensors
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    y = torch.tensor([[7, 8, 9], [10, 11, 12]])

    # Perform concatenation
    combined = concat_vectors(x, y)

    # Check the output type and shape
    assert isinstance(combined, torch.Tensor)
    assert combined.shape == (2, 6)  # Expected shape after concatenation

    # Check the values in the concatenated tensor
    expected_combined = torch.tensor([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]])
    assert torch.allclose(combined, expected_combined)

    print("Unit test passed!")

# Run the unit test
# test_concat_vectors()

####################################################
def test_calculate_input_dim():
    # Set up sample inputs
    z_dim = 100
    mnist_shape = (1, 28, 28)
    n_classes = 10

    # Calculate input dimensions
    generator_input_dim, discriminator_image_channel = calculate_input_dim(z_dim, mnist_shape, n_classes)

    # Check the output types and values
    assert isinstance(generator_input_dim, int)
    assert generator_input_dim == z_dim + n_classes

    assert isinstance(discriminator_image_channel, int)
    assert discriminator_image_channel == mnist_shape[0] + n_classes

    print("Unit test passed!")

# Run the unit test
# test_calculate_input_dim()


In [133]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

#from conditional_gan import *
#from utils import *

mnist_shape = (1, 340, 340)
n_classes = 768


criterion = nn.BCEWithLogitsLoss()
n_epochs = 1
z_dim = 2001
display_step = 500
batch_size = 128
lr = 0.0002
device = "cuda"

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)


In [134]:

#dataloader = DataLoader(
#    MNIST(
#        "/home/dennis00/scRNA_GAN/MNIST", download=True, transform=transform
#    ),
#    batch_size=batch_size,
#    shuffle=True,
#)

generator_input_dim, discriminator_image_channel = calculate_input_dim(
    z_dim, mnist_shape, n_classes
)

gen = Generator(input_dim=generator_input_dim).to(device)

gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)

disc = Discriminator(image_channel=discriminator_image_channel).to(device)

disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)


def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


gen  = gen.apply(weights_init)

disc = disc.apply(weights_init)


cur_step = 0
generator_losses = []
discriminator_losses = []

noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False


In [135]:
gen

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(2769, 2048, kernel_size=(5, 5), stride=(5, 5))
      (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): ConvTranspose2d(2048, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (4): Sequential(
      (0): ConvTranspose

In [136]:

class MultiModalDataset(Dataset):
    def __init__(self, cell_embeddings, smiles_embeddings, images):
        self.cell_embeddings = torch.tensor(cell_embeddings, dtype=torch.float32)
        self.smiles_embeddings = torch.tensor(smiles_embeddings, dtype=torch.float32)
        self.images = torch.tensor(images, dtype=torch.float32)

    def __len__(self):
        return len(self.cell_embeddings)

    def __getitem__(self, idx):
        return (
            self.cell_embeddings[idx], 
            self.smiles_embeddings[idx], 
            self.images[idx]
        )

# Example usage for train set
train_dataset = MultiModalDataset(train_cell_embeddings, train_smiles_embeddings, train_images)

# Same for validation and test sets
val_dataset = MultiModalDataset(valid_cell_embeddings, valid_smiles_embeddings, valid_images)
test_dataset = MultiModalDataset(test_cell_embeddings, test_smiles_embeddings, test_images)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [149]:
for epoch in range(n_epochs):
    for cell, smile, real in tqdm(train_loader):
        # Move data to GPU if needed
        cell, smile, real = cell.to(device), smile.to(device), real.to(device)

        cur_batch_size = len(cell)
        
        disc_opt.zero_grad()

        one_hot_labels = smile
        
        fake_noise = create_noise_vector(cur_batch_size, z_dim, device=device)
        
        noise_and_labels = concat_vectors(fake_noise, one_hot_labels)

        fake = gen(noise_and_labels)

        # Make sure that enough images were generated
        assert len(fake) == len(real)
        

  0%|          | 0/163 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [140]:
fake.shape

torch.Size([128, 1, 340, 340])

In [141]:
noise_and_labels.shape

torch.Size([128, 2769])

In [142]:
real.shape

torch.Size([128, 340, 340])

In [150]:
one_hot_labels

tensor([[ 2.4593, -0.2537,  0.0481,  ..., -0.0569, -0.2366,  0.7164],
        [ 2.1339, -0.1549,  0.3125,  ..., -0.0910, -0.1731,  0.6211],
        [ 2.5093,  0.9491,  1.0351,  ..., -0.1222, -0.4565,  0.9770],
        ...,
        [ 2.1316,  0.1801, -0.2926,  ..., -0.7508, -0.4878,  1.0630],
        [ 1.0609,  0.1547, -0.3923,  ..., -0.3137, -0.0152,  0.7931],
        [ 2.2265,  0.2990,  0.4958,  ...,  0.5109,  0.1003,  0.2555]],
       device='cuda:0')

In [153]:
image_one_hot_labels = one_hot_labels[:, :, None, None]

In [157]:
image_one_hot_labels = image_one_hot_labels.repeat(
            1, 1, 340, 340
        )

OutOfMemoryError: CUDA out of memory. Tried to allocate 42.33 GiB. GPU 0 has a total capacity of 31.74 GiB of which 1.06 GiB is free. Including non-PyTorch memory, this process has 30.08 GiB memory in use. Process 432941 has 610.00 MiB memory in use. Of the allocated memory 19.02 GiB is allocated by PyTorch, and 10.70 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [156]:
image_one_hot_labels.shape

torch.Size([128, 768, 1, 1])

In [None]:

#dataloader = DataLoader(
#    MNIST(
#        "/home/dennis00/scRNA_GAN/MNIST", download=True, transform=transform
#    ),
#    batch_size=batch_size,
#    shuffle=True,
#)

In [151]:
z_dim = 64
for epoch in range(n_epochs):
    # Dataloader returns the batches and the labels
    for real, labels in tqdm(dataloader):
        cur_batch_size = len(real)
        # Flatten the batch of real images from the dataset
        real = real.to(device)

        # create one hot encoded vectors from labels and n_classes
        one_hot_labels = ohe_vector_from_labels(labels.to(device), n_classes)
        print("one_hot_labels ", one_hot_labels.size())  # => torch.Size([128, 10])

        """ The above ([128, 10]) need to be converted to ([128, 10, 28, 28])

        Because, Concatenation of Multiple Tensor with `torch.cat()` - RULE - To concatenate WITH torch.cat(), where the list of tensors are concatenated across the specified dimensions, requires 2 conditions to be satisfied

         1. All tensors need to have the same number of dimensions and
         2. All dimensions except the one that they are concatenated on, need to have the same size.

        To do that, first I am adding extra dimension with 'None'
        the easiest way to add extra dimensions to an array is by using the keyword None,
        when indexing at the position to add the extra dimension.
        Note, in below with keyword None, I am only adding extra dummy empty dimension

        a = torch.rand(1, 2)
        ic(a) # => tensor([[0.1749, 0.6387]])
        ic(a[None, :]) # => tensor([[[0.1749, 0.6387]]])

        a = torch.rand([1,2,3,4])
        ic(a.shape) # => torch.Size([1, 2, 3, 4])
        ic(a[None, :].shape) # => torch.Size([1, 1, 2, 3, 4])
        """
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        print(
            "image_one_hot_labels.size ", image_one_hot_labels.size()
        )  # => torch.Size([128, 10, 1, 1])

        image_one_hot_labels = image_one_hot_labels.repeat(
            1, 1, mnist_shape[1], mnist_shape[2]
        )
        print(
            "image_one_hot_labels.size ", image_one_hot_labels.size()
        )  # => torch.Size([128, 10, 28, 28])

        #########################
        #  Train Discriminator
        #########################
        # Zero out the discriminator gradients
        disc_opt.zero_grad()
        # Get noise corresponding to the current batch_size
        fake_noise = create_noise_vector(cur_batch_size, z_dim, device=device)

        # Now we can get the images from the generator
        # Steps: 1) Combine the noise vectors and the one-hot labels for the generator
        #        2) Generate the conditioned fake images

        noise_and_labels = concat_vectors(fake_noise, one_hot_labels)
        fake = gen(noise_and_labels)

NameError: name 'dataloader' is not defined

In [None]:

for epoch in range(n_epochs):
    # Dataloader returns the batches and the labels
    for real, labels in tqdm(dataloader):
        cur_batch_size = len(real)
        # Flatten the batch of real images from the dataset
        real = real.to(device)

        # create one hot encoded vectors from labels and n_classes
        one_hot_labels = ohe_vector_from_labels(labels.to(device), n_classes)
        print("one_hot_labels ", one_hot_labels.size())  # => torch.Size([128, 10])

        """ The above ([128, 10]) need to be converted to ([128, 10, 28, 28])

        Because, Concatenation of Multiple Tensor with `torch.cat()` - RULE - To concatenate WITH torch.cat(), where the list of tensors are concatenated across the specified dimensions, requires 2 conditions to be satisfied

         1. All tensors need to have the same number of dimensions and
         2. All dimensions except the one that they are concatenated on, need to have the same size.

        To do that, first I am adding extra dimension with 'None'
        the easiest way to add extra dimensions to an array is by using the keyword None,
        when indexing at the position to add the extra dimension.
        Note, in below with keyword None, I am only adding extra dummy empty dimension

        a = torch.rand(1, 2)
        ic(a) # => tensor([[0.1749, 0.6387]])
        ic(a[None, :]) # => tensor([[[0.1749, 0.6387]]])

        a = torch.rand([1,2,3,4])
        ic(a.shape) # => torch.Size([1, 2, 3, 4])
        ic(a[None, :].shape) # => torch.Size([1, 1, 2, 3, 4])
        """
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        print(
            "image_one_hot_labels.size ", image_one_hot_labels.size()
        )  # => torch.Size([128, 10, 1, 1])

        image_one_hot_labels = image_one_hot_labels.repeat(
            1, 1, mnist_shape[1], mnist_shape[2]
        )
        print(
            "image_one_hot_labels.size ", image_one_hot_labels.size()
        )  # => torch.Size([128, 10, 28, 28])

        #########################
        #  Train Discriminator
        #########################
        # Zero out the discriminator gradients
        disc_opt.zero_grad()
        # Get noise corresponding to the current batch_size
        fake_noise = create_noise_vector(cur_batch_size, z_dim, device=device)

        # Now we can get the images from the generator
        # Steps: 1) Combine the noise vectors and the one-hot labels for the generator
        #        2) Generate the conditioned fake images

        noise_and_labels = concat_vectors(fake_noise, one_hot_labels)
        fake = gen(noise_and_labels)

        # Make sure that enough images were generated
        assert len(fake) == len(real)

        # Now we can get the predictions from the discriminator
        # Steps: 1) Create the input for the discriminator
        #           a) Combine the fake images with image_one_hot_labels,
        #              remember to detach the generator (.detach()) so we do not backpropagate
        #              through it
        #           b) Combine the real images with image_one_hot_labels
        #        2) Get the discriminator's prediction on the fakes as disc_fake_pred
        #        3) Get the discriminator's prediction on the reals as disc_real_pred

        # Combine the fake images with image_one_hot_labels
        fake_image_and_labels = concat_vectors(fake, image_one_hot_labels)

        # Combine the real images with image_one_hot_labels
        real_image_and_labels = concat_vectors(real, image_one_hot_labels)

        # Get the discriminator's prediction on the reals and fakes
        disc_fake_pred = disc(fake_image_and_labels.detach())
        disc_real_pred = disc(real_image_and_labels)

        # Make sure that enough predictions were made
        assert len(disc_real_pred) == len(real)
        # Make sure that the inputs are different
        assert torch.any(fake_image_and_labels != real_image_and_labels)

        # Calculate Discriminator Loss on fakes and reals
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))

        # Get average Discriminator Loss
        disc_loss = (disc_fake_loss + disc_real_loss) / 2

        # Backpropagate and update weights
        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        # Keep track of the average discriminator loss
        discriminator_losses += [disc_loss.item()]

        #########################
        #  Train Generators
        #########################

        gen_opt.zero_grad()

        fake_image_and_labels = concat_vectors(fake, image_one_hot_labels)
        # This will error if we didn't concatenate wer labels to wer image correctly
        disc_fake_pred = disc(fake_image_and_labels)

        """ Now calculate Generator Loss and note that, here, unlike the disc_loss, with
        disc_fake_pred, I am passing a vector containing its elements as 1 with torch.ones_like
        Because, Generator wants to fool the Discriminator by telling it that all these fake images are actually real, i.e. with value of 1
        """
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))

        # Backpropagate and update weights
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the generator losses
        generator_losses += [gen_loss.item()]

        ##################################
        #  Log Progress and Visualization
        #  for each display_step = 50
        ##################################
        if cur_step % display_step == 0 and cur_step > 0:
            # Calculate Generator Mean loss for the latest display_steps (i.e. latest 50 steps)
            # list[-x:]   # last x items in the array
            gen_mean = sum(generator_losses[-display_step:]) / display_step

            # Calculate Discriminator Mean loss for the latest display_steps (i.e. latest 50 steps)
            disc_mean = sum(discriminator_losses[-display_step:]) / display_step
            print(
                f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}"
            )

            # Plot both the real images and fake generated images
            plot_images_from_tensor(fake)
            plot_images_from_tensor(real)

            step_bins = 20
            x_axis = sorted(
                [i * step_bins for i in range(len(generator_losses) // step_bins)]
                * step_bins
            )
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples])
                .view(-1, step_bins)
                .mean(1),
                label="Generator Loss",
            )
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(discriminator_losses[:num_examples])
                .view(-1, step_bins)
                .mean(1),
                label="Discriminator Loss",
            )
            plt.legend()
            plt.show()
        elif cur_step == 0:
            print("Let Long Training Continue")
        cur_step += 1