In [1]:
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torch.utils.data import DataLoader
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import random

In [2]:
seed = 47
random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fb51c0c8b90>

In [3]:
device = torch.device('cpu') if not torch.cuda.is_available() else torch.device("cuda")

In [4]:
fmnist = FashionMNIST(root="./data", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))

In [5]:
print(*fmnist.classes)

T-shirt/top Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag Ankle boot


In [6]:
ix = (fmnist.targets == fmnist.class_to_idx["Sneaker"]) | (fmnist.targets == fmnist.class_to_idx["Shirt"])

In [7]:
fmnist.data = fmnist.data[ix]
fmnist.targets = fmnist.targets[ix]

In [8]:
fmnist.data.shape, fmnist.targets

(torch.Size([12000, 28, 28]), tensor([7, 7, 6,  ..., 6, 6, 7]))

In [9]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d( 100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d( 256, 128, 4, 2, 2, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d( 128, 1, 4, 2, 1),
            nn.Sigmoid(),
        )

    def forward(self, X):
        return self.network(X)
    

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.network = nn.Sequential(
            nn.Conv2d(1, 128, 5),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 3, 2, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, 3, 2, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Flatten(),
            nn.Linear(512*5*5, 1),
            nn.Sigmoid()
        )

    def forward(self, X):
        return self.network(X)

In [10]:
criterion = nn.BCELoss()
fixed_noise = torch.randn(64, 100, 1, 1, device=device)

gen = Generator()
gen.to(device)
dis = Discriminator()
dis.to(device)


optimizerD = torch.optim.Adam(dis.parameters())
optimizerG = torch.optim.Adam(gen.parameters())

writer = SummaryWriter("./runs/exp2")

dataloader = DataLoader(fmnist, batch_size=64, shuffle=True)
n_epochs = 20
k = 3

In [11]:
for epoch in range(n_epochs):
    print(f"Epoch: [{epoch+1}/{n_epochs}]")
    for i, (x, y) in enumerate(dataloader, 1):
        
        optimizerD.zero_grad()
        
        real_x = x.to(device)
        
        b_size = real_x.size(0)
        label = torch.full((b_size,1), 1, dtype=torch.float, device=device)
        
        output = dis(real_x)
        err_real_D = criterion(output, label)
        err_real_D.backward()
        
        noise = torch.randn(b_size, 100, 1, 1, device=device)
        fake_x = gen(noise)
        label.fill_(0)
        output = dis(fake_x.detach())
        err_fake_D = criterion(output, label)
        err_fake_D.backward()
        
        err_D = err_real_D + err_fake_D
        optimizerD.step()
        
        optimizerG.zero_grad()
        
        label.fill_(1)
        for j in range(k):
            noise = torch.randn(b_size, 100, 1, 1, device=device)
            fake_x = gen(noise)
            output = dis(fake_x)
            err_G = criterion(output, label)
            err_G.backward()

            optimizerG.step()
        
        if i%10 == 0:
            print("\rStep: [%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f"%(i, len(dataloader), err_D.item(), err_G.item()), end="")
            writer.add_scalar('Loss/Discriminator', err_D.item(), epoch*len(dataloader)+i)
            writer.add_scalar('Loss/Generator', err_G.item(), epoch*len(dataloader)+i)
            
            if i%30 == 0:
                
                with torch.no_grad():
                    fake_imgs = gen(fixed_noise).detach().cpu()
                
                grid = torchvision.utils.make_grid(fake_imgs)
                writer.add_image('Generated Image', grid, epoch*len(dataloader)+i)
                
                if epoch == 0:
                    grid = torchvision.utils.make_grid(real_x)
                    writer.add_image('Real Image', grid, epoch*len(dataloader)+i)

    print()
    
    torch.save(gen.state_dict(), "./generator.pth")
    torch.save(dis.state_dict(), "./discriminator.pth")
    writer.close()

Epoch: [1/20]
Step: [180/188]	Loss_D: 0.0007	Loss_G: 8.0011
Epoch: [2/20]
Step: [180/188]	Loss_D: 0.0206	Loss_G: 7.05365
Epoch: [3/20]
Step: [180/188]	Loss_D: 0.5378	Loss_G: 7.9691
Epoch: [4/20]
Step: [180/188]	Loss_D: 0.7044	Loss_G: 5.3533
Epoch: [5/20]
Step: [180/188]	Loss_D: 1.2569	Loss_G: 3.8079
Epoch: [6/20]
Step: [180/188]	Loss_D: 0.6442	Loss_G: 5.3688
Epoch: [7/20]
Step: [180/188]	Loss_D: 2.2146	Loss_G: 2.3744
Epoch: [8/20]
Step: [180/188]	Loss_D: 0.3025	Loss_G: 3.3341
Epoch: [9/20]
Step: [180/188]	Loss_D: 0.6280	Loss_G: 3.6359
Epoch: [10/20]
Step: [180/188]	Loss_D: 1.1844	Loss_G: 3.8985
Epoch: [11/20]
Step: [180/188]	Loss_D: 1.7398	Loss_G: 2.7523
Epoch: [12/20]
Step: [180/188]	Loss_D: 0.7605	Loss_G: 3.3881
Epoch: [13/20]
Step: [180/188]	Loss_D: 1.7704	Loss_G: 0.9365
Epoch: [14/20]
Step: [180/188]	Loss_D: 0.7932	Loss_G: 2.5801
Epoch: [15/20]
Step: [180/188]	Loss_D: 0.5246	Loss_G: 2.8098
Epoch: [16/20]
Step: [180/188]	Loss_D: 1.1541	Loss_G: 4.2631
Epoch: [17/20]
Step: [180/188]	L