In [1]:

import pandas as pd
import pickle as pkl
import scanpy as sc
import anndata as ad
import numpy as np
from sklearn.model_selection import train_test_split
import h5py


## Import dataset

In [2]:

with open("../data_for_training/OneK_cellEmbed_ImgSize125/train_cell_embeddings.pkl", "rb") as file:
    train_cell_embeddings = pkl.load(file)

with open("../data_for_training/OneK_cellEmbed_ImgSize125/train_smiles_embeddings.pkl", "rb") as file:
    train_smiles_embeddings = pkl.load(file)

with open("../data_for_training/OneK_cellEmbed_ImgSize125/train_images.pkl", "rb") as file:
    train_images = pkl.load(file)

with open("../data_for_training/OneK_cellEmbed_ImgSize125/valid_cell_embeddings.pkl", "rb") as file:
    valid_cell_embeddings = pkl.load(file)

with open("../data_for_training/OneK_cellEmbed_ImgSize125/valid_smiles_embeddings.pkl", "rb") as file:
    valid_smiles_embeddings = pkl.load(file)

with open("../data_for_training/OneK_cellEmbed_ImgSize125/valid_images.pkl", "rb") as file:
    valid_images = pkl.load(file)

with open("../data_for_training/OneK_cellEmbed_ImgSize125/test_cell_embeddings.pkl", "rb") as file:
    test_cell_embeddings = pkl.load(file)

with open("../data_for_training/OneK_cellEmbed_ImgSize125/test_smiles_embeddings.pkl", "rb") as file:
    test_smiles_embeddings = pkl.load(file)

with open("../data_for_training/OneK_cellEmbed_ImgSize125/test_images.pkl", "rb") as file:
    test_images = pkl.load(file)


