# Aprendizaje Multietiqueta de Patrones Geométricos en Objetos de Herencia Cultural
# Kunisch Features from ResNet architectures
## Seminario de Tesis II, Primavera 2022
### Master of Data Science. Universidad de Chile.
#### Prof. guía: Benjamín Bustos - Prof. coguía: Iván Sipirán
#### Autor: Matías Vergara

El objetivo de este notebook es extraer features mediante ResNet

## Imports

In [2]:
root_dir = '../'

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
%matplotlib inline
from matplotlib import pyplot as plt
from torchvision import datasets, models, transforms
import time
import os
import copy
import pandas as pd
import math
import random
import shutil

from torch.utils.data import Dataset
from PIL import Image

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np, scipy.io
import argparse
import json

## Configuración de dispositivo

In [5]:
CUDA_ID = 0

device = torch.device(f'cuda:{CUDA_ID}' if torch.cuda.is_available() else 'cpu')
print(f"Usando device: {torch.cuda.get_device_name(device)}")

Usando device: NVIDIA GeForce GTX 1060


## Configuración de datos y modelo

In [6]:
# Flags para los datos sintéticos
# Cada flag está asociada a una o más funciones de data augmentation.
# Los datos deben existir previamente 
# (se generan a partir del notebook split and augmentation)
USE_RN50 = False
DS_FLAGS = ['ref']
              # 'ref': [invertX, invertY],
              # 'rot': [rotate90, rotate180, rotate270],
              # 'crop': [crop] * CROP_TIMES,
              # 'blur': [blur],
              # 'emboss': [emboss],
              # 'randaug': [randaug],
              # 'rain': [rain],
              # 'elastic': [elastic]
                
# Las flags crop, randaug, elastic y gausblur 
# se pueden aplicar más de una vez c/u. 
# (si no están en DS_FLAGS, serán ignoradas).
CROP_TIMES = 1
RANDOM_TIMES = 1
ELASTIC_TIMES = 1
GAUSBLUR_TIMES = 1
K = 4

In [7]:
# Esta celda construye la variable data_flags, que lee DS_FLAGS de 
# la celda anterior y mapea su contenido a distintas rutas de 
# patrones, etiquetas y outputs
MAP_TIMES = {'crop': CROP_TIMES,
         'randaug': RANDOM_TIMES,
         'elastic': ELASTIC_TIMES,
         'gausblur': GAUSBLUR_TIMES
}

DS_FLAGS = sorted(DS_FLAGS)
data_flags = '_'.join(DS_FLAGS) if len(DS_FLAGS) > 0 else 'base'
MULTIPLE_TRANSF = MAP_TIMES.keys()
COPY_FLAGS = DS_FLAGS.copy()

for t in MULTIPLE_TRANSF:
    if t in DS_FLAGS:
        COPY_FLAGS.remove(t)
        COPY_FLAGS.append(t + str(MAP_TIMES[t]))
        data_flags = '_'.join(COPY_FLAGS)

# Revisión de los folds y creación de diccionario con paths
Kfolds = {}

for i in range(0, K):
    print("Fold ", i)
    patterns_dir = os.path.join(root_dir, 'patterns',  data_flags, str(i))
    labels_dir = os.path.join(root_dir, 'labels', data_flags, str(i))
    rn = 50 if USE_RN50 else 18
    models_path = os.path.join(root_dir, 'models', 'resnet', data_flags, f'resnet{rn}_K0.pth')
    features_dir = os.path.join(root_dir, 'features', 'resnet', data_flags, f'resnet{rn}_K{i}')

    if not (os.path.isdir(patterns_dir) and os.path.isdir(labels_dir)):
        print(patterns_dir)
        print(labels_dir)
        raise FileNotFoundError("""
        No existen directorios de datos para el conjunto de flags seleccionado. 
        Verifique que el dataset exista y, de lo contrario, llame a Split and Augmentation.
        """)
    if not (os.path.isfile(models_path)):
        print(models_path)
        raise FileNotFoundError(f"""
        No se encontró modelo para el conjunto de flags seleccionado. 
        Verifique que el modelo exista y, de lo contrario, llame a ResNet Retraining
        """)
                                
    Kfolds[i] = {
        'patterns_dir': patterns_dir,
        'labels_dir': labels_dir,
        'model_path': models_path,
        'features_dir': features_dir
    }
    
    print("--Pattern set encontrado en {}".format(patterns_dir))
    print("--Labels set encontrado en {}".format(labels_dir))
    print("--Modelo encontrado en {}".format(models_path))
    print("--Features a guardar en {}".format(features_dir))


