In [196]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image

In [197]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [198]:
# Custom dataset for Fashion MNIST
def load_mnist_images(filepath):
    with open(filepath, 'rb') as f:
        f.read(16)  # Skip the magic number and dimensions information
        data = np.frombuffer(f.read(), dtype=np.uint8)
        return data.reshape(-1, 28, 28)

In [199]:
class CustomFashionDataset(Dataset):
    def __init__(self, images_filepath, transform=None):
        self.images = load_mnist_images(images_filepath)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        image = Image.fromarray(image, mode='L')  # Convert numpy array to PIL Image
        if self.transform:
            image = self.transform(image)
        return image

In [200]:
# Hyperparameters
latent_dim = 128
batch_size = 128
learning_rate = 0.0002
num_epochs = 20
channels_img = 1


In [201]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

In [202]:
data_directory = "/project/jacobcha/nk643/gans/data"
train_images_filepath = os.path.join(data_directory, 'train-images-idx3-ubyte')
dataset = CustomFashionDataset(train_images_filepath, transform=transform)
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

In [224]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 7 * 7 * 256, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 5, 2, padding=2, output_padding=1, bias=False),  # Upsample to 14x14
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 5, 2, padding=2, output_padding=1, bias=False),  # Upsample to 28x28
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels_img, 5, 1, padding=2, bias=False),  # Output 28x28
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(x.size(0), 256, 7, 7)  # Reshape to 3D tensor
        return self.model(x)


In [225]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels_img, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(6272, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.model(x).squeeze(1)
        print(x.shape) 
        return x


In [226]:
# Instantiate models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [227]:
# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [228]:
# Loss function
criterion = nn.BCELoss()

In [229]:
# Training loop
for epoch in range(num_epochs):
    for batch_idx, real in enumerate(train_loader):
        real = real.to(device)
        batch_size = real.size(0)

        # Discriminator Training
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        
        fake = generator(noise)
        labels_real = torch.ones(batch_size, device=device)
        labels_fake = torch.zeros(batch_size, device=device)

        outputs_real = discriminator(real)
        loss_real = criterion(outputs_real, labels_real)

        outputs_fake = discriminator(fake.detach())
        loss_fake = criterion(outputs_fake, labels_fake)

        loss_D = (loss_real + loss_fake) / 2
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # Generator Training
        outputs_fake = discriminator(fake)
        loss_G = criterion(outputs_fake, labels_real)

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(train_loader)} \
                  Loss D: {loss_D:.4f}, loss G: {loss_G:.4f}")

    # Generate and save images
    with torch.no_grad():
        fake_images = generator(torch.randn(batch_size, latent_dim, device=device)).detach().cpu()
    save_image(fake_images, f"/project/jacobcha/nk643/gans/output/dcgan/epoch_{epoch}.png", nrow=12)

RuntimeError: shape '[128, 256, 7, 7]' is invalid for input of size 16384

In [223]:
# Save models
torch.save(generator.state_dict(), '/project/jacobcha/nk643/gans/checkpoints/dcgan/generator.pth')
torch.save(discriminator.state_dict(), '/project/jacobcha/nk643/gans/checkpoints/dcgan/discriminator.pth')