## GAN training

        self.gen = nn.Sequential(
            self._generator_block(input_dim, hidden_dim * 8, kernel_size=5, stride=5),   # (1x1) → (5x5)
            self._generator_block(hidden_dim * 8, hidden_dim * 4, kernel_size=4, stride=2, padding=1),   # (5x5) → (10x10)
            self._generator_block(hidden_dim * 4, hidden_dim * 2, kernel_size=3, stride=2, padding=0),   # (10x10) → (21x21)
            self._generator_block(hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding=1),   # (21x21) → (42x42)
            self._generator_block(hidden_dim, hidden_dim // 2, kernel_size=3, stride=2, padding=1),   # (42x42) → (85x85)
            self._generator_block(hidden_dim // 2, hidden_dim // 4, kernel_size=4, stride=2, padding=1),   # (85x85) → (170x170)
            self._generator_block(hidden_dim // 4, hidden_dim // 8, kernel_size=3, stride=2, padding=1),   # (170x170) → (255x255)
            self._generator_block(hidden_dim // 8, hidden_dim // 16, kernel_size=4, stride=1, padding=0),  # (255x255) → (272x272)
            self._generator_block(hidden_dim // 16, hidden_dim // 32, kernel_size=4, stride=1, padding=0),  # (272x272) → (300x300)
            self._generator_block(hidden_dim // 32, image_channel, kernel_size=4, stride=1, padding=0, final_layer=True),  # (300x300) → (340x340)
        )

In [79]:
import torch
from torch import nn

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, input_dim=1769, image_channel=1, hidden_dim=256):
        super(Generator, self).__init__()
        self.input_dim = input_dim

        self.gen = nn.Sequential(
            self._generator_block(input_dim, hidden_dim * 8, kernel_size=5, stride=5),   # (1x1) → (5x5)
            self._generator_block(hidden_dim * 8, hidden_dim * 4, kernel_size=5, stride=3, padding=1),   # (5x5) → (10x10)
            self._generator_block(hidden_dim * 4, hidden_dim * 2, kernel_size=5, stride=2, padding=1),   # (10x10) → (21x21)
            self._generator_block(hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding=1),   # (21x21) → (42x42)
            self._generator_block(hidden_dim, hidden_dim // 2, kernel_size=3, stride=2, padding=0, final_layer=True)   # (42x42) → (85x85)
            #self._generator_block(hidden_dim // 2, hidden_dim // 4, kernel_size=4, stride=2, padding=1),   # (85x85) → (170x170)
            #self._generator_block(hidden_dim // 4, hidden_dim // 8, kernel_size=3, stride=2, padding=1),   # (170x170) → (255x255)
            #self._generator_block(hidden_dim // 8, hidden_dim // 16, kernel_size=4, stride=1, padding=0),  # (255x255) → (272x272)
            #self._generator_block(hidden_dim // 16, hidden_dim // 32, kernel_size=4, stride=1, padding=0),  # (272x272) → (300x300)
            #self._generator_block(hidden_dim // 32, image_channel, kernel_size=4, stride=1, padding=0),  # (300x300) → (340x340)
        )

    def _generator_block(self, input_channels, output_channels, kernel_size=3, stride=2, padding=0, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)


def create_noise_vector(n_samples, input_dim, device="cpu"):
    return torch.randn(n_samples, input_dim, device=device)

torch.Size([32, 896, 125, 125])
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, image_channel=1, hidden_dim=512):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self._discriminator_block(image_channel, hidden_dim),
            self._discriminator_block(hidden_dim, hidden_dim * 2),
            self._discriminator_block(hidden_dim * 2, 1, final_layer=True),
        )

    def _discriminator_block(
        self,
        input_channels,
        output_channels,
        kernel_size=4,
        stride=2,
        final_layer=False,
    ):

        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)


gen = Generator(input_dim=generator_input_dim).to(device)
noise_and_labels = concat_vectors(fake_noise, one_hot_labels)

fake = gen(noise_and_labels)

disc = Discriminator(image_channel=discriminator_image_channel).to(device)
print(disc)
print(fake.shape)
print(disc(fake_image_and_labels).shape)

Discriminator(
  (disc): Sequential(
    (0): Sequential(
      (0): Conv2d(893, 512, kernel_size=(4, 4), stride=(2, 2))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2))
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(1024, 1, kernel_size=(4, 4), stride=(2, 2))
    )
  )
)
torch.Size([32, 128, 125, 125])


TypeError: conv2d() received an invalid combination of arguments - got (bool, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, tuple of ints padding = 0, tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!bool!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, str padding = "valid", tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!bool!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)


In [80]:
fake_image_and_labels.shape

AttributeError: 'bool' object has no attribute 'shape'

In [82]:
gen

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(1769, 2048, kernel_size=(5, 5), stride=(5, 5))
      (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): ConvTranspose2d(2048, 1024, kernel_size=(5, 5), stride=(3, 3), padding=(1, 1))
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (4): Sequential(
      (0

In [36]:
fake_image_and_labels.shape

torch.Size([32, 896, 125, 125])

In [37]:
896 - 769

127

In [83]:
import torch
from torch import nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torch.nn.functional as F

torch.manual_seed(0)  # Set for our testing purposes, please do not change!


def plot_images_from_tensor(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    """
    Plots a grid of images from a given tensor.

    The function first scales the image tensor to the range [0, 1]. It then detaches the tensor from the computation
    graph and moves it to the CPU if it's not already there. After that, it creates a grid of images and plots the grid.

    Args:
        image_tensor (torch.Tensor): A 4D tensor containing the images.
            The tensor is expected to be in the shape (batch_size, channels, height, width).
        num_images (int, optional): The number of images to include in the grid. Default is 25.
        size (tuple, optional): The size of a single image in the form of (channels, height, width). Default is (1, 28, 28).
        nrow (int, optional): Number of images displayed in each row of the grid. The final grid size is (num_images // nrow, nrow). Default is 5.
        show (bool, optional): Determines if the plot should be shown. Default is True.

    Returns:
        None. The function outputs a plot of a grid of images.
    """

    # Normalize the image tensor to [0, 1]
    image_tensor = (image_tensor + 1) / 2

    # Detach the tensor from its computation graph and move it to the CPU
    image_unflat = image_tensor.detach().cpu()

    # Create a grid of images using the make_grid function from torchvision.utils
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)

    # Plot the grid of images
    # The permute() function is used to rearrange the dimensions of the grid for plotting
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())

    # Show the plot if the 'show' parameter is True
    if show:
        plt.show()

    



""" The reason for doing "image_grid.permute(1, 2, 0)"

PyTorch modules processing image data expect tensors in the format C × H × W.

Whereas PILLow and Matplotlib expect image arrays in the format H × W × C

so to use them with matplotlib you need to reshape it
to put the channels as the last dimension:

I could have used permute() method as well like below
"np.transpose(npimg, (1, 2, 0))"

------------------

Tensor.detach() is used to detach a tensor from the current computational graph. It returns a new tensor that doesn't require a gradient.

When we don't need a tensor to be traced for the gradient computation, we detach the tensor from the current computational graph.

We also need to detach a tensor when we need to move the tensor from GPU to CPU.

"""


def weights_init(m):
    """
    Initialize the weights of convolutional and batch normalization layers.

    Args:
        m (torch.nn.Module): Module instance.

    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


def ohe_vector_from_labels(labels, n_classes):
    return F.one_hot(labels, num_classes=n_classes)


"""
x = torch.tensor([4, 3, 2, 1, 0])
F.one_hot(x, num_classes=6)

# Expected result
# tensor([[0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 1, 0, 0],
#         [0, 0, 1, 0, 0, 0],
#         [0, 1, 0, 0, 0, 0],
#         [1, 0, 0, 0, 0, 0]])
"""


""" Concatenation of Multiple Tensor with `torch.cat()` - RULE - To concatenate WITH torch.cat(), where the list of tensors are concatenated across the specified dimensions, requires 2 conditions to be satisfied

1. All tensors need to have the same number of dimensions and
2. All dimensions except the one that they are concatenated on, need to have the same size. """


def concat_vectors(x, y):
    """
    Concatenate two tensors along the second dimension.

    Args:
        x (torch.Tensor): First input tensor.
        y (torch.Tensor): Second input tensor.

    Returns:
        torch.Tensor: Concatenated tensor.

    """
    combined = torch.cat((x.float(), y.float()), 1)
    return combined

def calculate_input_dim(z_dim, mnist_shape, n_classes):
    """
    Calculate the input dimensions for the generator and discriminator networks.

    Args:
        z_dim (int): Dimension of the random noise vector (latent space).
        mnist_shape (tuple): Shape of the MNIST images, e.g., (1, 28, 28).
        n_classes (int): Number of classes in the dataset.

    Returns:
        tuple: Tuple containing the generator input dimension and discriminator image channel.

    mnist_shape = (1, 28, 28)
    n_classes = 10"""
    generator_input_dim = z_dim + n_classes

    # mnist_shape[0] is 1 as its grayscale images
    discriminator_image_channel = 128 + n_classes

    return generator_input_dim, discriminator_image_channel

In [84]:
import torch
import torch.nn as nn

#from utils import *

####################################################
def test_weights_init():
    # Create a sample model with Conv2d and BatchNorm2d layers
    model = nn.Sequential(
        nn.Conv2d(3, 16, kernel_size=3),
        nn.BatchNorm2d(16),
        nn.ConvTranspose2d(16, 3, kernel_size=3),
        nn.BatchNorm2d(3)
    )

    # Initialize the model weights
    model.apply(weights_init)

    # Check the weights of Conv2d layers
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            assert torch.allclose(module.weight.mean(), torch.tensor(0.0), atol=0.02)
            assert torch.allclose(module.weight.std(), torch.tensor(0.02), atol=0.02)

    # Check the weights of BatchNorm2d layers
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            assert torch.allclose(module.weight.mean(), torch.tensor(0.0), atol=0.02)
            assert torch.allclose(module.weight.std(), torch.tensor(0.02), atol=0.02)
            assert torch.allclose(module.bias, torch.tensor(0.0))

    print("Unit test passed!")

# Run the unit test
# test_weights_init()

####################################################
def test_concat_vectors():
    # Create sample input tensors
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    y = torch.tensor([[7, 8, 9], [10, 11, 12]])

    # Perform concatenation
    combined = concat_vectors(x, y)

    # Check the output type and shape
    assert isinstance(combined, torch.Tensor)
    assert combined.shape == (2, 6)  # Expected shape after concatenation

    # Check the values in the concatenated tensor
    expected_combined = torch.tensor([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]])
    assert torch.allclose(combined, expected_combined)

    print("Unit test passed!")

# Run the unit test
# test_concat_vectors()

####################################################
def test_calculate_input_dim():
    # Set up sample inputs
    z_dim = 100
    mnist_shape = (1, 28, 28)
    n_classes = 10

    # Calculate input dimensions
    generator_input_dim, discriminator_image_channel = calculate_input_dim(z_dim, mnist_shape, n_classes)

    # Check the output types and values
    assert isinstance(generator_input_dim, int)
    assert generator_input_dim == z_dim + n_classes

    assert isinstance(discriminator_image_channel, int)
    assert discriminator_image_channel == mnist_shape[0] + n_classes

    print("Unit test passed!")

# Run the unit test
# test_calculate_input_dim()


In [85]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

#from conditional_gan import *
#from utils import *

mnist_shape = (1, 125, 125)
n_classes = 768


criterion = nn.BCEWithLogitsLoss()
n_epochs = 1
z_dim = 1001
display_step = 500
batch_size = 32
lr = 0.0002
device = "cuda"

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)


In [86]:
mnist_shape

(1, 125, 125)

In [87]:

#dataloader = DataLoader(
#    MNIST(
#        "/home/dennis00/scRNA_GAN/MNIST", download=True, transform=transform
#    ),
#    batch_size=batch_size,
#    shuffle=True,
#)

generator_input_dim, discriminator_image_channel = calculate_input_dim(
    z_dim, mnist_shape, n_classes
)

gen = Generator(input_dim=generator_input_dim).to(device)

gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)

disc = Discriminator(image_channel=discriminator_image_channel).to(device)

disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)


def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


gen  = gen.apply(weights_init)

disc = disc.apply(weights_init)


cur_step = 0
generator_losses = []
discriminator_losses = []

noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False


In [89]:
mnist_shape[1]

125

In [90]:
n_classes

768

In [91]:
discriminator_image_channel

896

In [23]:

class MultiModalDataset(Dataset):
    def __init__(self, cell_embeddings, smiles_embeddings, images):
        self.cell_embeddings = torch.tensor(cell_embeddings, dtype=torch.float32)
        self.smiles_embeddings = torch.tensor(smiles_embeddings, dtype=torch.float32)
        self.images = torch.tensor(images, dtype=torch.float32)

    def __len__(self):
        return len(self.cell_embeddings)

    def __getitem__(self, idx):
        return (
            self.cell_embeddings[idx], 
            self.smiles_embeddings[idx], 
            self.images[idx]
        )

# Example usage for train set
train_dataset = MultiModalDataset(train_cell_embeddings, train_smiles_embeddings, train_images)

# Same for validation and test sets
val_dataset = MultiModalDataset(valid_cell_embeddings, valid_smiles_embeddings, valid_images)
test_dataset = MultiModalDataset(test_cell_embeddings, test_smiles_embeddings, test_images)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [92]:
for epoch in range(n_epochs):
    for cell, smile, real in tqdm(train_loader):
        # Move data to GPU if needed
        cell, smile, real = cell.to(device), smile.to(device), real.to(device)

        real = real.unsqueeze(1)  # Adds a channel dimension at index 1

        cur_batch_size = len(cell)
        
        disc_opt.zero_grad()

        one_hot_labels = smile
        print("one_hot_labels ", one_hot_labels.size())
        fake_noise = create_noise_vector(cur_batch_size, z_dim, device=device)
        
        noise_and_labels = concat_vectors(fake_noise, one_hot_labels)

        fake = gen(noise_and_labels)

        image_one_hot_labels = one_hot_labels[:, :, None, None]
        print(
            "image_one_hot_labels.size ", image_one_hot_labels.size()
        )  # => torch.Size([128, 10, 1, 1])

        image_one_hot_labels = image_one_hot_labels.repeat(
            1, 1, real.shape[2], real.shape[2]
        )

        # Make sure that enough images were generated
        assert len(fake) == len(real)
        
        fake_image_and_labels = concat_vectors(fake, image_one_hot_labels)

        # Combine the real images with image_one_hot_labels
        real_image_and_labels = concat_vectors(real, image_one_hot_labels)

        # Get the discriminator's prediction on the reals and fakes
        disc_fake_pred = disc(fake_image_and_labels.detach())
        disc_real_pred = disc(real_image_and_labels)

        # Make sure that enough predictions were made
        assert len(disc_real_pred) == len(real)
        # Make sure that the inputs are different
        assert torch.any(fake_image_and_labels != real_image_and_labels)
        

  0%|          | 0/649 [00:00<?, ?it/s]

one_hot_labels  torch.Size([32, 768])
image_one_hot_labels.size  torch.Size([32, 768, 1, 1])


RuntimeError: Given groups=1, weight of size [512, 896, 4, 4], expected input[32, 769, 125, 125] to have 896 channels, but got 769 channels instead

In [93]:
fake_image_and_labels.shape

torch.Size([32, 896, 125, 125])

In [95]:
real_image_and_labels.shape

torch.Size([32, 769, 125, 125])

In [96]:
real.shape

torch.Size([32, 1, 125, 125])

In [97]:
image_one_hot_labels.shape

torch.Size([32, 768, 125, 125])

In [45]:
896 - 768

128

In [42]:
fake.shape

torch.Size([32, 128, 125, 125])

In [43]:
image_one_hot_labels.shape

torch.Size([32, 768, 125, 125])

In [44]:
gen(noise_and_labels).shape

torch.Size([32, 128, 125, 125])

In [None]:
real.shape

In [None]:
image_one_hot_labels.shape