In [None]:
# Load TensorBoard extension
%load_ext tensorboard

# Start TensorBoard
from datetime import datetime
import os

# Set up a log directory for TensorBoard
log_dir = os.path.join("logs", datetime.now().strftime("%Y%m%d-%H%M%S"))
os.makedirs(log_dir, exist_ok=True)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import zscore
import os
import cv2
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np
def preProcessing():
    LABELS_PATH = '/content/drive/My Drive/OCTDL_labels.csv'
    INPUT_PATH = '/content/drive/My Drive/OCTDL'
    OUTPUT_PATH = '/content/drive/My Drive/augmented_data'

    df = pd.read_csv(LABELS_PATH)

    dimensions = pd.DataFrame(df, columns = ['file_name', 'image_width', 'image_hight'])

    dimensions['z_width'] = zscore(dimensions['image_width'])
    dimensions['z_height'] = zscore(dimensions['image_hight'])

    z_threshold = 2.5

    filtered_dimensions = dimensions[
        (abs(dimensions['z_width']) <= z_threshold) &
        (abs(dimensions['z_height']) <= z_threshold)
    ].copy()


    filtered_dimensions['aspect_ratio'] = filtered_dimensions['image_width'] / filtered_dimensions['image_hight']
    avg_aspect_ratio = filtered_dimensions['aspect_ratio'].mean()

    height = filtered_dimensions['image_hight'].min()
    width = round(height * avg_aspect_ratio)
    target_size = (width, height)

    def resize_and_save(classname):
        input_dir = INPUT_PATH + '/' + classname
        output_dir = OUTPUT_PATH + '/' + classname

        os.makedirs(output_dir, exist_ok = True)

        for _, row in filtered_dimensions.iterrows():
                file_name = row['file_name'] + '.jpg'
                if classname.lower() not in file_name:
                    continue
                image_path = os.path.join(input_dir, file_name)

                # Read the image
                img = cv2.imread(image_path)
                if img is None:
                    print(f"Error loading image {file_name}")
                    continue

                resized_img = cv2.resize(img, target_size, interpolation = cv2.INTER_AREA)

                output_path = os.path.join(output_dir, file_name)
                cv2.imwrite(output_path, resized_img)

    classnames = ['AMD', 'DME', 'ERM', 'NO', 'RAO', 'RVO', 'VID']

    #for classname in classnames:
        #resize_and_save(classname)


    # hyperparameters - to be adjusted
    TEST_TRAIN_SPLIT = 0.3
    BATCH_SIZE = 32

    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()
    ])

    dataset = datasets.ImageFolder(OUTPUT_PATH, transform=transform)

    n_test = int(np.floor(TEST_TRAIN_SPLIT * len(dataset)))
    n_train = len(dataset) - n_test

    train_ds, test_ds = random_split(dataset, [n_train, n_test])

    train_dl = DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)
    test_dl = DataLoader(test_ds, batch_size = BATCH_SIZE, shuffle = False, num_workers = 4)


    # some useful info about the dataset
    print(f"Classes: {dataset.classes}")
    print(f"Number of training samples: {len(train_ds)}")
    print(f"Number of testing samples: {len(test_ds)}")
    #for i, (x, label) in enumerate(train_dl):
        #print(label)
        #break
        #print(f"Image shape: {x.shape}")
    return train_dl, test_dl

In [None]:

def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            #input image of size 546 * 199
            nn.Conv2d(channels_img, features_d, kernel_size=[3,4], stride=2, padding=1),
            nn.LeakyReLU(0.2),# img: 273x100
            self._block(features_d, features_d * 2, 4, 2, 1),  # img: 136x50
            self._block(features_d * 2, features_d * 4, 4, 2, 1),  # img: 68x25
            self._block(features_d * 4, features_d * 8, 4, 2, 1),  # img: 34x12
            self._block(features_d * 8, features_d * 16, 4, 2, 1),  # img: 17x6
            self._block(features_d * 16, features_d * 32, [3,4], [1,2], 1),  # img: 8x6
            self._block(features_d * 32, features_d * 64, 4, 2, 1),  # img: 4x3
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 64, 1, kernel_size=[3,4], stride=[3,4], padding=0),
        )


    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 128, [3,4], 1, 0),  # img: 4x3
            self._block(features_g * 128, features_g * 64, [1,4], [1,2], [0,1]),  # img: 8x3
            self._block(features_g * 64, features_g * 32, [4,5], 2, 1),  # img: 17x6
            self._block(features_g * 32, features_g * 16, 4, 2, 1),  # img: 34x12
            self._block(features_g * 16, features_g * 8, [5,4], 2, 1),  # img: 68x25
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 136x50
            self._block(features_g * 4, features_g * 2, [4,5], 2, 1),  # img: 273x100
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=[3,4], stride=2, padding=1
            ),
            # Output: N x channels_img x 546 x 199
            nn.Tanh(),
        )


    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)


def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)





In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from google.colab import files

# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 200
FEATURES_CRITIC = 16
FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10


train_dl, test_dl = preProcessing()

# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

# for tensorboard plotting
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/GAN_OCT/real")
writer_fake = SummaryWriter(f"logs/GAN_OCT/fake")
%tensorboard --logdir logs --host 0.0.0.0 --port 6006
step = 0

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(train_dl):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # equivalent to minimizing the negative of that
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]

        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 44 == 0 and batch_idx > 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx+1}/{len(train_dl)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

                output_path = "fake_image.png"
                torchvision.utils.save_image(img_grid_fake, output_path)
                files.download(output_path)

            step += 1
