In [None]:
import copy
from typing import Optional, List

from fastai.distributed import *
from fastai.vision.all import *

import torch
import torch.nn.functional as F
import torchvision.transforms as Tensor
from torch import nn, Tensor
from torchvision import transforms
import torchvision.transforms as T

from scipy.spatial import distance
import numpy as np
from PIL import Image
import requests

from models.utils.distance_loss import *
from models.utils.metrics import Accuracy

In [None]:
H = 256
W= 256
bs = 10
gd = 16

transform = T.Compose([
T.Resize((H,W)),
T.ToTensor(),
])

In [None]:
def distance_matrix(width=256, height=256, grid_l=16):

    w = width
    h = height
    qt_hor_grids = w//grid_l
    qt_ver_grids = h//grid_l
    qtd_grids = qt_hor_grids*qt_ver_grids
    c = 0
    grids = []
    for i in range(qtd_grids):
        hor_pos = i//qt_hor_grids
        ver_pos = c
        c = c+1
        grid = [hor_pos,ver_pos]
        grids.append(grid)
        if c == qt_ver_grids:
            c=0
    #gd = torch.tensor(np.array(grids))
    dist_grid = []
    for g in range(len(grids)):
        dist_pair_list = []
        for n in range(len(grids)):
            dist_pair_list.append(distance.cityblock(grids[g], grids[n]))
        dist_grid.append(dist_pair_list)

    dist_matrix = torch.tensor(np.array(dist_grid))

    return dist_matrix

In [None]:
dist_matrix = distance_matrix()
dist_matrix.shape

In [None]:
def penalty_weights(dist_matrix, penalty_factor="2", alpha=4, beta=500, gamma=0.1):
    if penalty_factor == "1":
        high = (dist_matrix.max(0, keepdim=True)[0][0]+1).reshape(256,1)
        pf_matrix = torch.div((dist_matrix+gamma),high)
        return pf_matrix
    if penalty_factor == "2":
        high = (dist_matrix.max(0, keepdim=True)[0][0]).reshape(256,1)/alpha
        a = torch.sub(dm,high)
        pf_matrix = torch.div(a,torch.sqrt(torch.square(a)+beta))
        return pf_matrix

In [None]:
pf = penalty_weights(dist_matrix)
pf.shape

In [None]:
def penalty_matrix(bs, width=256, height=256, grid_l=16, penalty_factor="2", alpha=4, beta=500, gamma=0.1):
    dist_matrix = distance_matrix(width, height, grid_l)
    pf = penalty_weights(dist_matrix, penalty_factor, alpha, beta, gamma)
    stack = []
    for i in range(bs):
        stack.append(pf)
    pm = torch.stack(stack, dim=0)
    return pm

In [None]:
pm = penalty_matrix(bs)
pm = pm.to(device)

In [None]:
#plt.subplot(221)
#plt.imshow(pm[0])#[136].reshape(16,16))

In [None]:
model_dir = Path.home()/'Luiz/saved_models/AROB'
net = load_learner(model_dir/'ARViT2D-Base_6layers.pkl', cpu=False)
weights_dir = model_dir/'best/ARViT2D-Base_6layers.pth'
model = net.model
#model.load_state_dict(torch.load(weights_dir))
weights_dict = load_learner(weights_dir, cpu=False)
model.load_state_dict(weights_dict)
#model = model.eval()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
im = Image.open('sample_images/image'+str(10)+'.jpeg')
img = transform(im).unsqueeze(0).to(device)
outputs, attn, sattn, gm  = model(img.to(device))

In [None]:
stack = []
for i in range(8):
    stack.append(gm)
gm2 = torch.stack(stack, dim=0).reshape(8,256,256)

In [None]:
if pm.shape[0]>=gm2.shape[0]:
    pm = pm[:gm2.shape[0]] 
    print(pm.shape)

In [None]:
sattn[0].shape

In [None]:
loss = pm*gm2
dist_loss = torch.sum(loss)#.float().mean()
dist_loss[dist_loss <= 1] = 1
torch.log(dist_loss)
#dist_loss

In [None]:
plt.subplot(221)
plt.imshow(pm[0].cpu().detach().numpy())#[136].reshape(16,16))
plt.subplot(222)
plt.imshow(gm2[0].cpu().detach().numpy())#[136].reshape(16,16))

In [None]:
ARViT_Loss = ARViT2D_Loss(bs,layer=1)

In [None]:
path1 = untar_data(URLs.IMAGENETTE)
def data_loader(path):
    transform = ([*aug_transforms(),Normalize.from_stats([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    data = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                     get_items=get_image_files, 
                     splitter=RandomSplitter(),
                     get_y=parent_label,
                     item_tfms=Resize(H,W),
                     batch_tfms=transform)

    dloader = data.dataloaders(path,bs=bs)
    return dloader
dloader =  data_loader(path1)

In [None]:
def model_head(model, n_classes):
    model.head = nn.Linear(516, n_classes)
    #model.noise_mode = True
    #model.generator_mode = False

    #trainable = ['head.weight','head.bias']
    #for name, p in model.named_parameters():
    #    if name not in trainable:
    #        p.requires_grad = False
    #    else:
    #        p.requires_grad = True
model_head(model,10)

In [None]:
learner = Learner(dloader, model, loss_func=ARViT_Loss, metrics=[Accuracy])

In [None]:
learner.lr_find()

In [None]:
learner.fit_one_cycle(1,0.0002)