In [10]:
from fastai.vision.all import *
from fastai.distributed import *
from fastai.metrics import error_rate
from fastai.callback.tracker import SaveModelCallback
import argparse

from torchvision import datasets, transforms, models
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import os
import copy
import torchvision.transforms as T
import torch
from torch.nn.parallel import DistributedDataParallel
from torchvision.transforms.functional import *

from PIL import Image
import requests

import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from torch import nn

import argparse
from models.utils.joiner3 import ImageNetJoiner
from models.utils.new_losses import *
from models.utils.metrics import *
from models.utils.dataLoader import *
from models.utils.datasets import *

In [11]:
torch.cuda.set_device("cuda:3")
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

In [12]:
H = 256
W= 256
bs = 5

In [13]:
path = untar_data(URLs.FLOWERS)
Path.BASE_PATH = path
path.ls()
df = pd.read_csv('data/flowers.csv')

transform = ([*aug_transforms(),Normalize.from_stats([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

def get_x(r): return path/r['name']
def get_y(r): return r['class']

dblock = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                   n_inp=1,
                   splitter=RandomSplitter(seed=42),
                   get_x= get_x,
                   get_y= get_y, 
                   item_tfms = Resize(256),
                   #batch_tfms= Normalize.from_stats(*imagenet_stats)
                   batch_tfms= transform
                  )

dloader = dblock.dataloaders(df, bs=bs)

In [14]:
def paths(layer = 1, dataset = 'FLOWERS', base_model=False, best_model=False, pf="3",gm16=False): 
    model_dir = Path.home()/'Luiz/saved_models'
    if base_model == True:
        model_path = model_dir/'GramImageNet_Rotation_16x16grid_epochs-90_BaseModel.pkl'
        layer = 'BASE_MODEL'
    else:
        if gm16 == True:
            model_name = 'GramImageNet_Rotation_16x16grid_epochs-90-beta-5e-7_PenaltyFactor'+pf+'_Layer'+str(layer)+'_gm16.pkl'
        else:
            model_name = 'GramImageNet_Rotation_16x16grid_epochs-90-beta-5e-6_PenaltyFactor'+pf+'_Layer'+str(layer)+'.pkl'
        model_path = model_dir/model_name
        
    if best_model == False:
        weight_dir = Path.home()/'Luiz/gan_attention/models/finetuned/'
        if gm16 == True:
            file_name = 'LAYER_'+str(layer)+'_LOSS_'+pf+'__FineTuned__'+dataset+'__BestWeights_gm16.pth'
        else:
            file_name = 'LAYER_'+str(layer)+'_LOSS_'+pf+'__FineTuned__'+dataset+'__BestWeights.pth'
        weights_path = weight_dir/file_name
    else:
        weight_dir = Path.home()/'Luiz/gan_attention/models/finetuned/'
        if gm16 == True:
            file_name = 'LAYER_'+str(layer)+'_LOSS_'+pf+'__BestModel__FineTuned__'+dataset+'__BestWeights_gm16.pth'
        else:
            file_name = 'LAYER_'+str(layer)+'_LOSS_'+pf+'__BestModel__FineTuned__'+dataset+'__BestWeights.pth'
        weights_path = weight_dir/file_name
           
    return model_path, weights_path

In [15]:
def get_gm(r):
    label = parent_label(r)
    a = attrgetter("name")
    rgex = RegexLabeller(pat = r'image(.*?).jpeg') 
    gm = torch.load(save_path+"/gramm/"+str(label)+"/gm"+rgex(a(r))+".pt")
    return gm, TensorCategory(int(label))

def load_model(model_path):
    model_dir = Path.home()/'Luiz/saved_models'
    net = load_learner(model_path, cpu=False)
    model = net.model
        
    return model

def model_head(model, n_classes):
    model.head = nn.Linear(512*16*16, n_classes)
    #model.noise_mode = True
    #model.generator_mode = False

    trainable = ['head.weight','head.bias']
    for name, p in model.named_parameters():
        if name not in trainable:
            p.requires_grad = False
        else:
            p.requires_grad = True

In [16]:
def model_test(n_class=102, layer = 1, dataset = 'FLOWERS', base_model=False, best_model=False, epochs=1, lr=1e-9, pf="3",gm16=False):
    path_model, weights_path = paths(layer, dataset, base_model, best_model, pf,gm16)
    #print(path_model, weights_path, path_data)
    print("path model:", path_model)
    model = load_model(path_model)
    model_head(model, n_class)
    print("path model:", weights_path)
    weight_dict = load_learner(weights_path, cpu=False)
    model.load_state_dict(weight_dict)
    print("weights loaded")
        
    #dloader = data_loader(path_data)
    #Defining the Loss Function
    critic_loss = SingleLabelCriticLoss()
    
    #Wraping the Learner
    learner = Learner(dloader, model, loss_func=critic_loss, metrics=[_Accuracy])
    learner.fit_one_cycle(epochs, lr)
    #learner.fine_tune(60, base_lr=5e-4, freeze_epochs=30)

In [15]:
model_test(layer = 1, base_model=True, best_model=False,gm16=True)

path model: /home/atsumilab/Luiz/saved_models/GramImageNet_Rotation_16x16grid_epochs-90_BaseModel.pkl
path model: /home/atsumilab/Luiz/gan_attention/models/finetuned/LAYER_BASE_MODEL_LOSS_3__FineTuned__FLOWERS__BestWeights_gm16.pth
weights loaded


epoch,train_loss,valid_loss,_Accuracy,time
0,0.096783,0.575391,0.854001,00:35


In [16]:
model_test(layer = 1, base_model=True, best_model=True,gm16=True)

path model: /home/atsumilab/Luiz/saved_models/GramImageNet_Rotation_16x16grid_epochs-90_BaseModel.pkl
path model: /home/atsumilab/Luiz/gan_attention/models/finetuned/LAYER_BASE_MODEL_LOSS_3__BestModel__FineTuned__FLOWERS__BestWeights_gm16.pth
weights loaded


epoch,train_loss,valid_loss,_Accuracy,time
0,0.134443,0.580133,0.855834,00:34


In [17]:
model_test(layer = 1, base_model=False, best_model=False,gm16=True)

path model: /home/atsumilab/Luiz/saved_models/GramImageNet_Rotation_16x16grid_epochs-90-beta-5e-7_PenaltyFactor3_Layer1_gm16.pkl
path model: /home/atsumilab/Luiz/gan_attention/models/finetuned/LAYER_1_LOSS_3__FineTuned__FLOWERS__BestWeights_gm16.pth
weights loaded


epoch,train_loss,valid_loss,_Accuracy,time
0,0.090136,0.531896,0.857666,00:35


In [18]:
model_test(layer = 1, base_model=False, best_model=True,gm16=True)

path model: /home/atsumilab/Luiz/saved_models/GramImageNet_Rotation_16x16grid_epochs-90-beta-5e-7_PenaltyFactor3_Layer1_gm16.pkl
path model: /home/atsumilab/Luiz/gan_attention/models/finetuned/LAYER_1_LOSS_3__BestModel__FineTuned__FLOWERS__BestWeights_gm16.pth
weights loaded


epoch,train_loss,valid_loss,_Accuracy,time
0,0.109773,0.531371,0.858277,00:35


In [18]:
model_test(layer = 2, base_model=False, best_model=False,gm16=True)

path model: /home/atsumilab/Luiz/saved_models/GramImageNet_Rotation_16x16grid_epochs-90-beta-5e-7_PenaltyFactor3_Layer2_gm16.pkl
path model: /home/atsumilab/Luiz/gan_attention/models/finetuned/LAYER_2_LOSS_3__FineTuned__FLOWERS__BestWeights_gm16.pth
weights loaded


epoch,train_loss,valid_loss,_Accuracy,time
0,0.088686,0.584204,0.852779,00:38


In [19]:
model_test(layer = 2, base_model=False, best_model=True,gm16=True)

path model: /home/atsumilab/Luiz/saved_models/GramImageNet_Rotation_16x16grid_epochs-90-beta-5e-7_PenaltyFactor3_Layer2_gm16.pkl
path model: /home/atsumilab/Luiz/gan_attention/models/finetuned/LAYER_2_LOSS_3__BestModel__FineTuned__FLOWERS__BestWeights_gm16.pth
weights loaded


epoch,train_loss,valid_loss,_Accuracy,time
0,0.089912,0.57037,0.850947,01:04


In [20]:
model_test(layer = 3, base_model=False, best_model=False,gm16=False)

path model: /home/atsumilab/Luiz/saved_models/GramImageNet_Rotation_16x16grid_epochs-90-beta-5e-6_PenaltyFactor3_Layer3.pkl
path model: /home/atsumilab/Luiz/gan_attention/models/finetuned/LAYER_3_LOSS_3__FineTuned__FLOWERS__BestWeights.pth
weights loaded


epoch,train_loss,valid_loss,_Accuracy,time
0,0.121223,0.556345,0.852779,00:35


In [21]:
model_test(layer = 3, base_model=False, best_model=True,gm16=False)

path model: /home/atsumilab/Luiz/saved_models/GramImageNet_Rotation_16x16grid_epochs-90-beta-5e-6_PenaltyFactor3_Layer3.pkl
path model: /home/atsumilab/Luiz/gan_attention/models/finetuned/LAYER_3_LOSS_3__BestModel__FineTuned__FLOWERS__BestWeights.pth
weights loaded


epoch,train_loss,valid_loss,_Accuracy,time
0,0.099216,0.556308,0.855834,00:35
