# Aprendizaje Multietiqueta de Patrones Geométricos en Objetos de Herencia Cultural
# Kunisch Features from ResNet architectures
## Seminario de Tesis II, Primavera 2022
### Master of Data Science. Universidad de Chile.
#### Prof. guía: Benjamín Bustos - Prof. coguía: Iván Sipirán
#### Autor: Matías Vergara


## Imports

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms
import time
import os
import copy
import pandas as pd
import math
import random
import shutil

from torch.utils.data import Dataset
from PIL import Image

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np, scipy.io
import argparse
import json

## Mounting Google Drive

In [9]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    root_dir = 'drive/MyDrive/TesisMV/'
except:
    root_dir = '../'

## Dataset and model selection

In [18]:
#modify only this cell
USE_RN50 = True
SUBCHAPTERS = False
DS_FLAGS = ['blur']
              # 'ref': [invertX, invertY],
              # 'rot': [rotate90, rotate180, rotate270],
              # 'crop': [crop] * CROP_TIMES,
              # 'blur': [blur],
              # 'emboss': [emboss],
              # 'randaug': [randaug],
              # 'rain': [rain],
              # 'elastic': [elastic]
CROP_TIMES = 1
RANDOM_TIMES = 1
ELASTIC_TIMES = 1
GAUSBLUR_TIMES = 1
K = 4

In [23]:
# This cells builds the data_flags variable, that will be used
# to map the requested data treatment to folders
MAP_TIMES = {'crop': CROP_TIMES,
         'randaug': RANDOM_TIMES,
         'elastic': ELASTIC_TIMES,
         'gausblur': GAUSBLUR_TIMES
}

DS_FLAGS = sorted(DS_FLAGS)
data_flags = '_'.join(DS_FLAGS) if len(DS_FLAGS) > 0 else 'base'
MULTIPLE_TRANSF = ['crop', 'randaug', 'elastic', 'gausblur']
COPY_FLAGS = DS_FLAGS.copy()

for t in MULTIPLE_TRANSF:
    if t in DS_FLAGS:
        COPY_FLAGS.remove(t)
        COPY_FLAGS.append(t + str(MAP_TIMES[t]))
        data_flags = '_'.join(COPY_FLAGS)

subchapter_str = 'subchapters/' if SUBCHAPTERS else ''
Kfolds = {}

for i in range(0, K):
    print("Fold ", i)
    patterns_dir = os.path.join(root_dir, 'patterns', subchapter_str + data_flags, str(i))
    labels_dir = os.path.join(root_dir, 'labels', subchapter_str + data_flags, str(i))
    # models_path = folder_path + 'models/' + subchapter_str + (f'resnet50_{data_flags}.pth' if USE_RN50 else f'resnet18_{data_flags}.pth')
    # features_path = folder_path + 'features/' + subchapter_str + (f'resnet50_{data_flags}/' if USE_RN50 else f'resnet18_{data_flags}/')
    #rn = 50
    #ep = 65
    #models_path = folder_path + f"models/resnet{rn}_blur_each5/resnet{rn}_blur_e{ep}.pth"
    #features_path = folder_path + f"features/resnet{rn}_blur_each5/resnet{rn}_blur_e{ep}/"
    rn = 50 if USE_RN50 else 18
    models_path = os.path.join(root_dir, 'models', 'resnet', data_flags, f'resnet{rn}_K0.pth')
    features_dir = os.path.join(root_dir, 'features', 'resnet', data_flags, f'resnet{rn}_K{i}')

    if not (os.path.isdir(patterns_dir) and os.path.isdir(labels_dir)):
        print(patterns_dir)
        print(labels_dir)
        raise FileNotFoundError("""
        No existen directorios de datos para el conjunto de flags seleccionado. 
        Verifique que el dataset exista y, de lo contrario, llame a Split and Augmentation.
        """)
    if not (os.path.isfile(models_path)):
        print(models_path)
        raise FileNotFoundError(f"""
        No se encontró modelo para el conjunto de flags seleccionado. 
        Verifique que el modelo exista y, de lo contrario, llame a ResNet Retraining
        """)
                                
    Kfolds[i] = {
        'patterns_dir': patterns_dir,
        'labels_dir': labels_dir,
        'model_path': models_path,
        'features_dir': features_dir
    }
    
    print("--Pattern set encontrado en {}".format(patterns_dir))
    print("--Labels set encontrado en {}".format(labels_dir))
    print("--Modelo encontrado en {}".format(models_path))
    print("--Features a guardar en {}".format(features_dir))


