In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

from torchvision.utils import make_grid
from torch import nn
from pytorch_lightning import Trainer
from preprocessing.image_transform import ImageTransform
from datasets.monet import MonetDataModule
from datasets.agri import AgriDataModule
from systems.cycle_gan_system import CycleGANSystem
from network.generators import CycleGANGenerator
from network.discriminators import CycleGANDiscriminator

data_dir = './data'
domain = "domainB"

In [None]:
# Sanity Check
transform = ImageTransform(img_size=256)
batch_size = 8

dm = AgriDataModule(data_dir, transform, batch_size, domain=domain)
dm.prepare_data()
dm.setup()

dataloader = dm.test_dataloader()
base, style = next(iter(dataloader))

print('Input Shape {}, {}'.format(base.size(), style.size()))

In [None]:
temp = make_grid(base, nrow=4, padding=2).permute(1, 2, 0).detach().numpy()
temp = temp * 0.5 + 0.5
temp = temp * 255.0
temp = temp.astype(int)

fig = plt.figure(figsize=(18, 8), facecolor='w')
plt.imshow(temp)
plt.axis('off')
plt.title('Source Domain')
plt.show()

temp = make_grid(style, nrow=4, padding=2).permute(1, 2, 0).detach().numpy()
temp = temp * 0.5 + 0.5
temp = temp * 255.0
temp = temp.astype(int)

fig = plt.figure(figsize=(18, 8), facecolor='w')
plt.imshow(temp)
plt.axis('off')
plt.title('Target Domain')
plt.show()

In [None]:
def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.
    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                nn.init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            nn.init.normal_(m.weight.data, 1.0, init_gain)
            nn.init.constant_(m.bias.data, 0.0)

    net.apply(init_func)  # apply the initialization function <init_func>

In [None]:
# Config  -----------------------------------------------------------------
transform = ImageTransform(img_size=256)
batch_size = 8
lr = {
    'G': 0.0002,
    'D': 0.0002
}
epoch = 180
seed = 42
reconstr_w = 10
id_w = 2

# DataModule  -----------------------------------------------------------------
dm = AgriDataModule(data_dir, transform, batch_size, domain=domain)
viz_set = AgriDataModule(data_dir, transform, 4, domain=domain)

G_basestyle = CycleGANGenerator(filter=32)
G_stylebase = CycleGANGenerator(filter=32)
D_base = CycleGANDiscriminator(filter=32)
D_style = CycleGANDiscriminator(filter=32)

# Init Weight  --------------------------------------------------------------
for net in [G_basestyle, G_stylebase, D_base, D_style]:
    init_weights(net, init_type='normal')

# LightningModule  --------------------------------------------------------------
model = CycleGANSystem(G_basestyle, G_stylebase, D_base, D_style, lr, transform, reconstr_w, id_w, visualization_dataset=viz_set)

# Trainer  --------------------------------------------------------------
trainer = Trainer(
    logger=False,
    max_epochs=epoch,
    gpus=1,
    checkpoint_callback=False,
    reload_dataloaders_every_epoch=True,
    num_sanity_val_steps=0,  # Skip Sanity Check
)


# Train
trainer.fit(model, datamodule=dm)