In [1]:
import torch
import torch.nn as nn
import torch.optim as opts
import numpy as np

from tensorflow.keras import datasets as dts
from torch_modules import commons as commons
from torch_modules.training.trainer import GANTrainer
from torch_modules.training import callbacks, data

In [2]:
(x_train, _), _ = dts.fashion_mnist.load_data()

In [3]:
batch_size = 32

In [4]:
class Data(data.StandardDataset):
    
    def get_len(self):
        return x_train.shape[0]
    
    def get_item(self, idx):
        return x_train[idx, None] / 127.5 - 1

In [5]:
d = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=3, padding=1),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2),
    nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2),
    nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2),
    commons.Reshape(7 * 7 * 256),
    nn.Linear(7 * 7 * 256, 1),
).cuda(0)

In [6]:
class Generator(nn.Sequential):
    
    def forward(self, x=None):
        x = x or torch.randn(32, 128, device="cuda:0")
        
        return super().forward(x)

g = Generator(
    nn.Linear(128, 7 * 7 * 256),
    nn.BatchNorm1d(7 * 7 * 256),
    nn.ReLU(),
    commons.Reshape(256, 7, 7),
    nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Conv2d(64, 1, 1),
    nn.Tanh()
).cuda(0)

In [7]:
opt = opts.Adam([
    {"params": d.parameters()},
    {"params": g.parameters()}
], lr=1e-3, betas=(0, 0.999))

In [8]:
train = Data()

In [9]:
trainer = GANTrainer(d, g, opt)

In [10]:
trainer.fit(train, epochs=100)

Epoch:  1 / 100


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))




KeyboardInterrupt: 