In [2]:
import math

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [3]:
def set_seed():
    seed = 42
    torch.manual_seed(seed)

set_seed()

In [4]:
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

X_test = torch.Tensor( testset.data ) / 255.0# - 0.5
y_test = torch.Tensor( testset.targets ).long()
X_train = torch.Tensor( trainset.data ) / 255.0# - 0.5
y_train = torch.Tensor( trainset.targets ).long()

# train_dataset = TensorDataset(X_train, y_train)

train_data = DataLoader(trainset, batch_size=256, shuffle=True, drop_last=True)
test_data = DataLoader(testset, batch_size=256, shuffle=True, drop_last=True)

In [5]:
n_expansion = 16
n_latent = 16
n_channels=1

In [6]:
class SinePositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(SinePositionalEncoding, self).__init__()
        self.d_model = d_model
        
    def forward(self, digit):
        # Compute the positional encodings for the given digit
        position = torch.arange(0, 10).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2) * -(math.log(10000.0) / self.d_model))
        pe = torch.zeros(10, self.d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        # Get the positional encoding for the given digit
        pe_digit = pe[:, digit, :].unsqueeze(0)
        return pe_digit

temp = SinePositionalEncoding(5)

temp(None)


RuntimeError: The expanded size of the tensor (2) must match the existing size (3) at non-singleton dimension 1.  Target sizes: [10, 2].  Tensor sizes: [10, 3]

In [10]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.layers = nn.Sequential(
            nn.ConvTranspose2d(in_channels=n_latent, out_channels=n_expansion*8, kernel_size=4, stride=1, bias=True),
            nn.BatchNorm2d(n_expansion*8),
            nn.ReLU(),

            nn.ConvTranspose2d(in_channels=n_expansion*8, out_channels=n_expansion*4, kernel_size=4, stride=2, bias=True),
            nn.BatchNorm2d(n_expansion*4),
            nn.ReLU(),


            nn.ConvTranspose2d(in_channels=n_expansion*4, out_channels=n_expansion*2, kernel_size=4, stride=2, bias=True),
            nn.BatchNorm2d(n_expansion*2),
            nn.ReLU(),


            nn.ConvTranspose2d(in_channels=n_expansion*2, out_channels=n_expansion, kernel_size=4, stride=1, bias=True),
            nn.BatchNorm2d(n_expansion),
            nn.ReLU(),


            nn.ConvTranspose2d(in_channels=n_expansion, out_channels=1, kernel_size=4, stride=1, bias=True),
            nn.Tanh(),
        )

    def forward(self, X):
        return self.layers(X)
    
def generate_samples(batch_size=256):
    return torch.randn((batch_size, n_latent, 1, 1))

In [76]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=n_channels, out_channels=n_expansion, kernel_size=4, stride=1, padding=0, bias=True),

            nn.ReLU(),
            nn.Conv2d(in_channels=n_expansion, out_channels=n_expansion*2, kernel_size=4, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(n_expansion*2),
            nn.ReLU(),

            nn.Conv2d(in_channels=n_expansion*2, out_channels=n_expansion*4, kernel_size=4, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(n_expansion*4),
            nn.ReLU(),

            nn.Conv2d(in_channels=n_expansion*4, out_channels=n_expansion*8, kernel_size=4, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(n_expansion*8),
            nn.ReLU(),

            nn.Conv2d(in_channels=n_expansion*8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()  
        )

    def forward(self, X):
        return self.layers(X)

In [77]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [78]:
from sklearn.metrics import accuracy_score

In [144]:
def train_gan(D, G, lr, batch_size, tolerance, patience, lim):
    loss_function = nn.BCELoss()
    
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True)
    
    N_train = len(trainset)
    N_test = len(testset)

    N_batches_train = N_train // batch_size
    N_batches_test = N_test // batch_size

    optimizer_G = optim.Adam(G.parameters(), lr=lr)
    optimizer_D = optim.Adam(D.parameters(), lr=lr)

    for i in range(lim):

        # Train discriminator

        train_epoch_D_loss = 0
        train_epoch_D_correct = 0
        train_epoch_D_attempted = 0

        train_epoch_G_loss = 0
        
        for X_batch0, _ in train_loader:

            optimizer_G.zero_grad()
            optimizer_D.zero_grad()

            y_batch0 = torch.zeros(batch_size)

            # Do not record generator gradients during discriminator training
            with torch.no_grad():
                X_batch1 = G(generate_samples(batch_size=batch_size))
            y_batch1 = torch.ones(batch_size)

            X_batch = torch.cat([X_batch0, X_batch1])
            y_batch = torch.cat([y_batch0, y_batch1])

            # Note: D(X_batch) has dims (batch_size, 1, 1, 1)
            y_batch_probs = torch.squeeze(D(X_batch))
            y_batch_preds = torch.round(y_batch_probs)
            
            loss_batch_D = loss_function(y_batch_probs, y_batch)

            loss_batch_D.backward()
            optimizer_D.step()

            train_epoch_D_loss += loss_batch_D.item()

            train_epoch_D_attempted += y_batch.shape[0]
            train_epoch_D_correct += accuracy_score(y_batch.detach().numpy(), y_batch_preds.detach().numpy(), normalize=False)

            print(f"Epoch 1 D: {loss_batch_D:.3f} {accuracy_score(y_batch.detach().numpy(), y_batch_preds.detach().numpy())}")

            # Train generator

            optimizer_G.zero_grad()
            optimizer_D.zero_grad()

            X_batch = G(generate_samples(batch_size=batch_size*2))
            y_batch = torch.ones(batch_size*2)

            y_batch_probs = torch.squeeze(D(X_batch))
            y_batch_preds = torch.round(y_batch_probs)

            loss_batch_G = -loss_function(y_batch_probs, y_batch)

            loss_batch_G.backward()
            optimizer_G.step()

            train_epoch_G_loss += loss_batch_G.item()

            print(f"Epoch 1 G: {loss_batch_G:.3f} {accuracy_score(y_batch.detach().numpy(), y_batch_preds.detach().numpy())}")


In [145]:
lr=.001
batch_size=256
tolerance=.01
patience=10
lim=1

hyperparams = {
    'lr': lr,
    'batch_size': batch_size,
    'tolerance': tolerance,
    'patience': patience,
    'lim': lim,
}

In [146]:
# Create models
G = Generator()
D = Discriminator()

# Randomly initialize model weights
G.apply(weights_init)
D.apply(weights_init)

Discriminator(
  (layers): Sequential(
    (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(16, 32, kernel_size=(4, 4), stride=(1, 1))
    (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU()
    (5): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Conv2d(128, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

In [147]:
train_gan(D, G, **hyperparams)

Epoch 1 D: 0.848 0.412109375
Epoch 1 G: -0.553 0.751953125
Epoch 1 D: 0.531 0.689453125
Epoch 1 G: -1.174 0.4296875
Epoch 1 D: 0.533 0.751953125
Epoch 1 G: -1.317 0.5703125
Epoch 1 D: 0.517 0.78125
Epoch 1 G: -1.076 0.66796875
Epoch 1 D: 0.367 0.876953125
Epoch 1 G: -0.949 0.69921875
Epoch 1 D: 0.320 0.876953125
Epoch 1 G: -0.906 0.671875
Epoch 1 D: 0.172 0.94140625
Epoch 1 G: -0.995 0.65625
Epoch 1 D: 0.067 0.998046875
Epoch 1 G: -1.549 0.359375
Epoch 1 D: 0.052 0.998046875
Epoch 1 G: -2.319 0.125
Epoch 1 D: 0.059 0.99609375
Epoch 1 G: -2.907 0.078125
Epoch 1 D: 0.064 0.998046875
Epoch 1 G: -3.444 0.091796875
Epoch 1 D: 0.067 1.0
Epoch 1 G: -3.936 0.150390625
Epoch 1 D: 0.066 0.994140625
Epoch 1 G: -4.153 0.158203125
Epoch 1 D: 0.050 0.998046875
Epoch 1 G: -4.291 0.22265625
Epoch 1 D: 0.029 0.994140625
Epoch 1 G: -4.357 0.23828125
Epoch 1 D: 0.022 0.994140625
Epoch 1 G: -4.520 0.236328125
Epoch 1 D: 0.015 0.994140625
Epoch 1 G: -4.702 0.240234375
Epoch 1 D: 0.009 1.0
Epoch 1 G: -4.914

ValueError: Using a target size (torch.Size([512])) that is different to the input size (torch.Size([352])) is deprecated. Please ensure they have the same size.