In [None]:
%%capture
!pip install numpy tqdm matplotlib torch torchvision

In [None]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from PIL import Image
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as tv_data
import torchvision.transforms as transforms
import torchvision.utils as vutils
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Parameters
images_per_class = 4000  # Number of images per class
resize_to = (64, 64)  # Target size
output_dir = "./cifar10_resized_shuffled"  # Output folder

transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize the image if needed
    transforms.ToTensor()    # Convert the image to a tensor
    #transforms.Normalize((0.5,), (0.5,))  # Normalize the image
])

In [None]:
cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

In [None]:
# Group images by class
class_indices = {i: [] for i in range(10)}  # A dictionary to hold indices of each class
for idx, (_, label) in enumerate(cifar10):
    class_indices[label].append(idx)

In [None]:
selected_indices = []
for class_id, indices in class_indices.items():
    selected_indices.extend(np.random.choice(indices, images_per_class, replace=False))


In [None]:
np.random.shuffle(selected_indices)

# Subset the dataset
subset_dataset = Subset(cifar10, selected_indices)

In [None]:
# Create a DataLoader
batch_size = 128
dataloader = DataLoader(
    subset_dataset, batch_size=batch_size, shuffle=True, num_workers=2
)


In [None]:
os.makedirs(output_dir, exist_ok=True)
os.makedirs('./images', exist_ok=True)
os.makedirs('./weights', exist_ok=True)

In [None]:
from torchvision.transforms import ToPILImage

# Save all images without class labels
for idx in range(len(subset_dataset)):
    img, _ = subset_dataset[idx]  # Get the image (Tensor) and label (we discard the label)
    img_pil = transforms.ToPILImage()(img)  # Convert tensor to PIL image
    img_pil.save(os.path.join(output_dir, f"image_{idx}.png"))  # Save image

print(f"Shuffled, resized images saved in {output_dir} (20,000 images total).")

In [None]:
image_files = sorted(os.listdir(output_dir))
import matplotlib.pyplot as plt
# Display the first 10 images
plt.figure(figsize=(10, 10))
for i, img_file in enumerate(image_files[:10]):  # Adjust the range to display more/less images
    img = Image.open(os.path.join(output_dir, img_file))
    plt.subplot(5, 5, i + 1)  # Change grid size as needed
    plt.imshow(img)
    plt.axis('off')
plt.show()

In [None]:
current_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(current_device)

In [None]:
ngpu = 1   # number of GPUs 
nz = 100   #dimensionality of the noise vector

# initialises the weights of the neural network for stable training
def weights_normal_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:    
        # initialises weights with normal distribution with mean=0, standard deviation=0.02 to prevent exploding/vanishing gradients
        m.weight.data.normal_(0.0, 0.02)    
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(0.0, 0.02)
        # sets the bias of this layer=0 so that they adjust later on without any prior intervention
        m.bias.data.fill_(0)

In [None]:
class Disc_model(nn.Module):
    def __init__(self, ngpu):
        super(Disc_model, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv2d(num_channels, 64, 4, 2, 1, bias=False),  #kernel=4*4, stride=2, padding=1
            nn.LeakyReLU(0.2, inplace=True),   #alpha=0.2
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1).squeeze(1)

In [None]:
class Gen_model(nn.Module):
    def __init__(self, ngpu):
        super(Gen_model, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, num_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, input):
        output = self.main(input)
        return output

In [None]:
fixed_noise = torch.randn(128, nz, 1, 1).to(current_device)
real_label = 1
fake_label = 0

niter = 25
g_loss = []
d_loss = []

In [None]:
num_channels=3
model_Gen = Gen_model(ngpu).to(current_device)
model_Gen.apply(weights_normal_init)
model_Disc = Disc_model(ngpu).to(current_device)
model_Disc.apply(weights_normal_init)
loss_func = nn.BCELoss()

In [None]:
optimizerD = optim.Adam(model_Disc.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(model_Gen.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [None]:
for epoch in tqdm(range(niter), total = niter):
    for i, data in enumerate(dataloader, 0):
        model_Disc.zero_grad()
        device_model = data[0].to(current_device)
        batch_size = device_model.size(0)
        label = torch.full((batch_size,), real_label).to(current_device)

        output = model_Disc(device_model) # Discriminator output
        disc_error_real = loss_func(output.float(), label.float())
        disc_error_real.backward() # disc loss for real image
        D_x = output.mean().item()

        noise = torch.randn(batch_size, nz, 1, 1).to(current_device) # create noise
        fake = model_Gen(noise) # Fake image
        label.fill_(fake_label) # Fill with 0
        output = model_Disc(fake.detach())
        disc_error_fake = loss_func(output.float(), label.float()) # disc loss for fake image
        disc_error_fake.backward()
        D_G_z1 = output.mean().item()
        disc_error = disc_error_real + disc_error_fake
        optimizerD.step()
        model_Gen.zero_grad()
        label.fill_(real_label) # fill with 1
        output = model_Disc(fake.float()) # disc output
        gen_error = loss_func(output.float(), label.float())
        gen_error.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        print(f'[{epoch}/{niter}][{i}/{len(dataloader)}] Loss_D: {disc_error.item():.4f} Loss_G: {gen_error.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')

        if i % 100 == 0: # save images every 100 steps
            print('saving the output')
            vutils.save_image(device_model,'./images/real_samples.png',normalize=True)
            fake = model_Gen(fixed_noise)
            vutils.save_image(fake.detach(),'./images/fake_samples_epoch_%03d.png' % (epoch),normalize=True)
    # Save images every 2 epochs
    if epoch % 2 == 0:
        with torch.no_grad():
            fixed_fake = model_Gen(fixed_noise).detach().cpu()
        grid = vutils.make_grid(fixed_fake[:25], nrow=5)  # Arrange 25 images in a grid (5x5)
        
        # Display the grid using matplotlib
        plt.figure(figsize=(10, 10))
        plt.axis("off")
        plt.title(f"Epoch {epoch}: Generated Images")
        plt.imshow(grid.permute(1, 2, 0))  # Change dimensions from (C, H, W) to (H, W, C)
        plt.show()
        
    torch.save(model_Gen.state_dict(), 'weights/model_Gen_epoch_%d.pth' % (epoch))
    torch.save(model_Disc.state_dict(), 'weights/model_Disc_epoch_%d.pth' % (epoch))
            