Basado en: https://zablo.net/blog/post/pytorch-resnet-mnist-jupyter-notebook-2021/

TODO 
- probar otros modelos (resnet50, etc)
- poner nuestra propia CELoss

In [29]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision.models import resnet18
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch import nn
from torch.utils.data import DataLoader, default_collate
import torch.nn.functional as F

In [40]:
# For producing bias
def my_collate(batch):
    modified_batch = []
    for item in batch:
        image, label = item
        if label == 1 or label == 2: # only train in these numbers, but test on all!
            modified_batch.append(item)
    return default_collate(modified_batch)

train_ds = MNIST("mnist", train=True, download=True, transform=ToTensor())
test_ds = MNIST("mnist", train=False, download=True, transform=ToTensor())
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn = my_collate)
test_dl = DataLoader(test_ds, batch_size=64)

In [38]:
print(len(train_dl), len(test_dl))

938 157


Training (with **Pytorch_Lightning**)

In [None]:
!pip3 install pytorch_lightning
import pytorch_lightning as pl

In [13]:
# Cross-Entropy Custom
def my_cross_entropy(x, y):
    log_prob = -1.0 * F.log_softmax(x, 1)
    loss = log_prob.gather(1, y.unsqueeze(1))
    loss = loss.mean()
    return loss

class ResNetMNIST(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = resnet18(num_classes=10)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        # self.loss = myCELoss()
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_no):
        x, y = batch
        logits = self(x)
        loss = my_cross_entropy(logits, y)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.RMSprop(self.parameters(), lr=0.005)

In [39]:
model = ResNetMNIST()
trainer = pl.Trainer(max_epochs=5, devices=1, accelerator="gpu")
trainer.fit(model, train_dl)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.701    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


In [41]:
trainer.save_checkpoint("resnet18_mnist.pt")

In [42]:
def get_prediction(x, model: pl.LightningModule):
    model.freeze() # prepares model for predicting
    probabilities = torch.softmax(model(x), dim=1)
    predicted_class = torch.argmax(probabilities, dim=1)
    return predicted_class, probabilities

inference_model = ResNetMNIST.load_from_checkpoint("resnet18_mnist.pt", map_location="cuda")

In [43]:
from tqdm.autonotebook import tqdm

true_y, pred_y = [], []
for batch in tqdm(iter(test_dl), total=len(test_dl)):
  x, y = batch
  true_y.extend(y)
  preds, probs = get_prediction(x, inference_model)
  pred_y.extend(preds.cpu())

  0%|          | 0/157 [00:00<?, ?it/s]

In [44]:
from sklearn.metrics import classification_report
print(classification_report(true_y, pred_y, digits=3))

              precision    recall  f1-score   support

           0      0.000     0.000     0.000       980
           1      0.392     0.996     0.563      1135
           2      0.145     0.999     0.253      1032
           3      0.000     0.000     0.000      1010
           4      0.000     0.000     0.000       982
           5      0.000     0.000     0.000       892
           6      0.000     0.000     0.000       958
           7      0.000     0.000     0.000      1028
           8      0.000     0.000     0.000       974
           9      0.000     0.000     0.000      1009

    accuracy                          0.216     10000
   macro avg      0.054     0.200     0.082     10000
weighted avg      0.059     0.216     0.090     10000



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
