In [None]:
from fastai.vision.all import *
from fastai.distributed import *
from fastai.vision.gan import *
from fastai.callback.tracker import SaveModelCallback
from fastai import torch_core

from fastprogress import fastprogress
import torch
import argparse
from models.utils.gan_joiner import GAN
from models.utils.losses import *
from models.utils.metrics import *
from models.utils.misc import *
from models.unet import UNet
from models.utils.datasets import *

from torchvision import datasets, transforms, models
import torchvision.transforms as T

In [None]:
H = 320
W= 320
bs = 5
nclass = 10
#seed = 1234
#torch.manual_seed(seed)
#torch.cuda.manual_seed(seed)

In [None]:
#path = './data/ImageNetRotation1k/'
path = untar_data(URLs.IMAGENETTE_320)

transform = ([*aug_transforms(),Normalize.from_stats([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

data = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                 get_items=get_image_files, 
                 splitter=RandomSplitter(),
                 get_y=parent_label,
                 item_tfms=Resize(H,W),
                 batch_tfms=transform)

dloader = data.dataloaders(path,bs=bs) 

In [None]:
GanLoss = GanLossWrapper(beta=0.000000, gamma=0.005,sigma=1)

gan = GAN(num_encoder_layers = 4, nhead=4, backbone = True, num_classes = nclass, bypass=False, hidden_dim=256, 
          batch_size=bs, image_h=H, image_w=W,grid_l=4,penalty_factor="2")

In [None]:
gan.generatorSwitcher()
gan.noiseSwitcher()
print("Noise mode:", gan.noise_mode)
print("Generator mode:", gan.generator_mode)
gan.paramsToUpdate()
gan.assertParams()

In [None]:
@patch
def load(self:Learner, file, with_opt=None, device=None, **kwargs):
    print("Model load")
    if device is None and hasattr(self.dls, 'device'): device = self.dls.device
    if with_opt is None: self.opt=None
    file = join_path_file(file, self.path/self.model_dir, ext='.pth')
    load_model(file, self.model, self.opt, device=device, **kwargs)
    return self
#Learner.load = alt_load

In [None]:
critic_learner = Learner(dloader, gan, loss_func=GanLoss, metrics=[Reconstruction_Loss, Accuracy])
generator_learner = Learner(dloader, gan, loss_func=GanLoss, metrics=[Reconstruction_Loss, Accuracy])

In [None]:
critic_learner.model.generatorSwitcher()
critic_learner.model.noiseSwitcher()
critic_learner.model.paramsToUpdate()

In [None]:
print("Critic Noise mode:", critic_learner.model.noise_mode)
print("Critic Gen mode:", critic_learner.model.generator_mode)
print("Generator Noise mode:", generator_learner.model.noise_mode)
print("Generator Gen mode:", generator_learner.model.generator_mode)

In [None]:
assert critic_learner.model == generator_learner.model

In [None]:
print("Noise mode:", critic_learner.model.noise_mode)
print("Generator mode:", critic_learner.model.generator_mode)
critic_learner.lr_find()

In [None]:
generator_learner.model.generatorSwitcher()
generator_learner.model.noiseSwitcher()
generator_learner.model.paramsToUpdate()
print("Noise mode:", generator_learner.model.noise_mode)
print("Generator mode:", generator_learner.model.generator_mode)
generator_learner.model.assertParams()
generator_learner.lr_find()

In [None]:
epochs = 3
for e in range(epochs):
    
    print("Epoch", e+1)
    print("Generator training")
    assert critic_learner.model == generator_learner.model
    #Generator Training
    print("Noise mode:", generator_learner.model.noise_mode)
    print("Generator mode:", generator_learner.model.generator_mode)
    gan.paramsToUpdate()
    gan.assertParams()
    
    generator_learner.fit_one_cycle(1,0.001)
    
    print("Critit training without noised images")
    assert critic_learner.model == generator_learner.model
    gan.generatorSwitcher()
    gan.noiseSwitcher()
    gan.paramsToUpdate()
    gan.assertParams()
    print("Noise mode:", critic_learner.model.noise_mode)
    print("Generator mode:", critic_learner.model.generator_mode)
    
    critic_learner.fit_one_cycle(1,2e-6)
    
    print("Critit training with noised images")
    assert critic_learner.model == generator_learner.model
    #Critit training with noised images
    gan.noiseSwitcher()
    gan.assertParams()
    print("Noise mode:", critic_learner.model.noise_mode)
    print("Generator mode:", critic_learner.model.generator_mode)
    critic_learner.fit_one_cycle(1,2e-6)
    gan.generatorSwitcher()

In [None]:
def gan_create_opt(self):
    self.opt = [self.opt_func(self.splitter(self.model), lr=self.lr),self.opt_func(self.splitter(self.model), lr=self.lr)]
    if not self.wd_bn_bias:
        for p in self._bn_bias_state(True ): p['do_wd'] = False
    if self.train_bn:
        for p in self._bn_bias_state(False): p['force_train'] = True
Learner.create_opt = gan_create_opt

In [None]:
critic_learner.create_opt()

In [None]:
#critic_learner.model.model.encoder.encoder.layers[3].self_attn.out_proj.weight == generator_learner.model.model.encoder.encoder.layers[3].self_attn.out_proj.weight