<a href="https://colab.research.google.com/github/hepham/GANS/blob/main/FASHION_MNIST_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import pickle as pkl
import time
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
mb_size = 100

# MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.FashionMNIST(root='./fashion_mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root='./fashion_mnist_data/', train=False, transform=transform, download=False)

# Data Loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=mb_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=mb_size, shuffle=False)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./fashion_mnist_data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./fashion_mnist_data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./fashion_mnist_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./fashion_mnist_data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./fashion_mnist_data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./fashion_mnist_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./fashion_mnist_data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./fashion_mnist_data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./fashion_mnist_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./fashion_mnist_data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./fashion_mnist_data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./fashion_mnist_data/FashionMNIST/raw



In [6]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [7]:
# build network
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)




In [8]:
G


Generator(
  (fc1): Linear(in_features=100, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)

In [7]:
D

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [9]:
criterion = nn.BCELoss() 

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)


In [10]:
def D_train(x):
    #=======================Train the discriminator=======================#
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(mb_size, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on fake
    z = Variable(torch.randn(mb_size, z_dim).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(mb_size, 1).to(device))

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY Dicrimination
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [11]:
def G_train(x):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = Variable(torch.randn(mb_size, z_dim).to(device))
    
    y = Variable(torch.ones(mb_size, 1).to(device))

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [None]:
G.load_state_dict(torch.load('/content/drive/MyDrive/model/Generator.pt'))
D.load_state_dict(torch.load('/content/drive/MyDrive/model/Dicriminator.pt'))

<All keys matched successfully>

In [None]:
n_epoch = 60
losses=[]
min =10
start_time = time.time()
for epoch in range(0, n_epoch):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))
        
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    
    dloss=torch.mean(torch.FloatTensor(D_losses)).numpy()
    gloss=torch.mean(torch.FloatTensor(G_losses)).numpy()
    if(abs(gloss-dloss)<min):
      min=abs(gloss-dloss)
      torch.save(G.state_dict(), '/content/drive/MyDrive/model1/Generator.pt')
      torch.save(D.state_dict(), '/content/drive/MyDrive/model1/Dicriminator.pt')
    losses.append((dloss,gloss)) 
    end_time = time.time()
print('Training process done! Time used: {} mins.'.format((end_time - start_time)/60))

[0/60]: loss_d: 0.793, loss_g: 2.438
[1/60]: loss_d: 0.401, loss_g: 4.653
[2/60]: loss_d: 0.582, loss_g: 3.053
[3/60]: loss_d: 0.462, loss_g: 3.219
[4/60]: loss_d: 0.489, loss_g: 3.052
[5/60]: loss_d: 0.535, loss_g: 2.752
[6/60]: loss_d: 0.710, loss_g: 2.213
[7/60]: loss_d: 0.678, loss_g: 2.257
[8/60]: loss_d: 0.758, loss_g: 2.052
[9/60]: loss_d: 0.791, loss_g: 1.888
[10/60]: loss_d: 0.829, loss_g: 1.829
[11/60]: loss_d: 0.814, loss_g: 1.833
[12/60]: loss_d: 0.820, loss_g: 1.869
[13/60]: loss_d: 0.840, loss_g: 1.846
[14/60]: loss_d: 0.861, loss_g: 1.710
[15/60]: loss_d: 0.860, loss_g: 1.696
[16/60]: loss_d: 0.851, loss_g: 1.721
[17/60]: loss_d: 0.895, loss_g: 1.634
[18/60]: loss_d: 0.883, loss_g: 1.597
[19/60]: loss_d: 0.880, loss_g: 1.615
[20/60]: loss_d: 0.883, loss_g: 1.634
[21/60]: loss_d: 0.910, loss_g: 1.531
[22/60]: loss_d: 0.946, loss_g: 1.419
[23/60]: loss_d: 0.905, loss_g: 1.551
[24/60]: loss_d: 0.957, loss_g: 1.451
[25/60]: loss_d: 0.952, loss_g: 1.456
[26/60]: loss_d: 0.993

In [None]:
def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        img = img.detach()
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')

In [None]:
with torch.no_grad():
    test_z = Variable(torch.randn(mb_size, z_dim).to(device))
    generated = G(test_z)
    sample_size=16
    rand_z = np.random.uniform(-1, 1, size=(sample_size, 100))
    rand_z = torch.from_numpy(rand_z).float()

    G.eval() # eval mode
    # generated samples
    rand_images = G(rand_z)
    view_samples(0, [rand_images])
    save_image(generated.view(generated.size(0), 1, 28, 28), '/content/drive/MyDrive/model/sample_' + '.png')


In [None]:
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()