# Imports

In [None]:
import os
import random
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import scipy
from google.colab import drive
from scipy.stats import entropy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.models import inception_v3
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


seed_value = 44
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
np.random.seed(seed_value)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

drive.mount('/gdrive')

inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
inception_model = inception_model.eval()

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


# Discriminator and Generator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d, num_classes, img_size):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img+1, features_d, kernel_size=4, stride=2,
                      padding=1),
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            self._block(features_d * 8, features_d * 16, 4, 2, 1),
            nn.Conv2d(features_d * 16, 1, kernel_size=4, stride=2, padding=0),
        )
        self.embed = nn.Embedding(num_classes, img_size*img_size)

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x, labels):
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size,
                                            self.img_size)
        x = torch.cat([x, embedding], dim=1)
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g, num_classes,
                 img_size, embed_size):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.net = nn.Sequential(
            self._block(channels_noise + embed_size, features_g * 32, 4, 1, 0),
            self._block(features_g * 32, features_g * 16, 4, 2, 1),
            self._block(features_g * 16, features_g * 8, 4, 2, 1),
            self._block(features_g * 8, features_g * 4, 4, 2, 1),
            self._block(features_g * 4, features_g * 2, 4, 2, 1),
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            nn.Tanh(),
        )
        self.embed = nn.Embedding(num_classes, embed_size)

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x, labels):
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([x, embedding], dim=1)
        return self.net(x)


def gradient_penalty(critic, labels, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    mixed_scores = critic(interpolated_images, labels)

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True)[0]

    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

# Hiperparameters

In [None]:
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 128
CHANNELS_IMG = 1
NUM_CLASSES = 2
GEN_EMBEDDING = 100
NOISE_DIM = 100
NUM_EPOCHS = 100
FEATURES_DISC = 128
FEATURES_GEN = 128
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

# BalancedBatchSampler

