In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
!pip install traker



In [3]:
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from torch import nn as nn
from torchvision.datasets import Food101
from trak import TRAKer
from trak import modelout_functions
from collections.abc import Iterable

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-18")
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [5]:
def preprocess_image(image):
    image = transforms.functional.pil_to_tensor(image)
    processed_image = processor.preprocess(image)["pixel_values"][0]
    return torch.from_numpy(processed_image)

train_dataset = Food101("data/food-101", split="train", transform=preprocess_image, download=False)
test_dataset = Food101("data/food-101", split="test", transform=preprocess_image, download=False)

In [6]:
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

for X, y in train_dl:
    break

X.shape, y.shape

(torch.Size([64, 3, 224, 224]), torch.Size([64]))

In [7]:
model

ResNetForImageClassification(
  (resnet): ResNetModel(
    (embedder): ResNetEmbeddings(
      (embedder): ResNetConvLayer(
        (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (encoder): ResNetEncoder(
      (stages): ModuleList(
        (0): ResNetStage(
          (layers): Sequential(
            (0): ResNetBasicLayer(
              (shortcut): Identity()
              (layer): Sequential(
                (0): ResNetConvLayer(
                  (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                  (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (activation): ReLU()
           

In [8]:
num_classes = 101

model.classifier = nn.Sequential(
                    nn.Flatten(start_dim=1, end_dim=-1),
                    nn.Linear(in_features=512, out_features=num_classes))
for param in model.classifier.parameters():
        param.requires_grad = True

model.num_labels = 101

In [None]:
class EarlyStopping:
    def __init__(self, patience: int, min_delta: float):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.stop = False
        self.val_loss_min = 1e10
        self.best_parameters = None

    def __call__(self, model, val_loss):
        if val_loss < self.val_loss_min - self.min_delta:
            self.val_loss_min = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.stop = True
                self.best_parameters = model.state_dict()

    def get_best_model_parameters(self):
        return self.best_parameters

In [None]:
def train(
        model: torch.nn.Module,
        epochs: int,
        optimizer: torch.optim.Optimizer,
        early_stopping: EarlyStopping = None,
        criterion=None
    ):
    criterion = criterion or torch.nn.CrossEntropyLoss()
    model = model.to(device)

    for epoch in range(epochs):
        running_loss = 0.0

        model.train()

        train_correct = 0
        train_outputs = 0

        for i, data in enumerate(tqdm(train_dl), 0):
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs.logits, labels)

            loss.backward()
            optimizer.step()

            train_correct += (torch.argmax(outputs.logits, dim=-1) == labels).sum().item()
            train_outputs += outputs.logits.shape[0]

            running_loss += loss

        model.eval()
        total_correct = 0
        total_outputs = 0
        val_loss = 0.0

        with torch.no_grad():
            for i, data in enumerate(tqdm(test_dl), 0):
                images, labels = data
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                val_loss += criterion(outputs.logits, labels).item()
                correct = (torch.argmax(outputs.logits, dim=-1) == labels).sum().item()

                total_correct += correct
                total_outputs += outputs.logits.shape[0]

        print(f"[Epoch {epoch + 1}] Loss: {running_loss / i:.3f}, Train Acc: {train_correct/train_outputs:.3f}," +
              f"Valid loss: {val_loss/len(test_dl):.3f} Valid Acc: {total_correct/total_outputs:.3f}")

        if early_stopping:
            early_stopping(model, val_loss)
            if early_stopping.stop:
                print(f"Early stopping at epoch {epoch + 1}")
                break

    if early_stopping:
        model.load_state_dict(early_stopping.get_best_model_parameters())

In [None]:
num_epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
early_stopping = EarlyStopping(patience=3, min_delta=0.001)

train(model, num_epochs, optimizer, early_stopping=early_stopping)

In [None]:
torch.save(model.state_dict(), "model_finetuned_baseline.pth")

In [None]:
# TODO finetune the model on food101 dataset -> use TRAK -> finetune again on the base model -> look at the results

In [7]:
finetuned_path = '/content/drive/MyDrive/automating_science/model_finetuned_baseline.pth' #model_finetuned_baseline.pth
checkpoint = torch.load(finetuned_path,  map_location=device)

In [8]:
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18")

num_classes = 101

model.classifier = nn.Sequential(
                    nn.Flatten(start_dim=1, end_dim=-1),
                    nn.Linear(in_features=512, out_features=num_classes))

model.num_labels = num_classes

model.load_state_dict(checkpoint)

<All keys matched successfully>

In [9]:
model = model.to(device)

In [10]:
train_dl_no_shuffle = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)
test_dl_no_shuffle = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [15]:
class ResNetOutput(modelout_functions.AbstractModelOutput):
    def get_output(
                model: torch.nn.Module,
                weights: Iterable[torch.Tensor],
                buffers: Iterable[torch.Tensor],
                image: torch.Tensor,
                label: torch.Tensor
      ):
      for key, value in weights.items():
        weights[key] = weights[key].to(device)
      output = torch.func.functional_call(model, (weights, buffers), image.unsqueeze(0))
      logits = output.logits #our change
      bindex = torch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
      logits_correct = logits[bindex, label.unsqueeze(0)]

      cloned_logits = logits.clone()

      cloned_logits[bindex, label.unsqueeze(0)] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype)

      margins = logits_correct - cloned_logits.logsumexp(dim=-1)
      return margins.sum()
    def get_out_to_loss_grad(self, model, weights, buffers, batch):
      for key, value in weights.items():
        weights[key] = weights[key].to(device)
      images, labels = batch
      output = torch.func.functional_call(model, (weights, buffers), images)
      logits = output.logits #our change

      ps = self.softmax(logits / self.loss_temperature)[torch.arange(logits.size(0)), labels]
      return (1 - ps).clone().detach().unsqueeze(-1)

In [16]:
traker = TRAKer(model=model,
                task=ResNetOutput,
                train_set_size=len(train_dl_no_shuffle.dataset))

ERROR:TRAK:Could not use CudaProjector.
Reason: No module named 'fast_jl'
ERROR:TRAK:Defaulting to BasicProjector.
INFO:STORE:Existing model IDs in /content/trak_results: [0]
INFO:STORE:No model IDs in /content/trak_results have been finalized.
INFO:STORE:No existing TRAK scores in /content/trak_results.


In [17]:
model_id = 0
traker.load_checkpoint(torch.load(finetuned_path,  map_location=device), model_id=0)

In [19]:
for data in train_dl_no_shuffle:
    data = [xy.cuda() for xy in data]

    traker.featurize(batch=data, num_samples=data[0].shape[0])

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


OutOfMemoryError: CUDA out of memory. Tried to allocate 576.00 MiB. GPU 