In [3]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m27.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [4]:
!pip install torchgan

Collecting torchgan
  Downloading torchgan-0.1.0-py3-none-any.whl.metadata (8.9 kB)
Collecting wget (from torchgan)
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading torchgan-0.1.0-py3-none-any.whl (71 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.7/71.7 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9656 sha256=a3eaa2ab1b1b63650ce735d76959e171d452d755013960fad276dd54362a83bf
  Stored in directory: /root/.cache/pip/wheels/8b/f1/7f/5c94f0a7a505ca1c81cd1d9208ae2064675d97582078e6c769
Successfully built wget
Installing collected packages: wget, torchgan
Successfully installed torchgan-0.1.0 wget-3.2


In [7]:
!pip3 install torch



In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchgan.trainer import Trainer
from torchgan.models import DCGANGenerator, DCGANDiscriminator
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from datasets import load_dataset
import matplotlib.pyplot as plt

# Load Dataset
train_dataset = load_dataset('Falah/Alzheimer_MRI', split='train')
test_dataset = load_dataset('Falah/Alzheimer_MRI', split='test')

# Data Transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])

# Custom Dataset
class AlzheimerMRIDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform):
        self.hf_dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        # Load image and label
        item = self.hf_dataset[idx]
        image = item["image"]
        label = item["label"]
        image = self.transform(image)  # Apply transformation
        return image, label


# Convert to PyTorch datasets
train_data = AlzheimerMRIDataset(train_dataset, transform)

# Create DataLoader
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)


In [24]:
class AlzheimerMRIDCGANDiscriminator(DCGANDiscriminator):
    def __init__(self, in_channels=3, step_channels=64, in_size=64):
        super(AlzheimerMRIDCGANDiscriminator, self).__init__(in_size, in_channels, step_channels)

        # Modify 1st layer
        self.model[0] = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)

    def forward(self, x, feature_matching=False):
        x = self.model(x)

        print(f"Shape after Conv2d: {x.shape}")

        x = x.view(x.size(0), -1)  # Flatten the tensor
        return self.disc(x)

In [25]:
# Set device for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

network = {
    "generator": {
        "name": DCGANGenerator,
        "args": {
            "out_channels": 3,
            "step_channels": 64,
        },
        "optimizer": {"name": optim.Adam, "args": {"lr": 3e-4, "betas": (0.5, 0.999), 'weight_decay': 1e-3}},
    },
    "discriminator": {
        "name": AlzheimerMRIDCGANDiscriminator,
        "args": {
            "in_channels": 3,
            "step_channels": 64,
        },
        "optimizer": {"name": optim.Adam, "args": {"lr": 3e-4, "betas": (0.5, 0.999), 'weight_decay': 1e-3}},
    },
}

from torchgan.losses import *
losses = [LeastSquaresGeneratorLoss(), LeastSquaresDiscriminatorLoss()]

trainer = Trainer(
    network, losses, sample_size=64, epochs=50, device=device
)

In [26]:
# Define loss function
criterion = nn.BCELoss()

# Initialize models
generator = network['generator']['name'](**network['generator']['args']).to(device)
discriminator = network['discriminator']['name'](**network['discriminator']['args']).to(device)

# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=3e-4, betas=(0.5, 0.999), weight_decay=1e-3)
optimizer_d = optim.Adam(discriminator.parameters(), lr=3e-4, betas=(0.5, 0.999), weight_decay=1e-3)

# Lists to record loss and accuracy during training
loss_list = {'generator_loss': [], 'discriminator_loss': []}
acc_list = {'discriminator_accuracy': [], 'generator_accuracy': []}

# Define training loop
def train():
    for epoch in range(50):  # Train for 50 epochs
        generator_loss = 0.0
        discriminator_loss = 0.0
        generator_accuracy = 0.0
        discriminator_accuracy = 0.0
        for i, (real_images, _) in enumerate(train_loader):  # Assume binary labels
            real_images = real_images.to(device)

            # Train Discriminator
            optimizer_d.zero_grad()

            # Real images labels
            real_labels = torch.ones(real_images.size(0), 1).to(device)
            fake_labels = torch.zeros(real_images.size(0), 1).to(device)

            # Forward pass real images
            outputs = discriminator(real_images)
            real_loss = criterion(outputs, real_labels)
            real_accuracy = ((outputs > 0.5).float() == real_labels).float().mean().item()

            # Generate fake images
            z = torch.randn(real_images.size(0), 100).to(device)  # Latent vector z
            fake_images = generator(z)

            # Forward pass fake images
            outputs = discriminator(fake_images.detach())
            fake_loss = criterion(outputs, fake_labels)
            fake_accuracy = ((outputs < 0.5).float() == fake_labels).float().mean().item()

            # Discriminator loss and accuracy
            d_loss = real_loss + fake_loss
            discriminator_accuracy = (real_accuracy + fake_accuracy) / 2

            # Backprop and optimize
            d_loss.backward()
            optimizer_d.step()

            # Train Generator
            optimizer_g.zero_grad()

            # Generator loss (try to fool discriminator)
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels)

            # Generator accuracy (how well generator fooled the discriminator)
            generator_accuracy = (outputs > 0.5).float().mean().item()

            # Backprop and optimize
            g_loss.backward()
            optimizer_g.step()

            # Accumulate losses and accuracy
            generator_loss += g_loss.item()
            discriminator_loss += d_loss.item()
            loss_list['generator_loss'].append(g_loss.item())
            loss_list['discriminator_loss'].append(d_loss.item())
            acc_list['generator_accuracy'].append(generator_accuracy)
            acc_list['discriminator_accuracy'].append(discriminator_accuracy)

        # Print loss and accuracy after each epoch
        print(f"Epoch [{epoch+1}/50], "
              f"Generator Loss: {generator_loss/len(train_loader)}, "
              f"Discriminator Loss: {discriminator_loss/len(train_loader)}, "
              f"Generator Accuracy: {generator_accuracy}, "
              f"Discriminator Accuracy: {discriminator_accuracy}")

# Start training
train()

Shape after Conv2d: torch.Size([64, 512, 8, 8])


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [64, 32768]

In [None]:
# After training, plot the loss and accuracy graphs
def plot_graphs():
    epochs = range(1, len(loss_list['generator_loss']) + 1)

    # Plot Generator Loss
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss_list['generator_loss'], label='Generator Loss')
    plt.plot(epochs, loss_list['discriminator_loss'], label='Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss vs Epochs')

    # Plot Generator and Discriminator Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, acc_list['generator_accuracy'], label='Generator Accuracy')
    plt.plot(epochs, acc_list['discriminator_accuracy'], label='Discriminator Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy vs Epochs')

    plt.tight_layout()
    plt.show()

# Plot graphs after training
plot_graphs()

In [None]:
# Generate and Save Sample Images
def generate_images(generator, num_images=16):
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_images, 100, 1, 1, device=device)
        generated_images = generator(noise)
        generated_images = (generated_images + 1) / 2  # Rescale to [0, 1]
    return generated_images.cpu()

# Visualize some generated images
import matplotlib.pyplot as plt

generated_images = generate_images(generator, num_images=16)
grid = torch.cat([generated_images[i] for i in range(16)], dim=2).squeeze(0).numpy()
plt.imshow(grid, cmap="gray")
plt.axis("off")
plt.show()