In [1]:
%matplotlib inline
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
import cv2
import numpy as np
from transformers import DefaultDataCollator
from transformers import TFViTForImageClassification, create_optimizer
from IPython.display import HTML
from transformers import BitImageProcessor, BitForImageClassification
import torch
from datasets import load_dataset

2023-12-18 16:29:29.206883: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-18 16:29:29.228542: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-18 16:29:29.228567: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-18 16:29:29.229034: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-18 16:29:29.232026: I tensorflow/core/platform/cpu_feature_guar

# Import Data

In [2]:
train_folder = ImageFolder(root="train")
val_folder = ImageFolder(root="train")
test_set = ImageFolder(root="test")
train_set, val_set = torch.utils.data.random_split(train_folder, [0.8,0.2])
val_set.dataset = val_folder

## Define data augmentations
The images are augmented every time they are called. So for every epoch, an image is augmented in a different way.

In [3]:
transform_train = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.RandomRotation((-35, 35)),
    transforms.CenterCrop(size=(900, 900)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Resize((384, 384)),
    
    
])
transform_val = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.CenterCrop(size=(900, 900)),
    transforms.ToTensor(),
    transforms.Resize((384, 384)),
])

train_set.dataset.transform=transform_train
val_set.dataset.transform=transform_val
test_set.transform=transform_val

# Model
The model used is google's Big Transfer model published in 2020 which is based on a ResNetV2.

In [4]:
from transformers import CvtForImageClassification

# create optimizer wight weigh decay
model_name = "microsoft/cvt-21-384-22k"
model = CvtForImageClassification.from_pretrained(
    model_name,
    num_labels=len(train_folder.class_to_idx),
    id2label=train_folder.class_to_idx,
    label2id={v: k for k, v in train_folder.class_to_idx.items()},
    ignore_mismatched_sizes=True
)

Some weights of CvtForImageClassification were not initialized from the model checkpoint at microsoft/cvt-21-384-22k and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 384]) in the checkpoint and torch.Size([5, 384]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Train

In [5]:
from transformers import TrainingArguments

epochs = 30
bs = 16

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=bs,
  evaluation_strategy="steps",
  num_train_epochs=epochs,
  fp16=True,
  save_steps=100,
  eval_steps=50,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='wandb',
  load_best_model_at_end=True,
)

In [6]:
from transformers import Trainer, EarlyStoppingCallback, AutoImageProcessor

processor = AutoImageProcessor.from_pretrained(model_name)
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x[0] for x in batch]),
        'labels': torch.tensor([x[1] for x in batch])
    }

from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_set,
    eval_dataset=val_set,
    tokenizer=processor,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=5)]
)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
  metric = load_metric("accuracy")


In [7]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

[34m[1mwandb[0m: Currently logged in as: [33mjmettner[0m ([33mgmllm[0m). Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss




KeyboardInterrupt: 

In [None]:
metrics = trainer.evaluate(test_set)
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)