In [1]:

import pandas as pd
import pickle as pkl
import scanpy as sc
import anndata as ad
import numpy as np
from sklearn.model_selection import train_test_split
import h5py


In [3]:


ZS13 = h5py.File("../ZafrensData/zel024/microscopy/ZS13.segmentations.h5", "r+")
ZS25 = h5py.File("../ZafrensData/zel024/microscopy/ZS25.segmentations.h5", "r+")
#ZS26 = h5py.File("../ZafrensData/zel031/microscopy/ZS26.segmentations.h5", "r+")
#ZS27 = h5py.File("../ZafrensData/zel031/microscopy/ZS27.segmentations.h5", "r+")

Z13_image_dataset = ZS13["images"][()]
Z25_image_dataset = ZS25["images"][()]
#Z26_image_dataset = ZS26["images"][()]
#Z27_image_dataset = ZS27["images"][()]


In [None]:


all_images = np.concatenate((Z13_image_dataset, Z25_image_dataset
                            # Z26_image_dataset, Z27_image_dataset
                            ), axis=0)


Z13_image_barcode = pd.read_csv("../ZafrensData/zel024/microscopy/ZS13_dim_0_metadata.csv")
Z25_image_barcode = pd.read_csv("../ZafrensData/zel024/microscopy/ZS25_dim_0_metadata.csv")
#Z26_image_barcode = pd.read_csv("../ZafrensData/zel031/microscopy/ZS26_dim_0_metadata.csv")
#Z27_image_barcode = pd.read_csv("../ZafrensData/zel031/microscopy/ZS27_dim_0_metadata.csv")



In [None]:

# Columns to concatenate
columns_to_concat = ['physical_well_id', 'control_rx_id', 'bb1_id', 'bb2_id', 'bb3_id', 'bb4_id', 'censored']

# Concatenate columns with an underscore separator
Z13_image_barcode['sample'] = Z13_image_barcode[columns_to_concat].astype(str).agg('_'.join, axis=1)
Z25_image_barcode['sample'] = Z25_image_barcode[columns_to_concat].astype(str).agg('_'.join, axis=1)

# Columns to concatenate
# Z26_image_barcode['bb4_id'] = 1
# Z26_image_barcode['bb3_id'] = 1

# Z27_image_barcode['bb4_id'] = 1
# Z27_image_barcode['bb3_id'] = 1

#columns_to_concat = ['physical_well_id', 'control_rx_id', 'bb1_id', 'bb2_id', 'censored']

# Z26_image_barcode['sample'] = Z26_image_barcode[columns_to_concat].astype(str).agg('_'.join, axis=1)
# Z27_image_barcode['sample'] = Z27_image_barcode[columns_to_concat].astype(str).agg('_'.join, axis=1)

merged_df = pd.concat([Z13_image_barcode, Z25_image_barcode
                      # Z26_image_barcode, Z27_image_barcode
                      ], 
                      axis=0, join="inner").reset_index()

In [None]:

merged_df_not_censored = merged_df[merged_df['censored'] == False]
df_unique = merged_df_not_censored.drop_duplicates(subset=['sample'])
selected_indices = df_unique.index.values  # or some filtered indices like metadata[metadata['label'] == 'A'].index.values
selected_images = all_images[selected_indices]


In [None]:
all_images.shape

In [None]:

cell_embeddings = pd.read_csv("../data_for_training/cell_embeddings.csv")
cell_embeddings = np.array(cell_embeddings)

with open("../data_for_training/matched_metadata.pkl", "rb") as file:
    matched_metadata = pkl.load(file)

with open("../data_for_training/smiles_embedding_matrix.pkl", "rb") as file:
    smiles_embeddings = pkl.load(file)
    

In [None]:

with open("../data_for_training/images_34K.pkl", "rb") as file:
    images_34K = pkl.load(file)


In [None]:

df_cleaned = matched_metadata.dropna()

# Assume 'df' is your DataFrame
# First split: 70% train, 30% temporary (test + validation)
train, temp = train_test_split(df_cleaned, test_size=0.3, random_state=42)

# Second split: split the temporary set equally into test and validation (15% each)
test, validation = train_test_split(temp, test_size=0.5, random_state=42)

# Check the sizes:
print("Train set:", len(train))
print("Test set:", len(test))
print("Validation set:", len(validation))


#### Training sets

In [None]:

train_cell_embeddings = cell_embeddings[list(train['cell_index'].astype('Int64')), :]
train_smiles_embeddings = smiles_embeddings[list(train['smiles_index'].astype('Int64')), :]
train_images = images_34K[list(train['image_index'].astype('Int64')), :]


#### Validation sets

In [None]:

