In [1]:
#Name : Manahil Sarwar
#Section : AI-K
#Roll No: 21I-0293

In [10]:
#Loading Libraries
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
from torch.nn.utils import spectral_norm
import torch.optim as optim
from torchvision.utils import save_image
import os
from tqdm import tqdm

In [4]:
#Data Preprocessing
#Define transformations
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    ])

#Load CIFAR-10 training and test datasets
train_dataset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
test_dataset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)

#Function to filter dataset for cats and dogs
def filter_cats_dogs(dataset):
    targets=np.array(dataset.targets)
    mask=(targets==3)|(targets==5)
    indices=np.where(mask)[0]
    return Subset(dataset,indices)

#Filter datasets
train_subset=filter_cats_dogs(train_dataset)
test_subset=filter_cats_dogs(test_dataset)

#DataLoaders
batch_size=128
train_loader=DataLoader(train_subset,batch_size=batch_size,shuffle=True,num_workers=2)
test_loader=DataLoader(test_subset,batch_size=batch_size,shuffle=False,num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [01:20<00:00, 2120566.14it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
#Generator Class
class Generator(nn.Module):
    def __init__(self,nz=100,ngf=64,nc=3):
        super(Generator,self).__init__()
        self.main=nn.Sequential(
            nn.ConvTranspose2d(nz,ngf*4,4,1,0,bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.ConvTranspose2d(ngf,nc,4,2,1,bias=False),
            nn.Tanh()
        )
    def forward(self, input):
        return self.main(input)

In [14]:
#Discriminator Class
class MinibatchDiscrimination(nn.Module):
    def __init__(self,in_features,num_kernels,kernel_dim=10):
        super(MinibatchDiscrimination,self).__init__()
        self.num_kernels=num_kernels
        self.T=nn.Parameter(torch.randn(in_features,num_kernels*kernel_dim))

    def forward(self,x):
        M=x.mm(self.T)
        M=M.view(-1,self.num_kernels,M.size(1)//self.num_kernels)
        out=M.unsqueeze(0)-M.unsqueeze(1)
        out=torch.exp(-torch.abs(out).sum(3)) 
        out=out.sum(1)
        return torch.cat([x,out],dim=1)


class Discriminator(nn.Module):
    def __init__(self,nc=3,ndf=64):
        super(Discriminator,self).__init__()
        self.shared_conv=nn.Sequential(
            spectral_norm(nn.Conv2d(nc,ndf,4,2,1,bias=False)),
            nn.LeakyReLU(0.2,inplace=True),
            spectral_norm(nn.Conv2d(ndf,ndf*2,4,2,1,bias=False)),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2,inplace=True),
            spectral_norm(nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False)),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Flatten(),
        )
        
        self.minibatch_discrimination=MinibatchDiscrimination(in_features=8192,num_kernels=10)
        self.fc=nn.Sequential(
            nn.Linear(8202,512),  
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(512,1),
            nn.Sigmoid()
        )

    def forward(self,real_img,fake_img):
        real_feat=self.shared_conv(real_img)
        fake_feat=self.shared_conv(fake_img)
        combined=torch.cat((real_feat,fake_feat),dim=1)
        combined=self.minibatch_discrimination(combined)
        similarity=self.fc(combined)
        return similarity

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x000001C742769DA0>
Traceback (most recent call last):
  File "C:\Users\HP\AppData\Roaming\Python\Python312\site-packages\torch\utils\data\dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "C:\Users\HP\AppData\Roaming\Python\Python312\site-packages\torch\utils\data\dataloader.py", line 1435, in _shutdown_workers
    if self._persistent_workers or self._workers_status[worker_id]:
                                   ^^^^^^^^^^^^^^^^^^^^
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_workers_status'


In [16]:
#Training the Gan
criterion=nn.BCELoss()
#Labels
real_label=1.0
fake_label=0.0
#Check device
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

#Parameters
ndf=64
num_epochs=25
lr=0.0002
beta1=0.5
nz=128
#Initialize generator and discriminator
ngf=128
netG=Generator(nz,ngf,nc=3).to(device)
netD=Discriminator(nc=3,ndf=ndf).to(device)

#Initialize weights
def weights_init(m):
    classname=m.__class__.__name__
    if classname.find('Conv')!=-1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm')!=-1:
        nn.init.normal_(m.weight.data,1.0,0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

#Optimizers
optimizerD=optim.Adam(netD.parameters(),lr=lr,betas=(beta1,0.999))
optimizerG=optim.Adam(netG.parameters(),lr=lr,betas=(beta1,0.999))


#Fixed noise for generating samples
fixed_noise=torch.randn(64,nz,1,1,device=device)
#Directory to save generated images
os.makedirs('C:/Users/HP/Downloads/output_images',exist_ok=True)
#Training Loop
for epoch in range(num_epochs):
    for i, data in enumerate(tqdm(train_loader),0):
        #Update Discriminator
        netD.zero_grad()
        #Format batch
        real_images,_=data
        batch_size_curr=real_images.size(0)
        real_images=real_images.to(device)
        real_images=real_images+0.05 * torch.randn_like(real_images).to(device)

        # Real pairs:(Real,Real)
        real_pairs=(real_images,real_images)
        label_real=torch.full((batch_size_curr,1),0.9,dtype=torch.float,device=device)

        output_real=netD(*real_pairs)
        lossD_real=criterion(output_real,label_real)
        lossD_real.backward()

        #Fake pairs: (Real,Fake)
        noise=torch.randn(batch_size_curr,nz,1,1,device=device)
        fake_images=netG(noise)

        fake_images=fake_images+0.05*torch.randn_like(fake_images).to(device)
        fake_pairs=(real_images,fake_images.detach())
        label_fake=torch.full((batch_size_curr,1),fake_label,dtype=torch.float,device=device)

        output_fake=netD(*fake_pairs)
        lossD_fake=criterion(output_fake,label_fake)
        lossD_fake.backward()

        lossD=lossD_real+lossD_fake
        optimizerD.step()

        #Update Generator
        netG.zero_grad()
        #Generator wants D to output real labels for fake pairs
        label_gen=torch.full((batch_size_curr,1),real_label,dtype=torch.float,device=device)
        fake_pairs_for_G=(real_images,fake_images)
        output_gen=netD(*fake_pairs_for_G)
        lossG=criterion(output_gen,label_gen)
        lossG.backward()
        optimizerG.step()

        #Print statistics
        if i%100==0:
            print(f'Epoch [{epoch+1}/{num_epochs}] Batch {i}/{len(train_loader)} \
                  Loss D: {lossD.item():.4f}, Loss G: {lossG.item():.4f}')

    #Save generated images every epoch
    with torch.no_grad():
        fake=netG(fixed_noise).detach().cpu()
    fake=fake*0.5+0.5
    save_image(fake,f'C:/Users/HP/Downloads/output_images/fake_epoch_{epoch+1}.png', nrow=8)

Using device: cpu


  1%|▏         | 1/79 [00:17<22:09, 17.04s/it]

Epoch [1/25] Batch 0/79                   Loss D: 1.4082, Loss G: 3.0292


100%|██████████| 79/79 [02:23<00:00,  1.81s/it]
  1%|▏         | 1/79 [00:15<20:12, 15.55s/it]

Epoch [2/25] Batch 0/79                   Loss D: 0.4443, Loss G: 10.5247


100%|██████████| 79/79 [02:07<00:00,  1.61s/it]
  1%|▏         | 1/79 [00:13<17:19, 13.33s/it]

Epoch [3/25] Batch 0/79                   Loss D: 0.5045, Loss G: 6.5007


100%|██████████| 79/79 [02:04<00:00,  1.57s/it]
  1%|▏         | 1/79 [00:13<17:14, 13.26s/it]

Epoch [4/25] Batch 0/79                   Loss D: 0.3513, Loss G: 8.0415


100%|██████████| 79/79 [02:03<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:12<16:52, 12.99s/it]

Epoch [5/25] Batch 0/79                   Loss D: 0.3500, Loss G: 8.7940


100%|██████████| 79/79 [02:03<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:13<17:08, 13.19s/it]

Epoch [6/25] Batch 0/79                   Loss D: 0.5031, Loss G: 8.4168


100%|██████████| 79/79 [02:03<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:13<16:56, 13.04s/it]

Epoch [7/25] Batch 0/79                   Loss D: 0.4539, Loss G: 10.2448


100%|██████████| 79/79 [02:02<00:00,  1.55s/it]
  1%|▏         | 1/79 [00:13<16:55, 13.02s/it]

Epoch [8/25] Batch 0/79                   Loss D: 0.3625, Loss G: 10.6832


100%|██████████| 79/79 [02:03<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:13<16:58, 13.06s/it]

Epoch [9/25] Batch 0/79                   Loss D: 0.3394, Loss G: 13.9798


100%|██████████| 79/79 [02:03<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:13<17:27, 13.43s/it]

Epoch [10/25] Batch 0/79                   Loss D: 0.3352, Loss G: 14.3523


100%|██████████| 79/79 [02:02<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:13<17:31, 13.48s/it]

Epoch [11/25] Batch 0/79                   Loss D: 0.3347, Loss G: 16.8262


100%|██████████| 79/79 [02:03<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:13<17:12, 13.24s/it]

Epoch [12/25] Batch 0/79                   Loss D: 0.3288, Loss G: 17.4582


100%|██████████| 79/79 [02:04<00:00,  1.57s/it]
  1%|▏         | 1/79 [00:13<17:02, 13.10s/it]

Epoch [13/25] Batch 0/79                   Loss D: 0.3337, Loss G: 17.4468


100%|██████████| 79/79 [02:04<00:00,  1.57s/it]
  1%|▏         | 1/79 [00:13<17:02, 13.11s/it]

Epoch [14/25] Batch 0/79                   Loss D: 0.3365, Loss G: 19.9715


100%|██████████| 79/79 [02:04<00:00,  1.57s/it]
  1%|▏         | 1/79 [00:13<17:09, 13.20s/it]

Epoch [15/25] Batch 0/79                   Loss D: 0.3295, Loss G: 18.8562


100%|██████████| 79/79 [02:03<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:13<17:05, 13.15s/it]

Epoch [16/25] Batch 0/79                   Loss D: 0.3302, Loss G: 19.7295


100%|██████████| 79/79 [02:02<00:00,  1.55s/it]
  1%|▏         | 1/79 [00:13<17:15, 13.28s/it]

Epoch [17/25] Batch 0/79                   Loss D: 0.3304, Loss G: 14.9642


100%|██████████| 79/79 [02:03<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:13<17:31, 13.49s/it]

Epoch [18/25] Batch 0/79                   Loss D: 0.3274, Loss G: 18.3146


100%|██████████| 79/79 [02:04<00:00,  1.58s/it]
  1%|▏         | 1/79 [00:13<17:15, 13.27s/it]

Epoch [19/25] Batch 0/79                   Loss D: 0.3277, Loss G: 20.0148


100%|██████████| 79/79 [02:02<00:00,  1.55s/it]
  1%|▏         | 1/79 [00:13<17:01, 13.10s/it]

Epoch [20/25] Batch 0/79                   Loss D: 0.3336, Loss G: 19.0958


100%|██████████| 79/79 [02:03<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:13<17:08, 13.18s/it]

Epoch [21/25] Batch 0/79                   Loss D: 0.3267, Loss G: 22.4314


100%|██████████| 79/79 [02:03<00:00,  1.57s/it]
  1%|▏         | 1/79 [00:13<17:22, 13.36s/it]

Epoch [22/25] Batch 0/79                   Loss D: 0.3261, Loss G: 19.9452


100%|██████████| 79/79 [02:06<00:00,  1.60s/it]
  1%|▏         | 1/79 [00:13<17:09, 13.20s/it]

Epoch [23/25] Batch 0/79                   Loss D: 0.3269, Loss G: 19.6826


100%|██████████| 79/79 [02:03<00:00,  1.56s/it]
  1%|▏         | 1/79 [00:13<17:12, 13.24s/it]

Epoch [24/25] Batch 0/79                   Loss D: 0.3262, Loss G: 20.0438


100%|██████████| 79/79 [02:04<00:00,  1.58s/it]
  1%|▏         | 1/79 [00:13<17:09, 13.20s/it]

Epoch [25/25] Batch 0/79                   Loss D: 0.3275, Loss G: 20.6584


100%|██████████| 79/79 [02:02<00:00,  1.55s/it]


In [17]:
#Save the models
torch.save(netG.state_dict(),'C:/Users/HP/Downloads/netG.pth')
torch.save(netD.state_dict(),'C:/Users/HP/Downloads/netD.pth')

print("Training complete and models saved.")

Training complete and models saved.


In [19]:

with torch.no_grad():
    netG.eval()
    fake=netG(fixed_noise).detach().cpu()
    save_image(fake,f'C:/Users/HP/Downloads/output_images/testoutput.png',normalize=True)