Fold  0
--Pattern set encontrado en ../patterns\blur\0
--Labels set encontrado en ../labels\blur\0
--Modelo encontrado en ../models\resnet\blur\resnet50_K0.pth
--Features a guardar en ../features\resnet\blur\resnet50_K0
Fold  1
--Pattern set encontrado en ../patterns\blur\1
--Labels set encontrado en ../labels\blur\1
--Modelo encontrado en ../models\resnet\blur\resnet50_K0.pth
--Features a guardar en ../features\resnet\blur\resnet50_K1
Fold  2
--Pattern set encontrado en ../patterns\blur\2
--Labels set encontrado en ../labels\blur\2
--Modelo encontrado en ../models\resnet\blur\resnet50_K0.pth
--Features a guardar en ../features\resnet\blur\resnet50_K2
Fold  3
--Pattern set encontrado en ../patterns\blur\3
--Labels set encontrado en ../labels\blur\3
--Modelo encontrado en ../models\resnet\blur\resnet50_K0.pth
--Features a guardar en ../features\resnet\blur\resnet50_K3


## Dataset loader

In [21]:
class PatternDataset(Dataset):
    def __init__(self, root_dir, transform=None, build_classification=False, name_cla='output.cla'):
        self.root_dir = root_dir
        self.transform = transform
        self.namefiles = []


        self.classes = sorted(os.listdir(self.root_dir))

        for cl in self.classes:
            for pat in os.listdir(os.path.join(self.root_dir, cl)):
                self.namefiles.append((pat, cl))

        print(f'Files:{len(self.namefiles)}')
        self.namefiles = sorted(self.namefiles, key = lambda x: x[0])

        if build_classification:
            dictClasses = dict()

            for cl in self.classes:
                dictClasses[cl] = []

            for index, (name, cl) in enumerate(self.namefiles):
                dictClasses[cl].append((name, index))

            with open(name_cla, 'w') as f:
                f.write('PSB 1\n')
                f.write(f'{len(self.classes)} {len(self.namefiles)}\n')
                f.write('\n')
                for cl in self.classes:
                    f.write(f'{cl} 0 {len(dictClasses[cl])}\n')
                    for item in dictClasses[cl]:
                        f.write(f'{item[1]}\n')
                    f.write('\n')

    def __len__(self):
        return len(self.namefiles)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        img_name = os.path.join(self.root_dir, self.namefiles[index][1], self.namefiles[index][0])
        image = Image.open(img_name).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return self.namefiles[index], image


## Funciones auxiliares

In [22]:
def imshow(inp, title = None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)

    plt.imshow(inp)
    plt.show()

def get_vector(model,layer, dim_embedding, x):

  my_embedding = torch.zeros(dim_embedding)

  def copy_data(m,i,o):
    my_embedding.copy_(o.data.squeeze())

  h = layer.register_forward_hook(copy_data)
  model(x)
  h.remove()

  return my_embedding

## Extraction

In [24]:
DEVICE = 0
# 0 3090 (o 1060 en local)
# 1 y 2 2080

In [25]:
random.seed(30)

