In [10]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from data import ImageDataset
from training_pipeline import worker_init_fn, set_seed
from models import Generator, Discriminator
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import make_grid, save_image

random_seed = 42
set_seed(random_seed)

In [6]:
from data import split_dataset

train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

# enough to run once (splitting data)

source_path = "data/cats/Data"
destination_path = "data/cats/split_data"
split_dataset(source_path, destination_path, train_ratio, val_ratio, test_ratio, random_seed)

data\cats\split_data\train
data\cats\split_data\val
data\cats\split_data\test


In [11]:
# transform = transforms.Compose([
#     transforms.Resize((224, 224)), # to set (64, 64) by default
#     transforms.ToTensor()
# ])

img_size = 64

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

train_path = "data/cats/split_data/train"
val_path = "data/cats/split_data/val"
test_path = "data/cats/split_data/test"

train_dataset = ImageDataset(train_path, transform=transform)
val_dataset = ImageDataset(val_path, transform=transform)
test_dataset = ImageDataset(test_path, transform=transform)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 2048
n_workers = 4
prefetch_factor = 4 if n_workers > 0 else None
persistent_workers = True if n_workers > 0 else False

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor,persistent_workers=persistent_workers, worker_init_fn=worker_init_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=n_workers, pin_memory=True, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers)

## DCGAN

In [12]:
latent_dim = 100
num_epochs = 100
lr = 0.0002
beta1 = 0.5

In [13]:
netG = Generator(latent_dim).to(device)
netD = Discriminator().to(device)

criterion = nn.BCELoss()

optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))

fixed_noise = torch.randn(img_size, latent_dim, device=device)

In [14]:
for epoch in range(num_epochs):
    for i, real_imgs in enumerate(train_loader):
        real_imgs = real_imgs.to(device)
        b_size = real_imgs.size(0)

        # Real and fake labels
        real_labels = torch.ones(b_size, device=device)
        fake_labels = torch.zeros(b_size, device=device)

        ## Update Discriminator ##
        netD.zero_grad()

        output_real = netD(real_imgs)
        loss_real = criterion(output_real, real_labels)

        noise = torch.randn(b_size, latent_dim, device=device)
        fake_imgs = netG(noise)
        output_fake = netD(fake_imgs.detach())
        loss_fake = criterion(output_fake, fake_labels)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizerD.step()

        ## Update Generator ##
        netG.zero_grad()
        output = netD(fake_imgs)
        loss_G = criterion(output, real_labels)  # Try to fool the discriminator
        loss_G.backward()
        optimizerG.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Step [{i}] Loss_D: {loss_D:.4f} Loss_G: {loss_G:.4f}")

    # Save generated images
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
        save_image(fake, f"output/fake_epoch_{epoch+1:03d}.png", normalize=True)

Epoch [1/100] Step [0] Loss_D: 1.3608 Loss_G: 2.5621
Epoch [2/100] Step [0] Loss_D: 0.0799 Loss_G: 6.8304
Epoch [3/100] Step [0] Loss_D: 0.0501 Loss_G: 9.3847
Epoch [4/100] Step [0] Loss_D: 0.1133 Loss_G: 16.7177
Epoch [5/100] Step [0] Loss_D: 0.3299 Loss_G: 9.5806
Epoch [6/100] Step [0] Loss_D: 0.7714 Loss_G: 11.0277
Epoch [7/100] Step [0] Loss_D: 0.8257 Loss_G: 9.2877
Epoch [8/100] Step [0] Loss_D: 0.1186 Loss_G: 4.1836
Epoch [9/100] Step [0] Loss_D: 1.0004 Loss_G: 3.9597
Epoch [10/100] Step [0] Loss_D: 0.5586 Loss_G: 4.1883
Epoch [11/100] Step [0] Loss_D: 0.5469 Loss_G: 2.7442
Epoch [12/100] Step [0] Loss_D: 1.1111 Loss_G: 1.5368
Epoch [13/100] Step [0] Loss_D: 0.7199 Loss_G: 2.5120
Epoch [14/100] Step [0] Loss_D: 1.9684 Loss_G: 7.4636
Epoch [15/100] Step [0] Loss_D: 0.8633 Loss_G: 3.4412
Epoch [16/100] Step [0] Loss_D: 0.7176 Loss_G: 2.1885
Epoch [17/100] Step [0] Loss_D: 0.5947 Loss_G: 2.4761
Epoch [18/100] Step [0] Loss_D: 0.8008 Loss_G: 3.0216
Epoch [19/100] Step [0] Loss_D: 0.9