In [3]:
from fastai.vision.all import *
from fastai.distributed import *
from fastai.vision.gan import *
from fastai.callback.tracker import SaveModelCallback
from fastai import torch_core
from fastai.metrics import *

from fastprogress import fastprogress
import torch
import argparse
from models.utils.gan_joiner import GAN
from models.utils.joiner2 import *
from models.utils.losses import *
from models.utils.metrics import *
from models.utils.misc import *
from models.unet import UNet
from models.utils.datasets import *
from models.unet import UNet

from torchvision import datasets, transforms, models
import torchvision.transforms as T
import fastai
from torch.nn.parallel import DistributedDataParallel

In [8]:
H = 320
W= 320
bs = 5
nclass = 4

beta = 0.000000
gamma = 0.0005
sigma = 1
#seed = 1234
#torch.manual_seed(seed)
#torch.cuda.manual_seed(seed)

In [9]:
path = './data/ImageNetRotation1k/'
#path = untar_data(URLs.IMAGENETTE_320)

transform = ([*aug_transforms(),Normalize.from_stats([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

data = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                 get_items=get_image_files, 
                 splitter=RandomSplitter(),
                 get_y=parent_label,
                 item_tfms=Resize(H,W),
                 batch_tfms=transform)

dloader = data.dataloaders(path,bs=bs) 

In [10]:
gen = UNet(n_channels=3, n_classes=3, bilinear=False)
crt = Joiner(num_encoder_layers = 4, nhead=4, backbone = True, num_classes = nclass, bypass=False, hidden_dim=256, 
          batch_size=bs, image_h=H, image_w=W,grid_l=4,penalty_factor="2")

In [11]:
def _accumulate(self, learn):
        print("Loss Metric")
        print(learn.yb)
        #print(learn.loss_func)
        print(self.attr)
        print(getattr(learn.loss_func, self.attr, 0))
        bs = find_bs(learn.yb)
        print(learn.to_detach(getattr(learn.loss_func, self.attr, 0))*bs)
        self.total += learn.to_detach(getattr(learn.loss_func, self.attr, 0))*bs
        self.count += bs
LossMetric.accumulate = _accumulate

In [12]:
generator_loss = GeneratorLoss(beta, gamma,sigma)
critic_loss = CriticLoss(beta,sigma)

In [13]:
def _before_fit(self):
    opt_kwargs = { 'find_unused_parameters' : DistributedTrainer.fup } if DistributedTrainer.fup is not None else {}
    self.learn.model = DistributedDataParallel(
        nn.SyncBatchNorm.convert_sync_batchnorm(self.model) if self.sync_bn else self.model,
        device_ids=[self.cuda_id], output_device=self.cuda_id, find_unused_parameters=True, **opt_kwargs)
    self.old_dls = list(self.dls)
    self.learn.dls.loaders = [self._wrap_dl(dl) for dl in self.dls]
    if rank_distrib(): self.learn.logger=noop
before_fit = _before_fit

In [14]:
class _GANModule(Module):
    "Wrapper around a `generator` and a `critic` to create a GAN."
    def __init__(self, generator=None, critic=None, gen_mode=False):
        #print("Custom GAN Module")
        if generator is not None: self.generator=generator
        if critic    is not None: self.critic   =critic
        store_attr('gen_mode')

    def forward(self, *args):
        #print(*args)
        return self.generator(*args) if self.gen_mode else self.critic(*args)

    def switch(self, gen_mode=None):
        "Put the module in generator mode if `gen_mode`, in critic mode otherwise."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
GANModule = _GANModule

In [15]:
#GANModule(generator, critic, True)

In [16]:
class _GANLoss(GANModule):
    "Wrapper around `crit_loss_func` and `gen_loss_func`"
    def __init__(self, gen_loss_func, crit_loss_func, gan_model):
        super().__init__()
        store_attr('gen_loss_func,crit_loss_func,gan_model')

    def generator(self, output, target):
        "Evaluate the `output` with the critic then uses `self.gen_loss_func`"
        fake_pred = self.gan_model.critic(output)
        self.gen_loss = self.gen_loss_func(fake_pred, target)
        return self.gen_loss

    def critic(self, real_pred, input):
        "Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.crit_loss_func`."
        #print("GANLoss - Critic Loss")
        for param in self.gan_model.generator.parameters():
            param.requires_grad_(False)
        fake = self.gan_model.generator(real_pred[4])
        fake_pred = self.gan_model.critic(fake)
        self.crit_loss = self.crit_loss_func(real_pred, input) + self.crit_loss_func(fake_pred, input)
        return self.crit_loss
GANLoss = _GANLoss

In [17]:
class _FixedGANSwitcher(Callback):
    "Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
    run_after = GANTrainer
    def __init__(self, n_crit=1, n_gen=1): store_attr('n_crit,n_gen')
    def before_train(self): self.n_c,self.n_g = 0,0

    def after_batch(self):
        "Switch the model if necessary."
        #print("After Batch")
        if not self.training: return
        if self.learn.gan_trainer.gen_mode:
            self.n_g += 1
            n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g
        else:
            #print("After batch Else")
            self.n_c += 1
            n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c
        target = n_iter if isinstance(n_iter, int) else n_iter(n_in)
        #print(target)
        #print(n_out)
        if target == n_out:
            self.learn.gan_trainer.switch()
            self.n_c,self.n_g = 0,0
FixedGANSwitcher = _FixedGANSwitcher

In [18]:
def _before_batch(self):
    "Clamp the weights with `self.clip` if it's not None, set the correct input/target."
    if self.training and self.clip is not None:
        for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
    if not self.gen_mode:
        (self.learn.xb,self.learn.yb) = (self.xb,self.yb)
GANTrainer.before_batch = _before_batch

In [19]:
def _switch(self, gen_mode=None):
    "Switch the model and loss function, if `gen_mode` is provided, in the desired mode."
    self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
    self._set_trainable()
    self.model.switch(gen_mode)
    self.loss_func.switch(gen_mode)
GANTrainer.switch = _switch

In [20]:
def __set_trainable(self):
    train_model = self.generator if     self.gen_mode else self.critic
    loss_model  = self.generator if not self.gen_mode else self.critic
    set_freeze_model(train_model, True)
    set_freeze_model(loss_model, False)
    if self.switch_eval:
        train_model.train()
        loss_model.eval()
GANTrainer._set_trainable = __set_trainable

In [21]:
def _set_freeze_model(m, rg):
    if type(m) == Joiner:
        m.paramsToUpdate()
    else:
        for p in m.parameters(): p.requires_grad_(rg)
set_freeze_model = _set_freeze_model

In [22]:
@delegates()
class _GANLearner(Learner):
    "A `Learner` suitable for GANs."
    def __init__(self, dls, generator, critic, gen_loss_func, crit_loss_func, switcher=None, gen_first=False,
                 switch_eval=True, show_img=True, clip=None, cbs=None, metrics=None, **kwargs):
        gan = GANModule(generator, critic)
        loss_func = GANLoss(gen_loss_func, crit_loss_func, gan)
        if switcher is None: switcher = FixedGANSwitcher(n_crit=5, n_gen=1)
        trainer = GANTrainer(clip=clip, switch_eval=switch_eval, gen_first=gen_first, show_img=show_img)
        cbs = L(cbs) + L(trainer, switcher)
        metrics = L(metrics) + L(*LossMetrics('gen_loss,crit_loss'))
        super().__init__(dls, gan, loss_func=loss_func, cbs=cbs, metrics=metrics, **kwargs)

    @classmethod
    def from_learners(cls, gen_learn, crit_learn, switcher=None, weights_gen=None, **kwargs):
        "Create a GAN from `learn_gen` and `learn_crit`."
        losses = gan_loss_from_func(gen_learn.loss_func, crit_learn.loss_func, weights_gen=weights_gen)
        return cls(gen_learn.dls, gen_learn.model, crit_learn.model, *losses, switcher=switcher, **kwargs)

    @classmethod
    def wgan(cls, dls, generator, critic, switcher=None, clip=0.01, switch_eval=False, **kwargs):
        "Create a WGAN from `data`, `generator` and `critic`."
        return cls(dls, generator, critic, _tk_mean, _tk_diff, switcher=switcher, clip=clip, switch_eval=switch_eval, **kwargs)

GANLearner = _GANLearner

In [23]:
def _Accuracy(preds,target): 
    #print(preds[0].shape)
    #print(preds[1].shape)
    
    if len(preds) == 2:
        #print("Generator")
        fakePreds = learner.gan_trainer.critic(preds)
        _, pred = torch.max(fakePreds[0], 1)
        #print(target.shape)
        #print(pred.shape)
        
        return (pred == target).float().mean()
    else:
        print("Critic")
        _, pred = torch.max(preds[0], 1)
        print(target.shape)
        print(pred.shape)

        return (pred == target).float().mean()

In [24]:
def _Reconstruction_Loss(preds,target,sigma=1):
    if len(preds) == 2:
        MSE = nn.MSELoss()
        Lrec = sigma*MSE(preds[0],preds[1])
    else:
        Lrec = 0.000
    
    return Lrec

In [25]:
learner = GANLearner(dloader,gen,crt,generator_loss,critic_loss,gen_first=True, metrics=[_Accuracy,_Reconstruction_Loss])

In [27]:
model_dir = Path.home()/'Luiz/saved_models'
learner.export(model_dir/'critic.pkl')

In [54]:
learner.fit_one_cycle(1, 2e-7, wd=0.)

epoch,train_loss,valid_loss,_Accuracy,_Reconstruction_Loss,gen_loss,crit_loss,time
0,2.411385,0.226484,0.2,0.225765,0.226484,2.778801,00:53


Loss Metric
(TensorCategory([2, 2, 0, 3, 2], device='cuda:0'),)
gen_loss
TensorCategory(0.2323, device='cuda:0')
TensorCategory(1.1616)
Loss Metric
(TensorCategory([2, 2, 0, 3, 2], device='cuda:0'),)
crit_loss
TensorCategory(2.7788, device='cuda:0', grad_fn=<AliasBackward>)
TensorCategory(13.8940)
Loss Metric
(TensorCategory([3, 0, 0, 0, 2], device='cuda:0'),)
gen_loss
TensorCategory(0.2330, device='cuda:0')
TensorCategory(1.1649)
Loss Metric
(TensorCategory([3, 0, 0, 0, 2], device='cuda:0'),)
crit_loss
TensorCategory(2.7788, device='cuda:0', grad_fn=<AliasBackward>)
TensorCategory(13.8940)
Loss Metric
(TensorCategory([1, 3, 3, 3, 1], device='cuda:0'),)
gen_loss
TensorCategory(0.2282, device='cuda:0')
TensorCategory(1.1412)
Loss Metric
(TensorCategory([1, 3, 3, 3, 1], device='cuda:0'),)
crit_loss
TensorCategory(2.7788, device='cuda:0', grad_fn=<AliasBackward>)
TensorCategory(13.8940)
Loss Metric
(TensorCategory([1, 2, 3, 2, 1], device='cuda:0'),)
gen_loss
TensorCategory(0.2188, device=

Loss Metric
(TensorCategory([1, 0, 0, 2, 2], device='cuda:0'),)
gen_loss
TensorCategory(0.2217, device='cuda:0')
TensorCategory(1.1083)
Loss Metric
(TensorCategory([1, 0, 0, 2, 2], device='cuda:0'),)
crit_loss
TensorCategory(2.7788, device='cuda:0', grad_fn=<AliasBackward>)
TensorCategory(13.8940)
Loss Metric
(TensorCategory([3, 1, 3, 0, 1], device='cuda:0'),)
gen_loss
TensorCategory(0.2200, device='cuda:0')
TensorCategory(1.1001)
Loss Metric
(TensorCategory([3, 1, 3, 0, 1], device='cuda:0'),)
crit_loss
TensorCategory(2.7788, device='cuda:0', grad_fn=<AliasBackward>)
TensorCategory(13.8940)
Loss Metric
(TensorCategory([2, 2, 0, 2, 1], device='cuda:0'),)
gen_loss
TensorCategory(0.2338, device='cuda:0')
TensorCategory(1.1690)
Loss Metric
(TensorCategory([2, 2, 0, 2, 1], device='cuda:0'),)
crit_loss
TensorCategory(2.7788, device='cuda:0', grad_fn=<AliasBackward>)
TensorCategory(13.8940)
Loss Metric
(TensorCategory([1, 0, 0, 1, 1], device='cuda:0'),)
gen_loss
TensorCategory(0.2233, device=

In [None]:
GanLoss = GanLossWrapper(beta=0.000000, gamma=0.005,sigma=1)

gan = GAN(num_encoder_layers = 4, nhead=4, backbone = True, num_classes = nclass, bypass=False, hidden_dim=256, 
          batch_size=bs, image_h=H, image_w=W,grid_l=4,penalty_factor="2")

In [None]:
gan.generatorSwitcher()
gan.noiseSwitcher()
print("Noise mode:", gan.noise_mode)
print("Generator mode:", gan.generator_mode)
gan.paramsToUpdate()
gan.assertParams()

In [None]:
@patch
def load(self:Learner, file, with_opt=None, device=None, **kwargs):
    print("Model load")
    if device is None and hasattr(self.dls, 'device'): device = self.dls.device
    if with_opt is None: self.opt=None
    file = join_path_file(file, self.path/self.model_dir, ext='.pth')
    load_model(file, self.model, self.opt, device=device, **kwargs)
    return self
#Learner.load = alt_load

In [None]:
critic_learner = Learner(dloader, gan, loss_func=GanLoss, metrics=[Reconstruction_Loss, Accuracy])
generator_learner = Learner(dloader, gan, loss_func=GanLoss, metrics=[Reconstruction_Loss, Accuracy])

In [None]:
critic_learner.model.generatorSwitcher()
critic_learner.model.noiseSwitcher()
critic_learner.model.paramsToUpdate()

In [None]:
print("Critic Noise mode:", critic_learner.model.noise_mode)
print("Critic Gen mode:", critic_learner.model.generator_mode)
print("Generator Noise mode:", generator_learner.model.noise_mode)
print("Generator Gen mode:", generator_learner.model.generator_mode)

In [None]:
assert critic_learner.model == generator_learner.model

In [None]:
print("Noise mode:", critic_learner.model.noise_mode)
print("Generator mode:", critic_learner.model.generator_mode)
critic_learner.lr_find()

In [None]:
generator_learner.model.generatorSwitcher()
generator_learner.model.noiseSwitcher()
generator_learner.model.paramsToUpdate()
print("Noise mode:", generator_learner.model.noise_mode)
print("Generator mode:", generator_learner.model.generator_mode)
generator_learner.model.assertParams()
generator_learner.lr_find()

In [None]:
epochs = 3
for e in range(epochs):
    
    print("Epoch", e+1)
    print("Generator training")
    assert critic_learner.model == generator_learner.model
    #Generator Training
    print("Noise mode:", generator_learner.model.noise_mode)
    print("Generator mode:", generator_learner.model.generator_mode)
    gan.paramsToUpdate()
    gan.assertParams()
    
    generator_learner.fit_one_cycle(1,0.001)
    
    print("Critit training without noised images")
    assert critic_learner.model == generator_learner.model
    gan.generatorSwitcher()
    gan.noiseSwitcher()
    gan.paramsToUpdate()
    gan.assertParams()
    print("Noise mode:", critic_learner.model.noise_mode)
    print("Generator mode:", critic_learner.model.generator_mode)
    
    critic_learner.fit_one_cycle(1,2e-6)
    
    print("Critit training with noised images")
    assert critic_learner.model == generator_learner.model
    #Critit training with noised images
    gan.noiseSwitcher()
    gan.assertParams()
    print("Noise mode:", critic_learner.model.noise_mode)
    print("Generator mode:", critic_learner.model.generator_mode)
    critic_learner.fit_one_cycle(1,2e-6)
    gan.generatorSwitcher()

In [None]:
def gan_create_opt(self):
    self.opt = [self.opt_func(self.splitter(self.model), lr=self.lr),self.opt_func(self.splitter(self.model), lr=self.lr)]
    if not self.wd_bn_bias:
        for p in self._bn_bias_state(True ): p['do_wd'] = False
    if self.train_bn:
        for p in self._bn_bias_state(False): p['force_train'] = True
Learner.create_opt = gan_create_opt

In [None]:
critic_learner.create_opt()

In [None]:
#critic_learner.model.model.encoder.encoder.layers[3].self_attn.out_proj.weight == generator_learner.model.model.encoder.encoder.layers[3].self_attn.out_proj.weight