In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
# optimizers
from torch.optim import Adam
import os
from torchvision.utils import save_image

In [2]:
class Args:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def __repr__(self):
        return str(self.__dict__)

args = Args(
    n_epochs=200,
    batch_size=64,
    lr=0.0002,
    b1=0.5,
    b2=0.999,
    latent_dim=100,
    img_size=28,
    channels=1, 
)

img_shape = (args.channels, args.img_size, args.img_size)

print(img_shape)

def get_device():
    # m1 chip
    return torch.device("mps")

device = get_device()
print(device)


(1, 28, 28)
mps


In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *block(args.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self,z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

    


In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def block(in_feat, out_feat):
            layers = [nn.Linear(in_feat, out_feat)]
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(int(np.prod(img_shape)), 512),
            *block(512, 256),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0),-1)
        validity = self.model(img_flat)
        return validity
    

In [5]:
adversarial_loss = nn.BCELoss().to(device)

generator = Generator().to(device)
discriminator = Discriminator().to(device)

optim_G = Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2))
optim_D = Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2))


In [6]:
def get_data_loader():
    transform = transforms.Compose([transforms.Resize(args.img_size), transforms.ToTensor()])
    dataset = datasets.MNIST("../../data/mnist", train=True, download=True, transform=transform)
    return DataLoader(
        dataset, 
        batch_size=args.batch_size, 
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )

data_loader = get_data_loader()


In [7]:
try:  
    os.makedirs("../../images", exist_ok=True)
    
    Tensor = torch.FloatTensor
    for epoch in range(args.n_epochs):
        total_g_loss = 0
        total_d_loss = 0
        batch_count = 0
        for i, (imgs, _) in enumerate(data_loader):
            valid = Tensor(imgs.size(0), 1).fill_(1.0).to(device)
            fake = Tensor(imgs.size(0), 1).fill_(0.0).to(device)

            real_imgs = imgs.to(device)

            optim_G.zero_grad()

            z = torch.randn(
                imgs.size(0),
                args.latent_dim,
                ).to(device)
            
            gen_imgs = generator(z)

            d_gen_imgs = discriminator(gen_imgs)
            # print(valid.shape)
            # print(d_gen_imgs.shape)


            g_loss = adversarial_loss(d_gen_imgs, valid)

            g_loss.backward()
            optim_G.step()

            optim_D.zero_grad()
            real_loss = adversarial_loss(
                discriminator(real_imgs),
                valid
            )
            fake_loss = adversarial_loss(
                discriminator(gen_imgs.detach()),
                fake
            )

            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optim_D.step()

            total_g_loss += g_loss.item()
            total_d_loss += d_loss.item()
            batch_count += 1        
        
        save_image(gen_imgs.data, f"../../images/{epoch+1}.png")
        avg_g_loss = total_g_loss / len(data_loader)
        avg_d_loss = total_d_loss / len(data_loader)
        print(f"[Epoch {epoch+1}/{args.n_epochs}] [Avg D loss: {avg_d_loss:.4f}] [Avg G loss: {avg_g_loss:.4f}]")

except Exception as e:
    print("Error in training")
    print("z", z.shape)
    print("d_gen_imgs", d_gen_imgs.shape)
    print("valid", valid.shape)
    print("fake", fake.shape)
    raise e



[Epoch 1/200] [Avg D loss: 0.5453] [Avg G loss: 0.9380]
[Epoch 2/200] [Avg D loss: 0.4652] [Avg G loss: 1.2251]
[Epoch 3/200] [Avg D loss: 0.4145] [Avg G loss: 1.4570]
[Epoch 4/200] [Avg D loss: 0.3982] [Avg G loss: 1.5423]
[Epoch 5/200] [Avg D loss: 0.3748] [Avg G loss: 1.6742]
[Epoch 6/200] [Avg D loss: 0.3599] [Avg G loss: 1.7638]
[Epoch 7/200] [Avg D loss: 0.3402] [Avg G loss: 1.8677]
[Epoch 8/200] [Avg D loss: 0.3238] [Avg G loss: 1.9632]
[Epoch 9/200] [Avg D loss: 0.3050] [Avg G loss: 2.0596]
[Epoch 10/200] [Avg D loss: 0.2819] [Avg G loss: 2.2200]
[Epoch 11/200] [Avg D loss: 0.2707] [Avg G loss: 2.3404]
[Epoch 12/200] [Avg D loss: 0.2665] [Avg G loss: 2.3764]
[Epoch 13/200] [Avg D loss: 0.2523] [Avg G loss: 2.4558]
[Epoch 14/200] [Avg D loss: 0.2424] [Avg G loss: 2.5214]
[Epoch 15/200] [Avg D loss: 0.2375] [Avg G loss: 2.6359]
[Epoch 16/200] [Avg D loss: 0.2384] [Avg G loss: 2.6574]
[Epoch 17/200] [Avg D loss: 0.2242] [Avg G loss: 2.7053]
[Epoch 18/200] [Avg D loss: 0.2149] [Avg