Fold  0
--Pattern set encontrado en ../patterns\ref\0
--Labels set encontrado en ../labels\ref\0
--Modelo encontrado en ../models\resnet\ref\resnet18_K0.pth
--Features a guardar en ../features\resnet\ref\resnet18_K0
Fold  1
--Pattern set encontrado en ../patterns\ref\1
--Labels set encontrado en ../labels\ref\1
--Modelo encontrado en ../models\resnet\ref\resnet18_K0.pth
--Features a guardar en ../features\resnet\ref\resnet18_K1
Fold  2
--Pattern set encontrado en ../patterns\ref\2
--Labels set encontrado en ../labels\ref\2
--Modelo encontrado en ../models\resnet\ref\resnet18_K0.pth
--Features a guardar en ../features\resnet\ref\resnet18_K2
Fold  3
--Pattern set encontrado en ../patterns\ref\3
--Labels set encontrado en ../labels\ref\3
--Modelo encontrado en ../models\resnet\ref\resnet18_K0.pth
--Features a guardar en ../features\resnet\ref\resnet18_K3


## Dataset loader

In [8]:
class PatternDataset(Dataset):
    def __init__(self, root_dir, transform=None, build_classification=False, name_cla='output.cla'):
        self.root_dir = root_dir
        self.transform = transform
        self.namefiles = []
        
        self.classes = sorted(os.listdir(self.root_dir))

        for cl in self.classes:
            for pat in os.listdir(os.path.join(self.root_dir, cl)):
                self.namefiles.append((pat, cl))

        print(f'Files:{len(self.namefiles)}')
        #self.namefiles = sorted(self.namefiles, key = lambda x: x[0])
        
        if build_classification:
            dictClasses = dict()

            for cl in self.classes:
                dictClasses[cl] = []

            for index, (name, cl) in enumerate(self.namefiles):
                #print(index, name, cl)
                dictClasses[cl].append((name, index))

            with open(name_cla, 'w') as f:
                f.write('PSB 1\n')
                f.write(f'{len(self.classes)} {len(self.namefiles)}\n')
                f.write('\n')
                for cl in self.classes:
                    f.write(f'{cl} 0 {len(dictClasses[cl])}\n')
                    for item in dictClasses[cl]:
                        f.write(f'{item[1]}\n')
                    f.write('\n')

    def __len__(self):
        return len(self.namefiles)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        img_name = os.path.join(self.root_dir, self.namefiles[index][1], self.namefiles[index][0])
        image = Image.open(img_name).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return self.namefiles[index], image


## Funciones auxiliares

In [9]:
def imshow(inp, title = None):
    inp = inp.cpu().detach()
    inp = np.squeeze(inp)
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)

    plt.imshow(inp)
    plt.show()

def get_vector(model,layer, dim_embedding, x):

  my_embedding = torch.zeros(dim_embedding)

  def copy_data(m,i,o):
    my_embedding.copy_(o.data.squeeze())

  h = layer.register_forward_hook(copy_data)
  model(x)
  h.remove()

  return my_embedding

## Extracción de features

In [73]:
import random
random.seed(30)

