In [1]:
from fastai.vision.all import *
from fastai.distributed import *
from fastai.vision.gan import *
from fastai.metrics import *
from fastai.callback.tracker import SaveModelCallback, ReduceLROnPlateau
from fastai import torch_core
from torch import nn

In [2]:
from models.ARViT import ARViT
from models.unet import UNet
from models.SAM import SAM
from models.utils.fastai_gan import *
from losses.attention_loss import *
from losses.sam_loss import *

In [3]:
torch.cuda.set_device("cuda:0")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
bs=10
H=W=256
nclass=10
grid_l=gm_l=16

In [5]:
path = untar_data(URLs.IMAGENETTE)
transform = ([*aug_transforms(),Normalize.from_stats([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
data = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                 get_items=get_image_files, 
                 splitter=RandomSplitter(),
                 get_y=parent_label,
                 item_tfms=Resize(H,W),
                 batch_tfms=transform)
dloader = data.dataloaders(path,bs=bs)

In [6]:
gen = UNet(n_channels=3, n_classes=3, bilinear=False)
crt = SAM(enc_Layers=6, nhead=8, nclass=nclass, bs=bs, hidden_dim=512, H=H, W=W, grid_l=grid_l, gm_patch=gm_l)

In [7]:
generator_loss = GeneratorLoss()
critic_loss = CriticLoss()

In [8]:
def Acc(preds,target): 
    if len(preds) == 2:
        fakePreds = learner.gan_trainer.critic(preds)
        _, pred = torch.max(fakePreds[0], 1)
        return (pred == target).float().mean()
    else:
        _, pred = torch.max(preds[0], 1)

        return (pred == target).float().mean()

MSE = nn.MSELoss()    
def Lrec(preds,target):
    if len(preds) == 2:
        Lrec = MSE(preds[0],preds[1]).float().mean()
    else:
        Lrec = 0.000
  
    return Lrec

LCA = Attention_loss()
def La1(preds,target,layer=0):
    if len(preds) == 2:
        fakePreds = learner.gan_trainer.critic(preds[0])
        Latt = LCA(fakePreds[1][layer], fakePreds[3])
        return (0.01*Latt).float().mean()
    else:
        Latt = LCA(preds[1][layer], preds[3])   
        return (0.01*Latt).float().mean()
    
def La2(preds,target,layer=1):
    if len(preds) == 2:
        fakePreds = learner.gan_trainer.critic(preds[0])
        Latt = LCA(fakePreds[1][layer], fakePreds[3])
        return (0.01*Latt).float().mean()
    else:
        Latt = LCA(preds[1][layer], preds[3])   
        return (0.01*Latt).float().mean()

def La3(preds,target,layer=2):
    if len(preds) == 2:
        fakePreds = learner.gan_trainer.critic(preds[0])
        Latt = LCA(fakePreds[1][layer], fakePreds[3])
        return (0.01*Latt).float().mean()
    else:
        Latt = LCA(preds[1][layer], preds[3])   
        return (0.01*Latt).float().mean()
    
def La4(preds,target,layer=3):
    if len(preds) == 2:
        fakePreds = learner.gan_trainer.critic(preds[0])
        Latt = LCA(fakePreds[1][layer], fakePreds[3])
        return (0.01*Latt).float().mean()
    else:
        Latt = LCA(preds[1][layer], preds[3])   
        return (0.01*Latt).float().mean()
    
def La5(preds,target,layer=4):
    if len(preds) == 2:
        fakePreds = learner.gan_trainer.critic(preds[0])
        Latt = LCA(fakePreds[1][layer], fakePreds[3])
        return (0.01*Latt).float().mean()
    else:
        Latt = LCA(preds[1][layer], preds[3])   
        return (0.01*Latt).float().mean()
    
def La6(preds,target,layer=5):
    if len(preds) == 2:
        fakePreds = learner.gan_trainer.critic(preds[0])
        Latt = LCA(fakePreds[1][layer], fakePreds[3])
        return (0.01*Latt).float().mean()
    else:
        Latt = LCA(preds[1][layer], preds[3])   
        return (0.01*Latt).float().mean()

c_entropy = nn.CrossEntropyLoss() 
def CrossEnt(preds,target):
    if len(preds) == 2:
        fakePreds = learner.gan_trainer.critic(preds)
        Loss = c_entropy(fakePreds[0], target)
        return (Loss).float().mean()
    else:
        Loss = c_entropy(preds[0], target)
        return (Loss).float().mean()

LM = Misdirection_loss()
def Lm(preds,target,layers=[0,1,2,3,4,5],gammas=[0.0002,0.0002,0.0002,0.0002,0.0002,0.0002]):
    if len(preds) == 2:
        fakePreds = learner.gan_trainer.critic(preds[0])
        Lm = 0.0
        for i in range(len(layers)):
            Lm = Lm + gammas[i]*LM(fakePreds[1][layers[i]], fakePreds[3])
        return (Lm).float().mean()
    else:
        Lm = 0.0
        for i in range(len(layers)):
            Lm = Lm + gammas[i]*LM(preds[1][layers[i]],preds[3])
        return (Lm).float().mean()

In [9]:
Learner._do_one_batch = __do_one_batch

In [10]:
learner = GANLearner(dloader,gen,crt,generator_loss,critic_loss,gen_first=False, metrics=[Acc,CrossEnt,Lm,Lrec,La1,La2,La3,La4,La5,La6])


In [11]:
learner.fit_one_cycle(2, 5e-5)

  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")


epoch,train_loss,valid_loss,Acc,CrossEnt,Lm,Lrec,La1,La2,La3,La4,La5,La6,gen_loss,crit_loss,time
0,3.01421,0.011771,0.420463,1.712225,0.013095,0.000388,0.084784,0.083472,0.089528,0.089856,0.095301,0.092446,0.011771,2.677783,04:23
1,2.246221,0.011511,0.590739,1.258763,0.012611,0.000158,0.081498,0.081027,0.083655,0.084983,0.089356,0.088154,0.011511,2.839994,04:20
