In [40]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator,Generator,initialize_weights

In [41]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

In [42]:
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIMS = 100
NUM_EPOCHS = 10
FEATURES_DISC = 64 
FEATURES_GEN = 64


In [43]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
        [0.5 for _ in range(CHANNELS_IMG)],[0.5 for _ in range(CHANNELS_IMG)])
    ]
)

In [44]:
dataset = datasets.MNIST(root="dataset/",train = True,transform = transforms , download = True)

In [45]:
loader = DataLoader(dataset,batch_size = BATCH_SIZE,shuffle = True)
gen = Generator(Z_DIMS,CHANNELS_IMG,FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG,FEATURES_DISC).to(device)
initialize_weights(disc)
initialize_weights(gen)

In [46]:
opt_gen = optim.Adam(gen.parameters(),lr = LEARNING_RATE,betas=(0.5,0.999))
opt_disc =optim.Adam(disc.parameters(), lr = LEARNING_RATE,betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [47]:
fixed_noise = torch.randn(32,Z_DIMS,1,1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")

In [48]:
step =0

In [49]:
gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx , (real, _) in enumerate(loader):
        
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE,Z_DIMS,1,1).to(device)
        fake = gen(noise)
        
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real , torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        
        loss_disc = (loss_disc_real + loss_disc_fake)/2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()
        
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        if batch_idx %100 == 0 :
            print(
                f"EPOCH[{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                 Loss D : {loss_disc:.4f}  Loss G : {loss_gen:.4f}" 
            )
            with torch.no_grad():
                fake = gen(fixed_noise)
                img_grid_real= torchvision.utils.make_grid(
                     real[:32],normalize = True
                 )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32],normalize = True
                )
                writer_real.add_image("Real",img_grid_real,global_step=step)
                writer_fake.add_image("fake",img_grid_fake,global_step= step)
                
            step +=1
        

EPOCH[0/10] Batch 0/469                  Loss D : 0.6996  Loss G : 0.7675
EPOCH[0/10] Batch 100/469                  Loss D : 0.0141  Loss G : 4.2005
EPOCH[0/10] Batch 200/469                  Loss D : 0.5644  Loss G : 1.0788
EPOCH[0/10] Batch 300/469                  Loss D : 0.4615  Loss G : 0.7331
EPOCH[0/10] Batch 400/469                  Loss D : 0.5493  Loss G : 1.3310
EPOCH[1/10] Batch 0/469                  Loss D : 0.5570  Loss G : 0.9117
EPOCH[1/10] Batch 100/469                  Loss D : 0.9003  Loss G : 0.4713
EPOCH[1/10] Batch 200/469                  Loss D : 0.6087  Loss G : 1.2770
EPOCH[1/10] Batch 300/469                  Loss D : 0.5750  Loss G : 0.4402
EPOCH[1/10] Batch 400/469                  Loss D : 0.5538  Loss G : 0.7252
EPOCH[2/10] Batch 0/469                  Loss D : 0.5300  Loss G : 0.8303
EPOCH[2/10] Batch 100/469                  Loss D : 0.5705  Loss G : 1.0856
EPOCH[2/10] Batch 200/469                  Loss D : 0.4916  Loss G : 1.4102
EPOCH[2/10] Batch 