## Synthetic Image Generation Model For Class "MITOTIC"
To achieve our primary objective of creating a balanced multi-label dataset for human pluripotent stem cells (hPSCs), we adapt the base model by incorporating prior knowledge and class weights through transfer learning performed on each label in the dataset. As a result, 5 different GAN models are trained according to the 5 labels of our main dataset. Pre-trained parameters are loaded to initialize these 5 models which are subsequently retrained using randomize input images corresponding to their respective labels. Each model generates 500 synthetic images using the fine-tuned weights and biases. We hypothesize that fine-tuning will result in generating images that align with the distinct characteristics of each class. This notebook is dedicated to generate synthetic images for the class "Mitotoic". It starts with importing necessary libraries. Second step is defining the same generator and discriminator classes as the base model with the same architectures. Then next step is defining the data transformation and the directory that contains the original images. Then we create a new directory to save the generated images. After that we initializing the Generator and we load the pre-trained weights. Then we creatw a new model with the current architecture. Then we transfer weights from the pretrained model to the new model. To finsih we define a method to generate images and save those images based on their classes with the specified output directory.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageEnhance
import os
from torchvision.utils import save_image

In [2]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100, image_channels=3, image_size=64):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.image_channels = image_channels
        self.image_size = image_size

        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, image_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, image_channels=3, image_size=64):
        super(Discriminator, self).__init__()
        self.image_channels = image_channels
        self.image_size = image_size

        self.main = nn.Sequential(
            nn.Conv2d(image_channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, 4, 1, 0, bias=False)
        )

    def forward(self, x):
        output = self.main(x)
        return output.view(-1, 1).squeeze(1)

In [4]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

input_image_directory = "/Users/isikgurhan/Desktop/data-jpg/iPSC_Morphologies/mitotic"

output_image_directory = "/Users/isikgurhan/Desktop/data-jpg/iPSC_Morphologies/generated_images_mitotic"
os.makedirs(output_image_directory, exist_ok=True)

generator = Generator(latent_dim=100, image_channels=3, image_size=64)

pretrained_generator_path = "/Users/isikgurhan/best_gan_weights.pth"
pretrained_dict = torch.load(pretrained_generator_path)

generator = Generator(latent_dim=100, image_channels=3, image_size=64)

model_dict = generator.state_dict()
for name, param in pretrained_dict.items():
    if name in model_dict:
        if "main" in name:
            if "num_batches_tracked" not in name:
                print(f"Transferring weights for layer: {name}")
                print(f"Original weight shape: {param.shape}")

                if "fc" in name:
                    model_fc_weight_shape = model_dict[name].shape
                    if model_fc_weight_shape != param.shape:
                        print(f"FC layer shape mismatch. Pretrained shape: {param.shape}, Model shape: {model_fc_weight_shape}")
                    else:
                        model_dict[name].copy_(param)
                elif "main.0" in name:
                    target_shape = model_dict[name].shape
                    param_reshaped = param.view(target_shape)
                    model_dict[name].copy_(param_reshaped)
                elif "main.3" in name:
                    target_shape = model_dict[name].shape
                    param_reshaped = param.view(target_shape)
                    model_dict[name].copy_(param_reshaped)
                elif "main.6" in name:
                    target_shape = model_dict[name].shape
                    param_reshaped = param.view(target_shape)
                    model_dict[name].copy_(param_reshaped)
                else:
                    model_dict[name].copy_(param)
    else:
        print(f"Ignoring unexpected key: {name}")

Ignoring unexpected key: generator
Ignoring unexpected key: discriminator
Ignoring unexpected key: optimizer_G
Ignoring unexpected key: optimizer_D


In [5]:
def generate_images(generator, num_images, latent_dim, output_dir):
    generator.eval()
    with torch.no_grad():
        for i in range(num_images):
            z = torch.randn(1, latent_dim, 1, 1)
            fake_image = generator(z)
            save_image(fake_image, os.path.join(output_dir, f"image_mitotoic{i+1}.png"))

num_generated_images = 500
generate_images(generator, num_generated_images, latent_dim=100, output_dir=output_image_directory)
