# GAN for Alzheimer's Data Augmentation

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image

from pre_processing.preprocessing import CustomDataset

# Dataset

In [28]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize image to 128x128
    transforms.ToTensor(),
    # transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
])

# Load the entire dataset
dataset = CustomDataset(csv_file='dataset.csv',
                        transform=transform)

# Split the dataset into train, validation, and test sets
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset,
                                                        [train_size, val_size,
                                                         test_size])

In [29]:
print(f'Training Samples: {len(train_dataset)}')
print(f'Validation Samples: {len(val_dataset)}')
print(f'Test Samples: {len(test_dataset)}')

Training Samples: 23788
Validation Samples: 5097
Test Samples: 5099


In [30]:
# Dataset Sanity Check for RGB Concatentation
import matplotlib.pyplot as plt
print(train_dataset[0][0].size())
# plt.plot(train_dataset[0][0])
# plt.show()

torch.Size([3, 128, 128])


# Data Loader

In [35]:
batch_size = 32  # You can adjust this according to your needs

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Generator Model

In [31]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Discriminator Model

In [32]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

# Create Model

In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Define Optimizer and Loss Scheme

In [34]:
# Adam Optimizer
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Binary Cross Entropy Loss Function (Classification Task)
criterion = nn.BCELoss()

# Training

In [37]:
num_epochs = 50
latent_dim = 100  # Size of the latent vector (i.e., size of generator input)

for epoch in range(num_epochs):
    for i, data in enumerate(data_loader):
        print(i)
        # Assuming 'data' is a batch of real images
        real_images = data[0].to(device)
        b_size = real_images.size(0)

        # Train Discriminator with real images
        optimizer_D.zero_grad()
        label = torch.full((b_size,), 1., dtype=torch.float, device=device)
        output = discriminator(real_images).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # Train Discriminator with fake images
        noise = torch.randn(b_size, latent_dim, 1, 1, device=device)
        fake_images = generator(noise)
        label.fill_(0.)
        output = discriminator(fake_images.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        label.fill_(1.)  # Fake labels are real for generator cost
        output = discriminator(fake_images).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizer_G.step()

        if i % 50 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(data_loader)}] Loss_D: {errD.item()}, Loss_G: {errG.item()} D(x): {D_x}, D(G(z)): {D_G_z1}/{D_G_z2}')

    # Save/checkpoint generator model at the end of each epoch
    # torch.save(generator.state_dict(), 'generator_epoch_{}.pth'.format(epoch))

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/benrandoing/opt/anaconda3/envs/229Final/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/benrandoing/opt/anaconda3/envs/229Final/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
  File "/Users/benrandoing/opt/anaconda3/envs/229Final/lib/python3.9/site-packages/torch/__init__.py", line 1474, in <module>
    from torch import quantization as quantization
  File "/Users/benrandoing/opt/anaconda3/envs/229Final/lib/python3.9/site-packages/torch/quantization/__init__.py", line 2, in <module>
    from .observer import *  # noqa: F403
  File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
  File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
  File "<frozen importlib._bootstr

KeyboardInterrupt: 

# Validation