In [1]:
import torch

from catalyst import dl

from src.callbacks.gan import (
    CycleGANLoss,
    GANLoss,
    IdenticalGANLoss,
    PrepareGeneratorPhase,
    GeneratorOptimizerCallback,
    PrepareDiscriminatorPhase,
    DiscriminatorLoss,
    DiscriminatorOptimizerCallback
)
from src.callbacks.distillation import (
    HiddenStateLoss,
    TeacherStudentLoss,
)
from src.callbacks.visualization import LogImageCallback
from src.dataset import UnpairedDataset
from src.modules.generator import Generator
from src.modules.discriminator import NLayerDiscriminator, PixelDiscriminator
from src.runner import DistillRunner
from src.modules.loss import LSGanLoss

from torchvision import transforms as T

from PIL import Image



In [2]:
train_ds = UnpairedDataset(
    "./datasets/monet2photo/trainA_preprocessed",
    "./datasets/monet2photo/trainB_preprocessed",
    transforms=T.Compose([
        T.Resize((300,300)),
        T.RandomCrop((256, 256)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),

    ])

)
train_dl = torch.utils.data.DataLoader(
    train_ds,
    batch_size=1,
    shuffle=True
)

In [3]:
tr = transforms=T.Compose([
    T.Resize((256,256)),
    T.ToTensor(),
])


mipt_photo = tr(Image.open("./datasets/mipt.jpg"))
zinger_photo = tr(Image.open("./datasets/vk.jpg"))


In [4]:
from itertools import chain

model = {
    "generator_ab": Generator(3, 3, n_blocks=9),
    "generator_ba": Generator(3, 3, n_blocks=9),
    "generator_t": Generator(3, 3, n_blocks=3),
    "discriminator_a": PixelDiscriminator(3),
    "discriminator_b": PixelDiscriminator(3),
}
optimizer = {
    "generator": torch.optim.Adam(
        chain(
            model["generator_ab"].parameters(),
            model["generator_t"].parameters(),
        ),
        lr=0.0002
    ),
    "discriminator": torch.optim.Adam(
        chain(
            model["discriminator_a"].parameters(),
            model["discriminator_b"].parameters()
        ),
        lr=0.0002
    )
}
callbacks = [
    PrepareGeneratorPhase(),
    GANLoss(),
    CycleGANLoss(),
    IdenticalGANLoss(),
    GeneratorOptimizerCallback(
        keys=[
            "gan_loss", 
            "cycle_loss", 
            "identical_loss", 
            "hidden_state_loss",
            "ts_difference",
        ],
        weights=[1, 10, 5, 1, 10],
    ),
    PrepareDiscriminatorPhase(),
    DiscriminatorLoss(),
    DiscriminatorOptimizerCallback(),
    HiddenStateLoss(),
    TeacherStudentLoss(),
    LogImageCallback(model_key="generator_t"),
    LogImageCallback(key="mipt", img=mipt_photo, model_key="generator_t"),
    LogImageCallback(key="vk", img=zinger_photo, model_key="generator_t"),
]

criterion = {
    "gan": LSGanLoss(),
    "cycle": torch.nn.L1Loss(),
    "identical": torch.nn.L1Loss(),
    "hidden_state_loss": torch.nn.MSELoss(),
    "teacher_student": torch.nn.L1Loss(),
}


In [5]:
runner = DistillRunner(buffer_size=50, student_key="generator_t")

In [6]:
runner.train(
    model=model,
    optimizer=optimizer,
    loaders={"train": train_dl},
    callbacks=callbacks,
    criterion=criterion,
    num_epochs=100,
    verbose=True,
    main_metric="identical_loss"
)


Attention, there is only one dataloader - train



Early exiting                                                                                                                                                                                                                                                                
1/100 * Epoch (train):   0% 2/6287 [00:58<35:04:55, 20.09s/it, cycle_loss=0.680, discriminator_a_loss=0.334, discriminator_b_loss=0.350, discriminator_loss=0.684, gan_loss=1.467, generator_loss=26.479, hidden_state_loss=5.603, identical_loss=0.779, ts_difference=0.311]

In [7]:
runner.batch_metrics

defaultdict(None,
            {'gan_loss': tensor(2.0220, grad_fn=<AddBackward0>),
             'cycle_loss': tensor(0.9028, grad_fn=<AddBackward0>),
             'identical_loss': tensor(0.9646, grad_fn=<AddBackward0>)})