valid_cell_embeddings = cell_embeddings[list(validation['cell_index'].astype('Int64')), :]
valid_smiles_embeddings = smiles_embeddings[list(validation['smiles_index'].astype('Int64')), :]
valid_images = images_34K[list(validation['image_index'].astype('Int64')), :]



#### Test sets


In [None]:

test_cell_embeddings = cell_embeddings[list(test['cell_index'].astype('Int64')), :]
test_smiles_embeddings = smiles_embeddings[list(test['smiles_index'].astype('Int64')), :]
test_images = images_34K[list(test['image_index'].astype('Int64')), :]


In [None]:

with open("../data_for_training/train_cell_embeddings.pkl", "wb") as file:
    pkl.dump(train_cell_embeddings, file)

with open("../data_for_training/train_smiles_embeddings.pkl", "wb") as file:
    pkl.dump(train_smiles_embeddings, file)

with open("../data_for_training/train_images.pkl", "wb") as file:
    pkl.dump(train_images, file)

with open("../data_for_training/valid_cell_embeddings.pkl", "wb") as file:
    pkl.dump(valid_cell_embeddings, file)

with open("../data_for_training/valid_smiles_embeddings.pkl", "wb") as file:
    pkl.dump(valid_smiles_embeddings, file)

with open("../data_for_training/valid_images.pkl", "wb") as file:
    pkl.dump(valid_images, file)

with open("../data_for_training/test_cell_embeddings.pkl", "wb") as file:
    pkl.dump(test_cell_embeddings, file)

with open("../data_for_training/test_smiles_embeddings.pkl", "wb") as file:
    pkl.dump(test_smiles_embeddings, file)

with open("../data_for_training/test_images.pkl", "wb") as file:
    pkl.dump(test_images, file)


## GAN training

In [None]:
import torch
from torch import nn


class Generator(nn.Module):
    def __init__(self, input_dim=10, image_channel=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.input_dim = input_dim

        self.gen = nn.Sequential(
            self._generator_block(input_dim, hidden_dim * 4),
            self._generator_block(
                hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1
            ),
            self._generator_block(hidden_dim * 2, hidden_dim),
            self._generator_block(
                hidden_dim, image_channel, kernel_size=4, final_layer=True
            ),
        )

    def _generator_block(
        self,
        input_channels,
        output_channels,
        kernel_size=3,
        stride=2,
        final_layer=False,
    ):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(
                    input_channels, output_channels, kernel_size, stride
                ),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(
                    input_channels, output_channels, kernel_size, stride
                ),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)


def create_noise_vector(n_samples, input_dim, device="cpu"):
    return torch.randn(n_samples, input_dim, device=device)


class Discriminator(nn.Module):
    def __init__(self, image_channel=1, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self._discriminator_block(image_channel, hidden_dim),
            self._discriminator_block(hidden_dim, hidden_dim * 2),
            self._discriminator_block(hidden_dim * 2, 1, final_layer=True),
        )

    def _discriminator_block(
        self,
        input_channels,
        output_channels,
        kernel_size=4,
        stride=2,
        final_layer=False,
    ):

        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

In [None]:
import torch
from torch import nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torch.nn.functional as F

torch.manual_seed(0)  # Set for our testing purposes, please do not change!


def plot_images_from_tensor(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    """
    Plots a grid of images from a given tensor.

    The function first scales the image tensor to the range [0, 1]. It then detaches the tensor from the computation
    graph and moves it to the CPU if it's not already there. After that, it creates a grid of images and plots the grid.

    Args:
        image_tensor (torch.Tensor): A 4D tensor containing the images.
            The tensor is expected to be in the shape (batch_size, channels, height, width).
        num_images (int, optional): The number of images to include in the grid. Default is 25.
        size (tuple, optional): The size of a single image in the form of (channels, height, width). Default is (1, 28, 28).
        nrow (int, optional): Number of images displayed in each row of the grid. The final grid size is (num_images // nrow, nrow). Default is 5.
        show (bool, optional): Determines if the plot should be shown. Default is True.

    Returns:
        None. The function outputs a plot of a grid of images.
    """

    # Normalize the image tensor to [0, 1]
    image_tensor = (image_tensor + 1) / 2

    # Detach the tensor from its computation graph and move it to the CPU
    image_unflat = image_tensor.detach().cpu()

    # Create a grid of images using the make_grid function from torchvision.utils
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)

    # Plot the grid of images
    # The permute() function is used to rearrange the dimensions of the grid for plotting
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())

    # Show the plot if the 'show' parameter is True
    if show:
        plt.show()

    



""" The reason for doing "image_grid.permute(1, 2, 0)"

PyTorch modules processing image data expect tensors in the format C × H × W.

Whereas PILLow and Matplotlib expect image arrays in the format H × W × C

so to use them with matplotlib you need to reshape it
to put the channels as the last dimension:

I could have used permute() method as well like below
"np.transpose(npimg, (1, 2, 0))"

------------------

Tensor.detach() is used to detach a tensor from the current computational graph. It returns a new tensor that doesn't require a gradient.

When we don't need a tensor to be traced for the gradient computation, we detach the tensor from the current computational graph.

We also need to detach a tensor when we need to move the tensor from GPU to CPU.

"""


