Para ello ya escribimos un Dataset particular que lee el CSV y dado el índice de un ejemplo, calcula la posición de la cabeza, carga la imágen y retorna el par (imagen, posición) que para nuestro problema constituirán el par (x,y) (feature o variable independiente, target o variable dependiente):

In [1]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import torchvision.models
from torch import nn

import numpy as np
from matplotlib import pyplot
import math
import pickle
import gzip
from pathlib import Path
import requests
from tqdm import tqdm

In [2]:
resnet_model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.DEFAULT)
resnet_model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [3]:
resnet_model.fc = nn.Linear(in_features=512, out_features=10, bias=True)

In [4]:
resnet_model.fc

Linear(in_features=512, out_features=10, bias=True)

In [5]:
DATA_PATH = Path("../data")
PATH = DATA_PATH / "mnist"
FILENAME = "mnist.pkl.gz"

# Los arrays de las imagenes fueron guardados en un archivo formato pickle, que se utiliza para persistir variable en Python
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

In [6]:
x_train.shape

(50000, 784)

In [7]:
def add_channels(x):
    x=x.reshape(-1,28,28)
    x=np.stack([x,x,x],axis=1)
    return x
x_train=add_channels(x_train)
x_valid=add_channels(x_valid)
x_train.shape, x_valid.shape

((50000, 3, 28, 28), (10000, 3, 28, 28))

In [8]:
def accuracy(probs, target):
    class_predictions = torch.argmax(probs, dim=1)
    return (class_predictions == target).float().mean()

In [9]:
train_dataset = TensorDataset(torch.tensor(x_train), torch.tensor(y_train))
valid_dataset = TensorDataset(torch.tensor(x_valid), torch.tensor(y_valid))

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=True)

In [11]:
x,y=next(iter(train_dataloader))
x.shape, y.shape

(torch.Size([32, 3, 28, 28]), torch.Size([32]))

In [12]:
z=resnet_model(x)
z.shape

torch.Size([32, 10])

In [13]:
loss_fn = nn.CrossEntropyLoss()

In [14]:
all_weights = list(resnet_model.parameters())

In [15]:
def freeze(ws, unfreeze=False):
    for w in ws:
        w.requires_grad=unfreeze
freeze(all_weights)

In [16]:
resnet_model.fc.weight.requires_grad

False

In [17]:
freeze(resnet_model.fc.parameters(),unfreeze=True)

In [18]:
resnet_model.fc.weight.requires_grad

True

In [19]:
optimizer_only_fc = torch.optim.Adam(resnet_model.fc.parameters(), lr=0.001)

In [20]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
resnet_model = resnet_model.to(mps_device)

In [21]:
def train(n_epochs, opt):
    for idx_epoch in range(n_epochs):
        # Loop de entrenamiento
        loss_train_sum = 0
        n_batches_train = 0
        resnet_model.train()
        for x_train_batch, y_train_batch in tqdm(train_dataloader):
            x_train_batch = x_train_batch.to(mps_device)
            y_train_batch = y_train_batch.to(mps_device)
            predictions = resnet_model(x_train_batch)
            loss = loss_fn(predictions, y_train_batch)
            loss_train_sum += loss.item()
            n_batches_train += 1
            opt.zero_grad()
            loss.backward()
            opt.step()
    
        # Evaluamos los datos en validación
        loss_validation_sum = 0
        accuracy_sum = 0
        n_batches_valid = 0
        resnet_model.eval()
        for x_valid_batch, y_valid_batch in tqdm(valid_dataloader):
            x_valid_batch = x_valid_batch.to(mps_device)
            y_valid_batch = y_valid_batch.to(mps_device)
            predictions = resnet_model(x_valid_batch)
            loss = loss_fn(predictions, y_valid_batch)
            loss_validation_sum += loss.item()
            accuracy_sum += accuracy(predictions, y_valid_batch).item()
            n_batches_valid += 1
        
        # Imprimimos el loss en train y validación y la métrica (siempre en validación)
        accuracy_validation = accuracy_sum / n_batches_valid
        loss_validation = loss_validation_sum / n_batches_valid
        train_validation = loss_train_sum / n_batches_train
        print(f'epoch {idx_epoch} | train loss {loss_validation} | validation loss {train_validation} | accuracy {accuracy_validation}')

In [22]:
train(3,optimizer_only_fc)

100%|███████████████████████| 1563/1563 [00:21<00:00, 74.21it/s]
100%|█████████████████████████| 313/313 [00:03<00:00, 87.26it/s]


epoch 0 | train loss 0.7218840053668037 | validation loss 0.9062333932803063 | accuracy 0.7687699680511182


100%|███████████████████████| 1563/1563 [00:19<00:00, 80.26it/s]
100%|█████████████████████████| 313/313 [00:03<00:00, 88.47it/s]


epoch 1 | train loss 0.7194936341180588 | validation loss 0.732578944650813 | accuracy 0.7655750798722045


100%|███████████████████████| 1563/1563 [00:20<00:00, 78.02it/s]
100%|█████████████████████████| 313/313 [00:03<00:00, 91.33it/s]

epoch 2 | train loss 0.6767344293883815 | validation loss 0.7192259843789532 | accuracy 0.7822484025559105





In [23]:
freeze(resnet_model.parameters(),unfreeze=True)

In [24]:
resnet_model.conv1.weight.requires_grad

True

In [25]:
optimizer = torch.optim.Adam(resnet_model.parameters(), lr=0.001)

In [26]:
train(3,optimizer)

100%|███████████████████████| 1563/1563 [01:32<00:00, 16.88it/s]
100%|█████████████████████████| 313/313 [00:03<00:00, 85.94it/s]


epoch 0 | train loss 0.08062260638633832 | validation loss 0.15287834833905975 | accuracy 0.981729233226837


100%|███████████████████████| 1563/1563 [01:31<00:00, 17.05it/s]
100%|█████████████████████████| 313/313 [00:03<00:00, 92.30it/s]


epoch 1 | train loss 0.09859171823357431 | validation loss 0.11210157837327993 | accuracy 0.9720447284345048


100%|███████████████████████| 1563/1563 [01:31<00:00, 17.17it/s]
100%|█████████████████████████| 313/313 [00:03<00:00, 89.84it/s]

epoch 2 | train loss 0.049777310070828695 | validation loss 0.05895828608668815 | accuracy 0.9848242811501597