for i in range(0, K):
    patterns_dir = Kfolds[i]['patterns_dir']
    labels_dir = Kfolds[i]['labels_dir']
    model_path = Kfolds[i]['model_path']
    features_dir = Kfolds[i]['features_dir']
    
    output_train = os.path.join(features_dir, "augmented_train_df.json")
    output_val = os.path.join(features_dir, "val_df.json")
    output_test = os.path.join(features_dir, "test_df.json")

    train_df = pd.read_json(os.path.join(labels_dir, "augmented_train_df.json"), orient='index')
    val_df = pd.read_json(os.path.join(labels_dir, "val_df.json"), orient='index')
    test_df = pd.read_json(os.path.join(labels_dir, "test_df.json"), orient='index')

    train_pts = train_df.index.values
    val_pts = val_df.index.values
    test_pts = test_df.index.values

    device = ('cuda:0' if torch.cuda.is_available() else None)
    if device is None:
        raise Exception("La GPU solicitada no está disponible")

    my_transform = transforms.Compose([ transforms.Resize(224),
                                        #transforms.CenterCrop(224),
                         transforms.ToTensor(),
                         transforms.Normalize(mean=[0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
                        ])

    dataTrain = PatternDataset(root_dir=os.path.join(patterns_dir, 'train'), transform=my_transform)
    dataVal = PatternDataset(root_dir=os.path.join(patterns_dir, 'val'), transform=my_transform)
    dataTest = PatternDataset(root_dir=os.path.join(patterns_dir, 'test'), transform=my_transform)

    loaderTrain = DataLoader(dataTrain)
    loaderVal = DataLoader(dataVal)
    loaderTest = DataLoader(dataTest)

    if USE_RN50:
        model = models.resnet50(pretrained = True)
    else:
        model = models.resnet18(pretrained = True)
    dim = model.fc.in_features

    output_dim = 20 if SUBCHAPTERS else 6
    model.fc = nn.Linear(dim, output_dim)

    model = model.to(device)

    try:
        model.load_state_dict(torch.load(models_path))
    except RuntimeError as e:
        print('Ignoring "' + str(e) + '"')

    layer = model._modules.get('avgpool')

    model.eval()

    #features = []
    features_train = {}
    features_val = {}
    features_test = {}


    for name, img in loaderTrain:
      feat = get_vector(model, layer, dim, img.to(device))
      namefile = name[0][0]
      code, rest = namefile.split('.')
      features_train[code] = feat.numpy().tolist()
      #features.append(feat.numpy())

    for name, img in loaderVal:
      feat = get_vector(model, layer, dim, img.to(device))
      namefile = name[0][0]
      code, rest = namefile.split('.')
      features_val[code] = feat.numpy().tolist()
      #features.append(feat.numpy())

    for name, img in loaderTest:
      feat = get_vector(model, layer, dim, img.to(device))
      namefile = name[0][0]
      code, rest = namefile.split('.')
      features_test[code] = feat.numpy().tolist()
    #features = np.vstack(features)
    #print(features.shape)

    os.makedirs(features_dir, exist_ok=True)

    features_train_df = pd.DataFrame.from_dict(features_train, orient='index')
    features_val_df = pd.DataFrame.from_dict(features_val, orient='index')
    features_test_df = pd.DataFrame.from_dict(features_test, orient='index')

    features_train_df.to_json(output_train, orient='index')
    features_val_df.to_json(output_val, orient='index')
    features_test_df.to_json(output_test, orient='index')

    display(features_train_df)

Files:1008
Files:78
Files:194


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
10a,0.093543,0.348628,0.171098,0.329167,0.175677,0.058282,0.283390,0.037341,0.046214,0.001295,...,0.039145,0.039882,0.006907,0.027918,0.080345,0.016563,0.018120,0.043452,0.060444,0.002079
10a_blur,0.091480,0.320457,0.146162,0.328436,0.153196,0.067046,0.278890,0.044845,0.046699,0.000926,...,0.054381,0.054279,0.004835,0.027717,0.077647,0.013934,0.017158,0.042639,0.078073,0.002629
10b,0.069442,0.324692,0.145514,0.369813,0.095713,0.087898,0.260913,0.148231,0.043951,0.139633,...,0.152153,0.278249,0.047254,0.057533,0.086026,0.011156,0.100416,0.032219,0.142593,0.007896
10b_blur,0.064725,0.303267,0.141621,0.364327,0.104214,0.108745,0.258155,0.194874,0.042243,0.202844,...,0.198267,0.398158,0.058904,0.082508,0.075959,0.014638,0.130843,0.031133,0.183328,0.007138
10c,0.063454,0.497912,0.249252,0.065314,0.031488,0.072026,0.324419,0.041397,0.031247,0.137406,...,0.048290,0.132753,0.072855,0.052035,0.154804,0.003347,0.126461,0.043257,0.075543,0.002115
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9b_blur,0.101147,0.189822,0.079654,0.155821,0.106691,0.378813,0.223935,0.475222,0.068199,0.635979,...,0.491361,1.520536,0.221814,0.367487,0.043188,0.043173,0.351386,0.046060,0.605617,0.033140
9c,0.154298,0.376077,0.179643,0.053473,0.085490,0.580383,0.356863,0.432185,0.113946,0.534076,...,0.393091,1.792329,0.202303,0.466538,0.035965,0.020529,0.353905,0.073029,0.776988,0.036277
9c_blur,0.147101,0.368076,0.173734,0.046606,0.092645,0.609050,0.354745,0.415079,0.115079,0.522099,...,0.383552,1.843882,0.203935,0.495830,0.034736,0.021447,0.359334,0.071737,0.821086,0.033539
9d,0.095113,0.174041,0.091930,0.021109,0.049420,0.401161,0.230249,0.306228,0.098975,0.598888,...,0.343681,1.413953,0.268906,0.373673,0.036702,0.027089,0.387893,0.051760,0.640358,0.027142


Files:1008
Files:78
Files:194


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
10c,0.063454,0.497912,0.249252,0.065314,0.031488,0.072026,0.324419,0.041397,0.031247,0.137406,...,0.048290,0.132753,0.072855,0.052035,0.154804,0.003347,0.126461,0.043257,0.075543,0.002115
10c_blur,0.057371,0.465704,0.214624,0.065697,0.031990,0.078725,0.313281,0.048767,0.031350,0.151528,...,0.054668,0.158669,0.079237,0.057573,0.152382,0.003550,0.133361,0.043165,0.087062,0.001909
10d,0.071531,0.320793,0.139458,0.145462,0.039113,0.129930,0.290078,0.115839,0.041396,0.195424,...,0.138473,0.333107,0.101847,0.104421,0.132106,0.011607,0.103305,0.043055,0.173428,0.010620
10d_blur,0.060004,0.277608,0.118968,0.123613,0.036181,0.122241,0.262344,0.105442,0.033479,0.251395,...,0.164768,0.295900,0.138806,0.097134,0.121996,0.010719,0.129154,0.036798,0.197809,0.011002
11a,0.071619,0.507139,0.263094,0.042244,0.029847,0.083769,0.454336,0.176309,0.066955,0.101543,...,0.005890,0.000000,0.006178,0.022858,0.088802,0.003125,0.161583,0.062744,0.021123,0.000029
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9b_blur,0.115453,0.204059,0.091607,0.157751,0.090004,0.376595,0.240741,0.477214,0.077845,0.611987,...,0.491067,1.505052,0.205166,0.362888,0.042410,0.040814,0.331266,0.053792,0.589591,0.031049
9c,0.154298,0.376077,0.179643,0.053473,0.085490,0.580383,0.356863,0.432185,0.113946,0.534076,...,0.393091,1.792329,0.202303,0.466538,0.035965,0.020529,0.353905,0.073029,0.776988,0.036277
9c_blur,0.143651,0.357433,0.168742,0.041100,0.096947,0.622166,0.348449,0.416576,0.113355,0.529945,...,0.387073,1.889593,0.207370,0.510393,0.034813,0.022634,0.361864,0.070757,0.847610,0.034380
9d,0.095113,0.174041,0.091930,0.021109,0.049420,0.401161,0.230249,0.306228,0.098975,0.598888,...,0.343681,1.413953,0.268906,0.373673,0.036702,0.027089,0.387893,0.051760,0.640358,0.027142


Files:1008
Files:78
Files:194


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
10a,0.093543,0.348628,0.171098,0.329167,0.175677,0.058282,0.283390,0.037341,0.046214,0.001295,...,0.039145,0.039882,0.006907,0.027918,0.080345,0.016563,0.018120,0.043452,0.060444,0.002079
10a_blur,0.094180,0.315949,0.145740,0.327025,0.149733,0.068750,0.282146,0.045048,0.049081,0.000971,...,0.059072,0.057658,0.003879,0.027660,0.077486,0.014414,0.016173,0.043979,0.082746,0.002938
10d,0.071531,0.320793,0.139458,0.145462,0.039113,0.129930,0.290078,0.115839,0.041396,0.195424,...,0.138473,0.333107,0.101847,0.104421,0.132106,0.011607,0.103305,0.043055,0.173428,0.010620
10d_blur,0.061044,0.300474,0.124623,0.141486,0.037581,0.132050,0.276329,0.112355,0.037082,0.204165,...,0.143859,0.324835,0.109270,0.103000,0.131731,0.011030,0.098215,0.039527,0.179486,0.010105
10e,0.150976,0.511532,0.294358,0.040990,0.030825,0.117814,0.554213,0.144390,0.109978,0.120161,...,0.050571,0.000416,0.026723,0.028624,0.068219,0.003754,0.085308,0.075878,0.042042,0.004458
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9a_blur,0.046878,0.197632,0.161722,0.088945,0.059041,0.211574,0.215369,0.343269,0.047311,0.612756,...,0.429109,0.988154,0.239265,0.227270,0.045766,0.014799,0.314182,0.031282,0.368519,0.049790
9b,0.137744,0.242332,0.118246,0.168992,0.070164,0.384724,0.268882,0.467853,0.095913,0.551952,...,0.468065,1.485050,0.184823,0.363987,0.039774,0.034105,0.289622,0.063914,0.578362,0.031783
9b_blur,0.128136,0.220922,0.102040,0.160426,0.079825,0.385871,0.255774,0.479550,0.086161,0.591691,...,0.488345,1.505256,0.194825,0.366735,0.040755,0.038488,0.315867,0.061480,0.589053,0.029882
9e,0.096564,0.156842,0.095601,0.025917,0.106395,0.422722,0.164068,0.636789,0.069676,0.892272,...,0.632427,2.007805,0.318327,0.442797,0.038842,0.036151,0.478321,0.038314,0.791987,0.028513


Files:1008
Files:78
Files:194


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
10a,0.093543,0.348628,0.171098,0.329167,0.175677,0.058282,0.283390,0.037341,0.046214,0.001295,...,0.039145,0.039882,0.006907,0.027918,0.080345,0.016563,0.018120,0.043452,0.060444,0.002079
10a_blur,0.090935,0.324797,0.152496,0.327198,0.155951,0.063335,0.280689,0.043367,0.046998,0.001052,...,0.052599,0.053126,0.004821,0.027952,0.078567,0.014082,0.016661,0.042222,0.076228,0.002637
10d,0.071531,0.320793,0.139458,0.145462,0.039113,0.129930,0.290078,0.115839,0.041396,0.195424,...,0.138473,0.333107,0.101847,0.104421,0.132106,0.011607,0.103305,0.043055,0.173428,0.010620
10d_blur,0.060004,0.277608,0.118968,0.123613,0.036181,0.122241,0.262344,0.105442,0.033479,0.251395,...,0.164768,0.295900,0.138806,0.097134,0.121996,0.010719,0.129154,0.036798,0.197809,0.011002
10e,0.150976,0.511532,0.294358,0.040990,0.030825,0.117814,0.554213,0.144390,0.109978,0.120161,...,0.050571,0.000416,0.026723,0.028624,0.068219,0.003754,0.085308,0.075878,0.042042,0.004458
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9c_blur,0.143651,0.357433,0.168742,0.041100,0.096947,0.622166,0.348449,0.416576,0.113355,0.529945,...,0.387073,1.889593,0.207370,0.510393,0.034813,0.022634,0.361864,0.070757,0.847610,0.034380
9d,0.095113,0.174041,0.091930,0.021109,0.049420,0.401161,0.230249,0.306228,0.098975,0.598888,...,0.343681,1.413953,0.268906,0.373673,0.036702,0.027089,0.387893,0.051760,0.640358,0.027142
9d_blur,0.091476,0.171753,0.088472,0.022720,0.047098,0.437656,0.218932,0.352733,0.090297,0.696122,...,0.412202,1.551421,0.292673,0.406713,0.035335,0.024025,0.449227,0.048842,0.724628,0.028800
9e,0.096564,0.156842,0.095601,0.025917,0.106395,0.422722,0.164068,0.636789,0.069676,0.892272,...,0.632427,2.007805,0.318327,0.442797,0.038842,0.036151,0.478321,0.038314,0.791987,0.028513