In [None]:
class BalancedBatchSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, dataset):
        self.dataset = dataset
        self.indices_by_class = defaultdict(list)

        for idx, (_, label) in enumerate(self.dataset):
            self.indices_by_class[label].append(idx)

        self.batches = self.generate_batches()
        self.num_batches = len(self.batches)

    def shuffle_indices_by_class(self):
        for key in self.indices_by_class:
            random.shuffle(self.indices_by_class[key])

    def generate_batches(self):
        batches = []
        self.shuffle_indices_by_class()
        num_batches = min(len(self.indices_by_class[0]), len(self.indices_by_class[1])) // (BATCH_SIZE // 2)

        for i in range(num_batches):
            batch = []
            batch.extend(self.indices_by_class[0][i * (BATCH_SIZE // 2):(i + 1) * (BATCH_SIZE // 2)])
            batch.extend(self.indices_by_class[1][i * (BATCH_SIZE // 2):(i + 1) * (BATCH_SIZE // 2)])
            batches.append(batch)

        return batches

    def __len__(self):
        return self.num_batches

    def __iter__(self):
        self.batches = self.generate_batches()
        return iter(self.batches)

# FID & IS

In [None]:
def get_inception_features(images, model):
    images = F.interpolate(images, size=(299, 299))
    images = images.repeat(1, 3, 1, 1)

    outputs = model(images)
    if isinstance(outputs, tuple):
        return outputs[0]
    return outputs


def compute_fid(real_features, fake_features):
    mu1, sigma1 = real_features.mean(0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(0), np.cov(fake_features, rowvar=False)
    diff = mu1 - mu2

    covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        offset = np.eye(sigma1.shape[0]) * 1e-6
        covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real
    tr_covmean = np.trace(covmean)
    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean


In [None]:
def inception_score(imgs, inception_model, splits=10):
    def get_preds(imgs):
        imgs = torch.nn.functional.interpolate(imgs, size=(299, 299))
        with torch.no_grad():
            if imgs.shape[1] == 1:
                imgs = imgs.repeat(1, 3, 1, 1)
            outputs = inception_model(imgs)
            return torch.nn.functional.softmax(outputs, dim=1).data.cpu().numpy()

    scores = []
    preds = get_preds(imgs)
    n = len(preds)
    for i in range(splits):
        part = preds[(n // splits) * i : (n // splits) * (i + 1), :]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))

    return np.mean(scores), np.std(scores)

# Training

In [None]:
transform = transforms.Compose(
    [
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

data_dir = '/gdrive/MyDrive/chest_xray/train'

train_data = datasets.ImageFolder(data_dir, transform=transform)
sampler = BalancedBatchSampler(train_data)
dataloader = torch.utils.data.DataLoader(train_data, batch_sampler=sampler)

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

In [None]:
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMAGE_SIZE, GEN_EMBEDDING).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_DISC, NUM_CLASSES, IMAGE_SIZE).to(device)
initialize_weights(gen)
initialize_weights(critic)

opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

fixed_noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)

writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
critic.train()

D_losses = []
G_losses = []
Dx_values = []
DGz_values = []
FID_values = []
IS_values = []

for epoch in range(1, NUM_EPOCHS+1):
    epoch_D_losses = []
    epoch_G_losses = []
    epoch_Dx = []
    epoch_DGz = []
    epoch_FID = []
    epoch_IS = []

    for batch_idx, (real, labels) in enumerate(dataloader, start=1):
        real = real.to(device)
        labels = labels.to(device)

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
            fake = gen(noise, labels)
            critic_real = critic(real, labels).reshape(-1)
            critic_fake = critic(fake, labels).reshape(-1)

            gp = gradient_penalty(critic, labels, real, fake, device=device)

            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP*gp
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        critic_fake = critic(fake, labels).reshape(-1)
        loss_gen = -torch.mean(critic_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        epoch_D_losses.append(loss_critic.item())
        epoch_Dx.append(critic_real.mean().item())

        epoch_G_losses.append(loss_gen.item())
        epoch_DGz.append(critic_fake.mean().item())

        with torch.no_grad():
            real_features = get_inception_features(real, inception_model).cpu().numpy()
            fake_features = get_inception_features(fake, inception_model).cpu().numpy()

            fid = compute_fid(real_features, fake_features)
            epoch_FID.append(fid)

            mean_is, std_is = inception_score(fake, inception_model)
            epoch_IS.append(mean_is)

        if batch_idx % 10 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise, labels)

                img_grid_real = torchvision.utils.make_grid(real[:9], normalize=True, nrow=3)
                img_grid_fake = torchvision.utils.make_grid(fake[:9], normalize=True, nrow=3)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

    D_losses.append(sum(epoch_D_losses)/len(epoch_D_losses))
    G_losses.append(sum(epoch_G_losses)/len(epoch_G_losses))
    Dx_values.append(sum(epoch_Dx)/len(epoch_Dx))
    DGz_values.append(sum(epoch_DGz)/len(epoch_DGz))

    FID_values.append(sum(epoch_FID)/len(epoch_FID))
    IS_values.append(sum(epoch_IS)/len(epoch_IS))


    fig, axs = plt.subplots(3, 3, figsize=(6,6))
    for ax, img in zip(axs.ravel(), fake):
        ax.imshow(img[0].cpu().detach().numpy(), cmap='gray')
        ax.axis('off')

    if not os.path.exists('./generated'):
        os.makedirs('./generated')
    plt.savefig(f'generated/epoch_{epoch}.png')
    plt.close()

torch.save(gen.state_dict(), 'generator_model.pth')

In [None]:
scores = [D_losses, G_losses, Dx_values, DGz_values, FID_values, IS_values]

if not os.path.exists('./results'):
        os.makedirs('./results')

def variable_name(var, namespace=globals()):
    return [name for name, value in namespace.items() if value is var][0]

for score in scores:
    with open(f"results/{variable_name(score)}.txt", "w") as file:
        for val in score:
            file.write(f"{val}\n")