In [1]:
import torch
import torch.nn as nn
import torch.optim as opts
import numpy as np

from tensorflow.keras import datasets as dts
from plasma import gan_layers as layers, commons
from plasma.training import trainer, gan_callbacks as callbacks, data

In [2]:
(x_train, _), _ = dts.fashion_mnist.load_data()

In [3]:
x_train = x_train / 127.5 - 1

x_train.shape

(60000, 28, 28)

In [4]:
batch_size = 32

In [5]:
class Data(data.StandardDataset):
    
    def get_len(self):
        return x_train.shape[0]
    
    def get_item(self, idx):
        return x_train[idx, None]

In [6]:
d = nn.Sequential(
    layers.ScaleCon(1, 64, kernel_size=1, padding=0),
    nn.LeakyReLU(0.2),
    layers.ScaleCon(64, 128),
    nn.LeakyReLU(0.2),
    nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=True),
    layers.ScaleCon(128, 256),
    nn.LeakyReLU(0.2),
    nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=True),
    layers.MiniBatchStd(),
    layers.ScaleCon(257, 512),
    nn.LeakyReLU(0.2),
    layers.ScaleCon(512, 512, kernel_size=7, padding=0),
    commons.Reshape(-1),
    layers.ScaleLinear(512, 1)
).cuda(0)

d

Sequential(
  (0): ScaleCon(in_channels=1, out_channels=64, kernel_size=1, stride=1, padding=0, bias=True)
  (1): LeakyReLU(negative_slope=0.2)
  (2): ScaleCon(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
  (3): LeakyReLU(negative_slope=0.2)
  (4): Upsample(scale_factor=0.5, mode=bilinear)
  (5): ScaleCon(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
  (6): LeakyReLU(negative_slope=0.2)
  (7): Upsample(scale_factor=0.5, mode=bilinear)
  (8): MiniBatchStd()
  (9): ScaleCon(in_channels=257, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True)
  (10): LeakyReLU(negative_slope=0.2)
  (11): ScaleCon(in_channels=512, out_channels=512, kernel_size=7, stride=1, padding=0, bias=True)
  (12): Reshape(shape=(-1,))
  (13): ScaleLinear(in_channels=512, out_channels=1, bias=True)
)

In [7]:
class Generator(nn.Sequential):
    
    def forward(self, x=None, batch_size=32):
        x = x or torch.randn(batch_size, 128, device="cuda:0")
        
        return super().forward(x)

g = Generator(
    layers.ScaleLinear(128, 7 * 7 * 512),
    nn.LeakyReLU(0.2),
    commons.Reshape(-1, 7, 7),
    nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
    layers.ScaleCon(512, 256),
    nn.LeakyReLU(0.2),
    nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
    layers.ScaleCon(256, 128),
    nn.LeakyReLU(0.2),
    layers.ScaleCon(128, 1, kernel_size=1, padding=0),
    nn.Tanh()
).cuda(0)

g

Generator(
  (0): ScaleLinear(in_channels=128, out_channels=25088, bias=True)
  (1): LeakyReLU(negative_slope=0.2)
  (2): Reshape(shape=(-1, 7, 7))
  (3): Upsample(scale_factor=2.0, mode=bilinear)
  (4): ScaleCon(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
  (5): LeakyReLU(negative_slope=0.2)
  (6): Upsample(scale_factor=2.0, mode=bilinear)
  (7): ScaleCon(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
  (8): LeakyReLU(negative_slope=0.2)
  (9): ScaleCon(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
  (10): Tanh()
)

In [8]:
d_opt = opts.Adam(d.parameters(), lr=1e-3, betas=(0, 0.99))
g_opt = opts.Adam(g.parameters(), lr=1e-3, betas=(0, 0.99))

In [9]:
train = Data()

In [10]:
trainer = trainer.GANTrainer(d, g, d_opt, g_opt, nn.BCEWithLogitsLoss(), r1=10, x_device="cuda:0")

In [11]:
cbs = [
    callbacks.GenImage()
]

In [None]:
trainer.fit(train, epochs=100, callbacks=cbs)

epoch  1 / 100


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))