def weights_init(m):
    """
    Initialize the weights of convolutional and batch normalization layers.

    Args:
        m (torch.nn.Module): Module instance.

    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


def ohe_vector_from_labels(labels, n_classes):
    return F.one_hot(labels, num_classes=n_classes)


"""
x = torch.tensor([4, 3, 2, 1, 0])
F.one_hot(x, num_classes=6)

# Expected result
# tensor([[0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 1, 0, 0],
#         [0, 0, 1, 0, 0, 0],
#         [0, 1, 0, 0, 0, 0],
#         [1, 0, 0, 0, 0, 0]])
"""


""" Concatenation of Multiple Tensor with `torch.cat()` - RULE - To concatenate WITH torch.cat(), where the list of tensors are concatenated across the specified dimensions, requires 2 conditions to be satisfied

1. All tensors need to have the same number of dimensions and
2. All dimensions except the one that they are concatenated on, need to have the same size. """


def concat_vectors(x, y):
    """
    Concatenate two tensors along the second dimension.

    Args:
        x (torch.Tensor): First input tensor.
        y (torch.Tensor): Second input tensor.

    Returns:
        torch.Tensor: Concatenated tensor.

    """
    combined = torch.cat((x.float(), y.float()), 1)
    return combined

def calculate_input_dim(z_dim, mnist_shape, n_classes):
    """
    Calculate the input dimensions for the generator and discriminator networks.

    Args:
        z_dim (int): Dimension of the random noise vector (latent space).
        mnist_shape (tuple): Shape of the MNIST images, e.g., (1, 28, 28).
        n_classes (int): Number of classes in the dataset.

    Returns:
        tuple: Tuple containing the generator input dimension and discriminator image channel.

    mnist_shape = (1, 28, 28)
    n_classes = 10"""
    generator_input_dim = z_dim + n_classes

    # mnist_shape[0] is 1 as its grayscale images
    discriminator_image_channel = mnist_shape[0] + n_classes

    return generator_input_dim, discriminator_image_channel

In [None]:
import torch
import torch.nn as nn

#from utils import *

####################################################
def test_weights_init():
    # Create a sample model with Conv2d and BatchNorm2d layers
    model = nn.Sequential(
        nn.Conv2d(3, 16, kernel_size=3),
        nn.BatchNorm2d(16),
        nn.ConvTranspose2d(16, 3, kernel_size=3),
        nn.BatchNorm2d(3)
    )

    # Initialize the model weights
    model.apply(weights_init)

    # Check the weights of Conv2d layers
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            assert torch.allclose(module.weight.mean(), torch.tensor(0.0), atol=0.02)
            assert torch.allclose(module.weight.std(), torch.tensor(0.02), atol=0.02)

    # Check the weights of BatchNorm2d layers
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            assert torch.allclose(module.weight.mean(), torch.tensor(0.0), atol=0.02)
            assert torch.allclose(module.weight.std(), torch.tensor(0.02), atol=0.02)
            assert torch.allclose(module.bias, torch.tensor(0.0))

    print("Unit test passed!")

# Run the unit test
# test_weights_init()

####################################################
def test_concat_vectors():
    # Create sample input tensors
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    y = torch.tensor([[7, 8, 9], [10, 11, 12]])

    # Perform concatenation
    combined = concat_vectors(x, y)

    # Check the output type and shape
    assert isinstance(combined, torch.Tensor)
    assert combined.shape == (2, 6)  # Expected shape after concatenation

    # Check the values in the concatenated tensor
    expected_combined = torch.tensor([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]])
    assert torch.allclose(combined, expected_combined)

    print("Unit test passed!")

# Run the unit test
# test_concat_vectors()

####################################################
def test_calculate_input_dim():
    # Set up sample inputs
    z_dim = 100
    mnist_shape = (1, 28, 28)
    n_classes = 10

    # Calculate input dimensions
    generator_input_dim, discriminator_image_channel = calculate_input_dim(z_dim, mnist_shape, n_classes)

    # Check the output types and values
    assert isinstance(generator_input_dim, int)
    assert generator_input_dim == z_dim + n_classes

    assert isinstance(discriminator_image_channel, int)
    assert discriminator_image_channel == mnist_shape[0] + n_classes

    print("Unit test passed!")

# Run the unit test
# test_calculate_input_dim()


In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

#from conditional_gan import *
#from utils import *


mnist_shape = (1, 28, 28)
n_classes = 10


criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = "cuda"

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)