for i in range(0, K):
    patterns_dir = Kfolds[i]['patterns_dir']
    labels_dir = Kfolds[i]['labels_dir']
    model_path = Kfolds[i ]['model_path']
    features_dir = Kfolds[i]['features_dir']
    
    output_train = os.path.join(features_dir, "augmented_train_df.json")
    output_val = os.path.join(features_dir, "val_df.json")
    output_test = os.path.join(features_dir, "test_df.json")

    train_df = pd.read_json(os.path.join(labels_dir, "augmented_train_df.json"), orient='index')
    val_df = pd.read_json(os.path.join(labels_dir, "val_df.json"), orient='index')
    test_df = pd.read_json(os.path.join(labels_dir, "test_df.json"), orient='index')


    my_transform = transforms.Compose([ transforms.Resize(224),
                                        #transforms.CenterCrop(224),
                         transforms.ToTensor(),
                         transforms.Normalize(mean=[0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
                        ])

    dataTrain = PatternDataset(root_dir=os.path.join(patterns_dir, 'train'), transform=my_transform)
    dataVal = PatternDataset(root_dir=os.path.join(patterns_dir, 'val'), transform=my_transform)
    dataTest = PatternDataset(root_dir=os.path.join(patterns_dir, 'test'), transform=my_transform)

    loaderTrain = DataLoader(dataTrain, shuffle="False")
    loaderVal = DataLoader(dataVal, shuffle="False")
    loaderTest = DataLoader(dataTest, shuffle="False")

    if USE_RN50:
        model = models.resnet50(pretrained = True)
    else:
        model = models.resnet18(pretrained = True)
    dim = model.fc.in_features

    output_dim = 6
    model.fc = nn.Linear(dim, output_dim)

    model = model.to(device)
    
    # Carga del modelo en models_path
    try:
        model.load_state_dict(torch.load(models_path))
    except RuntimeError as e:
        print('Ignoring "' + str(e) + '"')

    layer = model._modules.get('avgpool')

    model.eval()

    features_train = {}
    features_val = {}
    features_test = {}


    for name, img in loaderTrain:
      feat = get_vector(model, layer, dim, img.to(device))
      namefile = name[0][0]
      code, rest = namefile.split('.')
      #print(code)
      #imshow(img)
      features_train[code] = feat.numpy().tolist()
      #features.append(feat.numpy())

    for name, img in loaderVal:
      feat = get_vector(model, layer, dim, img.to(device))
      namefile = name[0][0]
      code, rest = namefile.split('.')
      #print(code)
      #imshow(img) 
      features_val[code] = feat.numpy().tolist()
      #features.append(feat.numpy())

    for name, img in loaderTest:
      feat = get_vector(model, layer, dim, img.to(device))
      namefile = name[0][0]
      code, rest = namefile.split('.')
      #print(code)
      #imshow(img)
      features_test[code] = feat.numpy().tolist()
    #features = np.vstack(features)
    #print(features.shape)

    os.makedirs(features_dir, exist_ok=True)

    features_train_df = pd.DataFrame.from_dict(features_train, orient='index')
    features_val_df = pd.DataFrame.from_dict(features_val, orient='index')
    features_test_df = pd.DataFrame.from_dict(features_test, orient='index')

    features_train_df.to_json(output_train, orient='index')
    features_val_df.to_json(output_val, orient='index')
    features_test_df.to_json(output_test, orient='index')



Files:1512
Files:78
Files:194


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,502,503,504,505,506,507,508,509,510,511
32f_invY,0.216411,2.545912,0.274050,0.030606,0.207415,1.782455,1.839407,3.205649,0.184709,0.739667,...,1.175827,0.626792,0.099139,0.000000,0.441330,0.923945,0.819985,0.595752,0.000000,0.001602
47c,0.484535,2.479028,0.119957,0.391916,0.212415,0.729945,1.093478,2.766302,0.020138,1.099458,...,0.370584,0.797214,0.074316,0.308721,1.637857,0.273490,0.244051,0.231993,0.264718,0.264572
81j_invY,0.167374,1.162126,4.310850,4.277111,4.212051,0.632702,2.099672,0.174170,3.903870,0.824926,...,4.278620,3.275423,2.448764,2.637787,0.138432,5.192579,4.957649,3.598578,4.689170,3.537819
88f_invY,0.257723,0.480580,1.525188,1.474157,1.195313,0.662575,1.416506,0.075061,2.424763,0.428542,...,1.497035,0.349340,0.874630,0.724427,0.312005,2.375319,3.028110,1.092815,1.722909,1.284860
34c,0.549399,2.581027,0.383845,0.043184,0.314326,1.680813,2.404783,3.103560,0.078637,0.467551,...,0.791483,0.719377,0.337735,0.000000,0.463197,0.543999,0.685246,0.392184,0.037586,0.014642
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
40g_invY,0.051304,0.738227,0.446983,0.164002,0.039714,1.051752,0.344923,2.924332,0.269679,0.158704,...,0.005367,0.179380,0.092836,0.062302,1.787000,0.000000,0.017328,0.311029,0.093528,0.408777
78a_invY,0.666827,2.223170,3.663968,2.954900,3.337783,1.446215,1.498405,1.445929,2.544292,0.260688,...,2.950609,3.393563,2.691127,2.364173,0.210872,2.765521,2.706462,2.799938,3.982060,3.402439
29f_invX,0.869446,2.458507,0.186262,0.078984,0.325661,2.017623,2.520950,4.009588,0.027564,0.513463,...,1.108680,0.663844,0.389342,0.000925,0.748815,0.822696,0.747806,0.587260,0.016361,0.027530
49a_invY,0.088067,1.347415,0.149512,0.313623,0.100394,0.064559,0.311103,0.921756,0.005480,0.500713,...,0.107551,0.395049,0.064779,0.649875,2.115431,0.024865,0.014013,0.066650,0.137145,0.157873


Files:1512
Files:78
Files:194


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,502,503,504,505,506,507,508,509,510,511
82a_invY,1.072594,1.427161,4.033115,3.399431,4.051223,1.064157,2.565021,0.346391,3.541313,0.510676,...,3.743491,3.246346,3.398809,2.454486,0.102613,4.287403,4.343206,3.469133,4.418277,3.625623
8c,0.001493,0.199711,1.726705,1.145872,0.553033,0.235243,0.574435,0.235280,1.516833,0.540164,...,0.936664,0.338214,0.241967,0.493356,1.128581,0.940068,1.544513,0.773725,0.998768,0.590011
16f_invX,0.846576,5.367227,1.033123,0.037234,1.123403,2.744650,3.948873,4.434862,0.000000,0.677928,...,2.244979,2.141965,0.625196,0.000000,0.403097,1.249281,0.740420,1.453088,0.000000,0.000000
39c_invY,0.012398,0.191240,0.785449,0.196365,0.010843,1.037975,0.080920,2.721721,0.222794,0.397486,...,0.022338,0.028560,0.004613,0.030188,2.597660,0.000000,0.000000,0.301180,0.227530,0.530134
70g,1.974855,0.067050,0.035863,1.040721,0.295270,0.165782,0.908374,0.129319,0.223767,0.037907,...,0.109910,0.070594,1.275713,2.946160,1.895627,0.186038,0.560121,0.111808,0.814316,1.247456
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8a_invY,0.073693,0.077472,0.381961,0.199609,0.062655,0.153989,0.309793,0.080791,0.071790,0.475557,...,0.197192,0.049841,0.042692,0.064621,1.419424,0.045911,0.074147,0.247505,0.010024,0.011854
26a_invX,0.055551,1.324776,0.503448,0.005300,0.093732,1.252469,1.177673,1.268006,0.173990,0.550809,...,1.216022,0.315240,0.038877,0.002292,0.590155,0.846484,0.771976,0.298893,0.003486,0.005490
78g_invX,1.880470,0.103724,0.419560,1.131212,0.522444,0.430493,0.422485,0.057715,0.685858,1.003886,...,0.877881,0.750161,1.150021,0.845831,0.347843,0.897877,0.459540,1.415521,0.531890,0.996927
25c_invX,0.476791,4.396452,0.393892,0.000000,0.265091,2.394427,3.025063,4.723989,0.000000,1.023598,...,1.631694,1.360861,0.105560,0.000165,0.307730,1.075669,0.808615,0.833923,0.000000,0.000000


Files:1512
Files:78
Files:194


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,502,503,504,505,506,507,508,509,510,511
88c_invX,0.195959,0.164719,1.775942,0.654523,0.618389,1.268099,0.634282,0.629411,1.494126,0.350196,...,0.805918,0.401608,0.839713,0.528752,0.332427,0.686633,1.214921,1.296892,1.247925,1.393699
26c,0.568114,2.732448,0.133458,0.000000,0.206960,2.318360,2.103534,3.722894,0.014717,1.021454,...,1.437675,0.701021,0.222803,0.000000,0.510308,1.046449,0.639693,0.887953,0.000000,0.000000
44o,0.077733,0.042295,0.532441,0.513983,0.009748,0.911816,0.194794,1.505745,0.481902,0.593585,...,0.001312,0.005070,0.096643,0.430190,2.297871,0.000000,0.016692,0.380887,0.666745,1.103290
36d_invY,0.060210,0.593921,0.400025,0.122684,0.079162,0.705904,0.306907,2.839922,0.213602,0.128790,...,0.000288,0.067545,0.010191,0.009808,1.612692,0.000000,0.024584,0.103930,0.062541,0.188538
44g_invY,0.019854,0.018246,0.308339,0.737510,0.000586,0.800345,0.071201,1.753907,0.527728,0.437807,...,0.050939,0.042971,0.048171,0.401792,1.941099,0.077835,0.027707,0.204953,0.437603,0.919789
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
94c,1.290805,0.325356,0.282210,0.047970,0.643613,0.745732,2.110984,0.242463,0.012401,0.117520,...,0.223279,0.150560,0.849458,0.389930,0.477408,0.167774,0.344012,0.076085,0.031008,0.105726
91g_invX,1.465635,0.946826,0.589266,0.740385,0.996050,0.626214,1.808609,1.352582,0.608689,0.220137,...,0.385490,0.702968,1.059101,0.968509,0.399511,0.521361,0.940603,0.283190,1.504369,1.386683
7c_invX,0.085936,0.966380,0.683972,0.635693,0.066633,0.555428,0.667358,1.488305,0.230282,0.364355,...,0.336890,0.216069,0.009221,0.239407,1.194931,0.320945,0.487119,0.044892,0.186002,0.173820
45d_invX,1.658123,0.005410,0.805805,0.778512,0.359044,1.600987,1.804917,1.491762,1.515018,0.247283,...,0.113331,0.020131,1.147446,1.619789,0.784828,0.127292,1.311321,0.235353,1.774647,2.395244


Files:1512
Files:78
Files:194


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,502,503,504,505,506,507,508,509,510,511
1a_invY,0.000000,0.444001,0.469720,0.342201,0.128196,0.098404,0.474952,1.164770,0.609645,0.418495,...,0.065258,0.042943,0.000594,0.084639,0.738033,0.107680,1.132334,0.019410,0.315884,0.004962
76j_invY,0.960298,0.614653,3.714630,3.540163,3.252182,0.779323,1.059946,0.354960,3.365393,0.459956,...,2.940090,2.690079,3.230767,2.802829,0.147816,3.265899,3.315222,3.185577,4.274198,3.966321
42c,0.000000,1.260223,0.597104,0.000000,0.065283,1.490760,0.613478,4.135476,0.096593,0.299561,...,0.132280,0.219124,0.000000,0.058742,1.785876,0.000000,0.006349,0.114979,0.016855,0.199870
80e,0.784384,2.206610,5.592750,5.115158,4.876832,1.154241,1.113116,0.744520,4.340173,0.512180,...,4.884574,4.842781,4.164837,3.840490,0.113775,4.819702,4.184783,4.869024,5.981846,5.028915
85d_invY,0.183677,0.817187,1.670355,1.441479,1.149208,0.626146,0.875047,0.213847,1.808618,0.667842,...,1.808463,1.007956,1.010991,0.684361,0.356242,2.066576,2.102584,1.562619,1.302665,0.982233
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14d_invY,0.061475,2.638050,0.657559,0.031611,0.139960,1.411546,1.177368,3.739347,0.145887,0.632488,...,0.353763,0.537504,0.046699,0.120497,1.374792,0.012663,0.043879,0.600542,0.018930,0.132879
67h_invX,1.128176,0.114723,0.231848,1.491001,0.000000,0.012045,0.043484,0.243195,0.000000,1.191511,...,0.000000,0.000843,0.000000,1.563635,2.355037,0.000000,0.000000,0.000000,0.656316,0.821345
11f_invY,0.440269,3.150263,0.195141,0.049571,0.142829,2.160075,2.234122,3.944328,0.076729,0.999343,...,1.542302,0.775421,0.123188,0.000000,0.463660,1.201396,0.909048,0.770520,0.000000,0.000000
82j_invY,0.267983,1.495764,3.343214,2.978197,3.156504,0.847442,1.661190,0.284818,2.479091,0.615175,...,3.633921,3.023523,2.206385,1.715760,0.173127,3.873477,3.389818,3.245658,2.959706,2.272059
