In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
from torchsummary import summary

import numpy as np
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


# Define Model

## Discriminator

In [19]:
class Maxout(nn.Module):
    def __init__(self, input_dim=784, output_dim=784,  k=3):
        super(Maxout, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.k = k                                                          #[batch, input_dim]
        self.linear = nn.Linear(self.input_dim, self.output_dim * k)        #[batch, output_dim * k]

    def forward(self, x):
        self.input_size = x.size()
        x = self.linear(x)                                          #below: reshape to [batch, k, output_dim] -> [batch, k], like maxpool from k
        # print(x.view(-1, self.output_dim, self.k)[0][0])          uncomment here for check value                        
        # print(x.view(-1, self.output_dim, self.k)[0][1])                                       
        x = x.view(-1, self.output_dim, self.k).max(dim=2)[0]       #not 0 index for ignore index of max method
        return x
    #from https://github.com/junhoseo0/pytorch-gan/blob/master/Maxout.py

In [20]:
maxout = Maxout(input_dim=25, output_dim=10)
rand_num = torch.randn(16, 25)
print(rand_num.shape)
result = maxout(rand_num)
print(result.shape)
print(result[0])

torch.Size([16, 25])
torch.Size([16, 10])
tensor([0.2566, 0.4292, 0.7110, 1.0207, 0.4392, 0.0957, 0.1424, 0.2280, 1.1508,
        1.0092], grad_fn=<SelectBackward0>)


In [21]:
class Generator(nn.Module):
    def __init__(self, input_dim=100, output_dim = 784):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.linear_1 = nn.Linear(self.input_dim, 1200)
        self.linear_2 = nn.Linear(1200, 1200)
        self.linear_3 = nn.Linear(1200, self.output_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = F.relu(x)
        x = self.linear_2(x)
        x = F.relu(x)
        x = self.linear_3(x)
        x = F.relu(x)

        return x

In [22]:
class Discriminator(nn.Module):
    def __init__(self, input_dim=784, output_dim = 1):
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.max_1 = Maxout(self.input_dim, 240, k=5)
        self.drop_1 = nn.Dropout1d(0.8)

        self.max_2 = Maxout(240, 240, 5)

        self.linear = nn.Linear(240, 1)

    def forward(self, x):
        x = self.max_1(x)
        x = self.drop_1(x)
        x = self.max_2(x)
        x  = self.linear(x)

        x = F.sigmoid(x)
        return x

In [23]:
mnist_train = datasets.MNIST(
    "../Datasets/MNIST_PyTorch/",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
mnist_loader = DataLoader(
    mnist_train,
    batch_size=16,
    shuffle=True,
)

# Train

In [32]:
latent_dim = 100
D = Discriminator().to(device)
G = Generator(latent_dim).to(device)

In [33]:
optimizer_D = torch.optim.SGD(D.parameters(), lr=0.1, momentum=0.5)
optimizer_G = torch.optim.SGD(G.parameters(), lr=0.1, momentum=0.5)

In [34]:
epochs = 50
criterion = torch.nn.BCELoss()

In [35]:
for epoch in range(epochs):
    d_loss = 0.0
    g_loss = 0.0
    for i, data in enumerate(mnist_loader):
        #Sample minibatch from sample and $z$
        real = data[0].to(device)
        fake = G(torch.randn(16, 1, latent_dim).to(device))



KeyboardInterrupt: 

In [36]:
print(fake.shape)

torch.Size([16, 1, 784])
