# ResNet Fine-Tuning Demo

*Based on: https://huggingface.co/blog/fine-tune-vit*

Before running, download the [TrashNet dataset](https://github.com/garythung/trashnet/raw/master/data/dataset-resized.zip) and extract it to the 'dataset' folder.

In [None]:
%pip install huggingface
%pip install datasets
%pip install pillow
%pip install transformers
%pip install scikit-learn
%pip install transformers[torch]
%pip install tensorboardX
# PyTorch + CUDA should be installed manually

## Dataset Processing

In [1]:
from datasets import load_dataset

ds = load_dataset("imagefolder", data_dir="dataset", split="train")
ds = ds.train_test_split(test_size=0.2, seed=512)

  from .autonotebook import tqdm as notebook_tqdm
Resolving data files: 100%|██████████| 2527/2527 [00:00<00:00, 75841.54it/s]


In [2]:
from transformers import AutoImageProcessor
from torchvision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, ToTensor, Normalize, Resize, CenterCrop, ConvertImageDtype, InterpolationMode, ColorJitter
import torch

model_name_or_path = 'microsoft/resnet-50'
processor = AutoImageProcessor.from_pretrained(model_name_or_path)

_train_transform = Compose([
    RandomResizedCrop(processor.size['shortest_edge'], interpolation=InterpolationMode.BICUBIC, antialias=None),
    RandomHorizontalFlip(),
    ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    ToTensor(),
    ConvertImageDtype(torch.float),
    Normalize(
        mean=processor.image_mean,
        std=processor.image_std
    )
])

_val_transform = Compose([
    Resize(processor.size['shortest_edge'], interpolation=InterpolationMode.BICUBIC, antialias=None),
    CenterCrop(processor.size['shortest_edge']),
    ToTensor(),
    ConvertImageDtype(torch.float),
    Normalize(
        mean=processor.image_mean,
        std=processor.image_std
    )
])

def train_transform(example_batch):
    example_batch['pixel_values'] = [
        _train_transform(img.convert('RGB')) for img in example_batch['image']
    ]
    return example_batch

def val_transform(example_batch):
    example_batch['pixel_values'] = [
        _val_transform(img.convert('RGB')) for img in example_batch['image']
    ]
    return example_batch

ds['train'].set_transform(train_transform)
ds['test'].set_transform(val_transform)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


## Training

In [3]:
import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [4]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

  metric = load_metric("accuracy")


In [5]:
from transformers import ResNetForImageClassification

labels = ds['train'].features['label'].names

model = ResNetForImageClassification.from_pretrained(
    model_name_or_path,
    ignore_mismatched_sizes=True,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([6, 2048]) in the model instantiated
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([6]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
from transformers import TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
import os

if os.path.exists('./resnet-50-trash'):
    last_checkpoint = get_last_checkpoint("./resnet-50-trash")
else:
    last_checkpoint = None

training_args = TrainingArguments(
  output_dir="./resnet-50-trash",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=32,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=20,
  learning_rate=1e-4,
  save_total_limit=5,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
  resume_from_checkpoint=last_checkpoint
)

In [11]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=processor,
)

In [12]:
train_results = trainer.train(resume_from_checkpoint=last_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

 99%|█████████▉| 4020/4064 [00:03<00:00, 7162.89it/s]

{'loss': 0.2744, 'learning_rate': 1.3041338582677166e-06, 'epoch': 31.65}


 99%|█████████▉| 4040/4064 [00:06<00:00, 7162.89it/s]

{'loss': 0.3559, 'learning_rate': 8.366141732283466e-07, 'epoch': 31.81}


100%|█████████▉| 4060/4064 [00:08<00:00, 7162.89it/s]

{'loss': 0.2279, 'learning_rate': 3.44488188976378e-07, 'epoch': 31.97}


100%|██████████| 4064/4064 [00:09<00:00, 436.60it/s] 

{'train_runtime': 9.3203, 'train_samples_per_second': 6938.838, 'train_steps_per_second': 436.038, 'train_loss': 0.0046543464592591986, 'epoch': 32.0}
***** train metrics *****
  epoch                    =       32.0
  train_loss               =     0.0047
  train_runtime            = 0:00:09.32
  train_samples_per_second =   6938.838
  train_steps_per_second   =    436.038





In [13]:
metrics = trainer.evaluate(ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

100%|██████████| 64/64 [00:01<00:00, 33.68it/s]

***** eval metrics *****
  epoch                   =       32.0
  eval_accuracy           =      0.915
  eval_loss               =     0.2639
  eval_runtime            = 0:00:02.20
  eval_samples_per_second =    229.345
  eval_steps_per_second   =     29.008



