In [1]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torchvision
import os
from torchvision.utils import save_image

In [2]:
# Define dataset path
data_root = '../data/source_2'

# Define transformations for data augmentation (only for the minority class)
transform_aug = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to standard size
    transforms.RandomHorizontalFlip(),  # Random horizontal flip
    transforms.RandomRotation(10),  # Random rotation by 10 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # Random color jitter
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),  # Random translation
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize between -1 and 1
])

# Define basic transformation for the majority class
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize
])

# Load the images from both classes
dataset_A = datasets.ImageFolder(root=os.path.join(data_root, "Non Demented"), transform=transform)  # No Alzheimer's
dataset_B = datasets.ImageFolder(root=os.path.join(data_root, "Very mild Dementia"), transform=transform_aug)  # Very Mild with augmentation

# Number of samples in each class
num_A = len(dataset_A)  # No Alzheimer's class
num_B = len(dataset_B)  # Very Mild class

# Calculate how many times to repeat the minority class
repeat_factor = num_A // num_B  # To make dataset_B equal size to dataset_A

# Augment dataset_B (Very Mild) until it reaches the size of dataset_A
oversampled_B = torch.utils.data.ConcatDataset([dataset_B] * repeat_factor)

# Create DataLoader for both datasets
batch_size = 16
dataloader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True)
dataloader_B = DataLoader(oversampled_B, batch_size=batch_size, shuffle=True)

# Example: Load a batch
images_A, _ = next(iter(dataloader_A))
images_B, _ = next(iter(dataloader_B))

print(f"Loaded batch from No Alzheimer's: {images_A.shape}")
print(f"Loaded batch from Very Mild (augmented): {images_B.shape}")


Loaded batch from No Alzheimer's: torch.Size([16, 3, 256, 256])
Loaded batch from Very Mild (augmented): torch.Size([16, 3, 256, 256])


In [3]:
print(len(dataset_A))
len(oversampled_B)

67222


54900

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torchvision.utils import save_image

# Define the Generator (ResNet-based)
class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, num_blocks=9):
        super(ResnetGenerator, self).__init__()

        model = [
            nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=3, bias=False),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True)
        ]

        # Downsampling
        model += [
            nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(ngf * 4),
            nn.ReLU(True)
        ]

        # Residual Blocks
        for _ in range(num_blocks):
            model += [ResnetBlock(ngf * 4)]

        # Upsampling
        model += [
            nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.InstanceNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True)
        ]

        model += [
            nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=3, bias=False),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Define Residual Block
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super(ResnetBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim)
        )

    def forward(self, x):
        return x + self.block(x)

# Define PatchGAN Discriminator
class PatchGANDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64):
        super(PatchGANDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 4, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# Initialize networks
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G = ResnetGenerator(input_nc=3, output_nc=3).to(device)
D_B = PatchGANDiscriminator(input_nc=3).to(device)

# Loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()  # Cycle consistency loss
criterion_identity = nn.L1Loss()  # Identity loss

# Optimizers
optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
num_epochs = 50
for epoch in range(1, num_epochs):
    print(f'epochs - {epoch}')
    for i, (real_A, _) in enumerate(dataloader_A):
        print(i)
        real_B, _ = next(iter(dataloader_B))  # Load a batch from Very Mild Dementia

        real_A, real_B = real_A.to(device), real_B.to(device)

        # ----------- Train Generator G -----------
        optimizer_G.zero_grad()

        fake_B = G(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake))  # Adversarial loss

        # Cycle loss (optional)
        loss_cycle = criterion_cycle(G(fake_B), real_A) * 10.0

        # Identity loss (optional)
        loss_identity = criterion_identity(G(real_B), real_B) * 5.0

        loss_G = loss_GAN + loss_cycle + loss_identity
        loss_G.backward()
        optimizer_G.step()

        # ----------- Train Discriminator D_B -----------
        optimizer_D_B.zero_grad()

        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = D_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()

        if i % 100 == 0:
            print(f"Epoch {epoch}/{num_epochs}, Batch {i}, Loss G: {loss_G.item()}, Loss D_B: {loss_D_B.item()}")

            # Save some sample images
            save_image(fake_B, f"results/fake_B_{epoch}_{i}.png")
            save_image(real_A, f"results/real_A_{epoch}_{i}.png")

print("Training complete.")

epochs - 1
0
Epoch 1/50, Batch 0, Loss G: 12.73147201538086, Loss D_B: 0.8458019495010376
1
2
3
4
5
6
7
8
