In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from tqdm import tqdm

from CGAN import Discriminator, Generator

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 0.0002
BATCH_SIZE = 64
EPOCHS = 100
BETAS = (0.5, 0.999)
NUM_CLASSES = 10
IMG_SIZE = 32
IMG_CHANNELS = 1
IMG_SHAPE = (IMG_CHANNELS, IMG_SIZE, IMG_SIZE)
Z_DIM = 100
EMBED_SIZE = 100


In [None]:
dataloader = DataLoader(
    datasets.MNIST(
        root = '../datasets',
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.5 for _ in range(IMG_CHANNELS)], [0.5 for _ in range(IMG_CHANNELS)])
        ])
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [None]:
generator = Generator(NUM_CLASSES, Z_DIM, IMG_SHAPE).to(DEVICE)
discriminator = Discriminator(NUM_CLASSES, IMG_SHAPE).to(DEVICE)

In [None]:
optimizer_g = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=BETAS)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=BETAS)

In [None]:
adversarial_loss = nn.MSELoss().to(DEVICE)

In [None]:
for epoch in range(EPOCHS):
    with tqdm(dataloader, unit='batch') as t:
        t.set_description(f'Epoch {epoch}')
        for i, (imgs, labels) in enumerate(t):
            valid = torch.ones(imgs.size(0), 1, requires_grad=False, device=DEVICE)
            fake = torch.zeros(imgs.size(0), 1, requires_grad=False, device=DEVICE)

            real_img = imgs.to(DEVICE)
            real_label = labels.to(DEVICE)

            optimizer_g.zero_grad()
            z = torch.rand(imgs.shape[0], Z_DIM, device=DEVICE)
            fake_label = torch.randint(0, NUM_CLASSES, (imgs.size(0),), device=DEVICE)
            gen_img = generator(z, fake_label)
            validity = discriminator(gen_img, fake_label)
            g_loss = adversarial_loss(validity, valid)
            g_loss.backward()
            optimizer_g.step()

            # train discriminator
            optimizer_d.zero_grad()
            d_real_loss = adversarial_loss(discriminator(real_img, real_label), valid)
            d_fake_loss = adversarial_loss(discriminator(gen_img.detach(), fake_label), fake)
            d_loss = (d_real_loss + d_fake_loss)/2
            d_loss.backward()
            optimizer_d.step()

            t.set_postfix(D_loss=d_loss.item(), G_loss=g_loss.item())
            

In [None]:
a = torch.randn(1, 100)
b = torch.tensor([9], dtype=torch.int64)
# b = torch.randint(0, NUM_CLASSES, (1,))
generator.eval()
generator.to('cpu')
hi = generator(a, b)
abab = hi.detach().numpy().squeeze()
import matplotlib.pyplot as plt

plt.imshow(abab, cmap='gray')

In [None]:
# generator.train()
# generator.to(DEVICE)