In [60]:
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from torch import nn as nn
from torchvision.datasets import Food101
from trak import TRAKer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [61]:
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-18")
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18")

In [62]:
def preprocess_image(image):
    image = transforms.functional.pil_to_tensor(image)
    processed_image = processor.preprocess(image)["pixel_values"][0]
    return torch.from_numpy(processed_image)

train_dataset = Food101("data/food-101", split="train", transform=preprocess_image, download=True)
test_dataset = Food101("data/food-101", split="test", transform=preprocess_image, download=True)

In [63]:
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

for X, y in train_dl:
    break

X.shape, y.shape

(torch.Size([64, 3, 224, 224]), torch.Size([64]))

In [64]:
model

ResNetForImageClassification(
  (resnet): ResNetModel(
    (embedder): ResNetEmbeddings(
      (embedder): ResNetConvLayer(
        (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (encoder): ResNetEncoder(
      (stages): ModuleList(
        (0): ResNetStage(
          (layers): Sequential(
            (0): ResNetBasicLayer(
              (shortcut): Identity()
              (layer): Sequential(
                (0): ResNetConvLayer(
                  (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                  (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (activation): ReLU()
           

In [65]:
num_classes = 101

model.classifier = nn.Sequential(
                    nn.Flatten(start_dim=1, end_dim=-1),
                    nn.Linear(in_features=512, out_features=num_classes))
for param in model.classifier.parameters():
        param.requires_grad = True

model.num_labels = 101

In [66]:
class EarlyStopping:
    def __init__(self, patience: int, min_delta: float):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.stop = False
        self.val_loss_min = 1e10
        self.best_parameters = None

    def __call__(self, model, val_loss):
        if val_loss < self.val_loss_min - self.min_delta:
            self.val_loss_min = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.stop = True
                self.best_parameters = model.state_dict()

    def get_best_model_parameters(self):
        return self.best_parameters

In [67]:
def train(
        model: torch.nn.Module,
        epochs: int,
        optimizer: torch.optim.Optimizer,
        early_stopping: EarlyStopping = None,
        criterion=None
    ):
    criterion = criterion or torch.nn.CrossEntropyLoss()
    model = model.to(device)

    for epoch in range(epochs):
        running_loss = 0.0

        model.train()

        train_correct = 0
        train_outputs = 0

        for i, data in enumerate(tqdm(train_dl), 0):
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs.logits, labels)

            loss.backward()
            optimizer.step()

            train_correct += (torch.argmax(outputs.logits, dim=-1) == labels).sum().item()
            train_outputs += outputs.logits.shape[0]

            running_loss += loss

        model.eval()
        total_correct = 0
        total_outputs = 0
        val_loss = 0.0

        with torch.no_grad():
            for i, data in enumerate(tqdm(test_dl), 0):
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                
                outputs = model(images)
                val_loss += criterion(outputs.logits, labels).item()
                correct = (torch.argmax(outputs.logits, dim=-1) == labels).sum().item()

                total_correct += correct
                total_outputs += outputs.logits.shape[0]

        print(f"[Epoch {epoch + 1}] Loss: {running_loss / i:.3f}, Train Acc: {train_correct/train_outputs:.3f}," +
              f"Valid loss: {val_loss/len(test_dl):.3f} Valid Acc: {total_correct/total_outputs:.3f}")
        
        if early_stopping:
            early_stopping(model, val_loss)
            if early_stopping.stop:
                print(f"Early stopping at epoch {epoch + 1}")
                break
        
    if early_stopping:
        model.load_state_dict(early_stopping.get_best_model_parameters())

In [68]:
num_epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
early_stopping = EarlyStopping(patience=3, min_delta=0.001)

train(model, num_epochs, optimizer, early_stopping=early_stopping)  

100%|██████████| 1184/1184 [31:55<00:00,  1.62s/it]
100%|██████████| 395/395 [08:26<00:00,  1.28s/it]


[Epoch 1] Loss: 6.890, Train Acc: 0.432,Valid loss: 1.939 Valid Acc: 0.504


100%|██████████| 1184/1184 [19:35<00:00,  1.01it/s]
100%|██████████| 395/395 [07:20<00:00,  1.12s/it]


[Epoch 2] Loss: 4.591, Train Acc: 0.599,Valid loss: 1.416 Valid Acc: 0.624


100%|██████████| 1184/1184 [18:44<00:00,  1.05it/s]
100%|██████████| 395/395 [05:46<00:00,  1.14it/s]


[Epoch 3] Loss: 3.471, Train Acc: 0.688,Valid loss: 1.366 Valid Acc: 0.642


100%|██████████| 1184/1184 [19:11<00:00,  1.03it/s]
100%|██████████| 395/395 [12:23<00:00,  1.88s/it]


[Epoch 4] Loss: 2.609, Train Acc: 0.758,Valid loss: 1.456 Valid Acc: 0.634


100%|██████████| 1184/1184 [25:38<00:00,  1.30s/it]
100%|██████████| 395/395 [07:28<00:00,  1.13s/it]


[Epoch 5] Loss: 1.895, Train Acc: 0.818,Valid loss: 1.464 Valid Acc: 0.645


100%|██████████| 1184/1184 [19:18<00:00,  1.02it/s]
100%|██████████| 395/395 [05:10<00:00,  1.27it/s]

Early stopping at epoch 6





In [69]:
torch.save(model.state_dict(), "model_finetuned_baseline.pth")

In [104]:
# TODO finetune the model on food101 dataset -> use TRAK -> finetune again on the base model -> look at the results

In [105]:
traker = TRAKer(model, task="image_classification", train_set_size=len(dataset))

ERROR:TRAK:Could not use CudaProjector.
Reason: No module named 'fast_jl'
ERROR:TRAK:Defaulting to BasicProjector.
INFO:STORE:No existing model IDs in C:\Users\kamil\OneDrive\Pulpit\przedmioty\semestr 8\automating science\trak-for-automating-science\trak_results.
INFO:STORE:No existing TRAK scores in C:\Users\kamil\OneDrive\Pulpit\przedmioty\semestr 8\automating science\trak-for-automating-science\trak_results.
