In [24]:
import math

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [87]:
def set_seed():
    seed = 42
    torch.manual_seed(seed)

set_seed()

In [25]:
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

X_test = torch.Tensor( testset.data ) / 255.0# - 0.5
y_test = torch.Tensor( testset.targets ).long()
X_train = torch.Tensor( trainset.data ) / 255.0# - 0.5
y_train = torch.Tensor( trainset.targets ).long()

# train_dataset = TensorDataset(X_train, y_train)

train_data = DataLoader(trainset, batch_size=256, shuffle=True, drop_last=True)
test_data = DataLoader(testset, batch_size=256, shuffle=True, drop_last=True)

In [77]:
n_expansion = 16
n_latent = 16
n_channels=1

In [34]:
class SinePositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(SinePositionalEncoding, self).__init__()
        self.d_model = d_model
        
    def forward(self, digit):
        # Compute the positional encodings for the given digit
        position = torch.arange(0, 10).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2) * -(math.log(10000.0) / self.d_model))
        pe = torch.zeros(10, self.d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        # Get the positional encoding for the given digit
        pe_digit = pe[:, digit, :].unsqueeze(0)
        return pe_digit

temp = SinePositionalEncoding(5)

temp(None)


RuntimeError: The expanded size of the tensor (2) must match the existing size (3) at non-singleton dimension 1.  Target sizes: [10, 2].  Tensor sizes: [10, 3]

In [73]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.layers = nn.Sequential(
            nn.ConvTranspose2d(in_channels=n_latent, out_channels=n_expansion*8, kernel_size=4, stride=1, bias=True),
            nn.BatchNorm2d(n_expansion*8),
            nn.ReLU(),

            nn.ConvTranspose2d(in_channels=n_expansion*8, out_channels=n_expansion*4, kernel_size=4, stride=2, bias=True),
            nn.BatchNorm2d(n_expansion*4),
            nn.ReLU(),


            nn.ConvTranspose2d(in_channels=n_expansion*4, out_channels=n_expansion*2, kernel_size=4, stride=2, bias=True),
            nn.BatchNorm2d(n_expansion*2),
            nn.ReLU(),


            nn.ConvTranspose2d(in_channels=n_expansion*2, out_channels=n_expansion, kernel_size=4, stride=1, bias=True),
            nn.BatchNorm2d(n_expansion),
            nn.ReLU(),


            nn.ConvTranspose2d(in_channels=n_expansion, out_channels=1, kernel_size=4, stride=1, bias=True),
            nn.Tanh(),
        )

    def forward(self, X):
        return self.layers(X)
    
def generate_samples(batch_size=256):
    return torch.randn((batch_size, n_latent, 1, 1))

In [79]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=n_channels, out_channels=n_expansion, kernel_size=4, stride=1, padding=0, bias=True),

            nn.ReLU(),
            nn.Conv2d(in_channels=n_expansion, out_channels=n_expansion*2, kernel_size=4, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(n_expansion*2),
            nn.ReLU(),

            nn.Conv2d(in_channels=n_expansion*2, out_channels=n_expansion*4, kernel_size=4, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(n_expansion*4),
            nn.ReLU(),

            nn.Conv2d(in_channels=n_expansion*4, out_channels=n_expansion*8, kernel_size=4, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(n_expansion*8),
            nn.ReLU(),

            nn.Conv2d(in_channels=n_expansion*8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()  
        )

    def forward(self, X):
        return self.layers(X)

In [80]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [83]:
G = Generator()
D = Discriminator()

In [85]:
# Randomly initialize model weights
G.apply(weights_init)
D.apply(weights_init)





In [None]:
def train_gan(D, G, lr, batch_size, tolerance, patience, lim):
    loss_function = nn.BCELoss()
    
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)


    optimizer_G = optim.Adam(G.parameters(), lr=lr)
    optimizer_D = optim.Adam(D.parameters(), lr=lr)

    for i in range(lim):
        for X_batch, y_batch in train_data:
            pass

In [86]:
loss_function = nn.BCELoss()

demo_batch = generate_samples(batch_size=16)