<a href="https://www.kaggle.com/code/ibombonato/vit-transformers-sorghum-100-starter-0-491?scriptVersionId=91060985" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Simple baseline with ViT Transformers + Hugging Face + Lightning

This notebook shows the use of Hugging Face, Pytorch and Pytorch Lightning to train a classifier with ViT Transformers architecture.

It is based/inspired on the [HuggingPics](https://github.com/nateraw/huggingpics) project and uses a rezised and adjusted labels by folder dataset

https://www.kaggle.com/ibombonato/sorghum-100-cultivar-512x512-png-imagefolder


TO DO:
- ~~Make inference faster~~
- Find best learning rate
- ~~Add augmentation~~
- ~~Better split strategy~~
- Add CrossValidation
- ~~Make Wandb Work (I am getting an error right now with the logger in self.experiment.config.update(params, allow_val_change=True))~~
- Load from checkpoint?

**If it helps you in some manner, please upvote the dataset and the notebook :D**

### Load libs and minimal setup

In [None]:
!pip install -q timm
!pip install -q --upgrade wandb wandb[service]

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from tqdm import tqdm
from tqdm.auto import tqdm
from sklearn.model_selection import ShuffleSplit
from PIL import Image, UnidentifiedImageError
from pathlib import Path

tqdm.pandas()

In [None]:
#Confirm that a GPU is available
!nvidia-smi

In [None]:
ORIGIN_FOLDER = "../input/sorghum-100-cultivar-512x512-png-imagefolder/images"
USE_WANDB = True
EPOCHS = 50
MODEL_NAME = 'google/vit-base-patch16-224-in21k'

In [None]:
train_raw = pd.read_csv("../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv")

In [None]:
import matplotlib.pyplot as plt
import shutil
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy
from torchvision.datasets import ImageFolder
from transformers import AutoFeatureExtractor, ViTForImageClassification
from torchvision.transforms import ToTensor
import torchvision
from torchvision.io import read_image
import random
from timm.data import ImageDataset
from sklearn.model_selection import StratifiedShuffleSplit

## Loading the data

Since we are using a [dataset](https://www.kaggle.com/ibombonato/sorghum-100-cultivar-512x512-png-imagefolder) that has all imades grouped by folders/labels, we can use `ImageFolder` from `torchvision.datasets` to load the dataset and simplify the process.

~~We will use `random_split` from Pytorch to split the Images into train and validation sets.~~

In [None]:
all_ds = ImageFolder(Path(ORIGIN_FOLDER, "train"))

Lets add some transformation to the images, this will help our model to generalize better and also help with overfit

In [None]:
from torchvision import transforms
# For training, we add some augmentation. Networks are too powerful and would overfit.

feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

train_transform = transforms.Compose(
    [
        transforms.RandomAffine(0.75),
        transforms.ColorJitter(brightness=0.5, contrast=0.25),
        transforms.RandomAutocontrast(0.25),
        transforms.RandomRotation(0.15),
        transforms.RandomResizedCrop(feature_extractor.size, scale=(0.1, 1), ratio=(0.5, 2)),
        transforms.RandomHorizontalFlip(),
    ]
)

val_transform = transforms.Compose(
    [
        transforms.Resize(feature_extractor.size),
    ]
)

We will use `StratifiedShuffleSplit` from `sklearn` to split the Images into train and validation sets in a stratified way, label proportions are keept in the split.

We also need to create a `Subset` that we can use differente transforms for train and validation

In [None]:
# https://stackoverflow.com/questions/51782021/how-to-use-different-data-augmentation-for-subsets-in-pytorch
# Subset with transform, so we can have a train and val transform
class Subset(Dataset):
    r"""
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
        transform (Transformation): Vision Transforms to apply in the image
    """
    def __init__(self, dataset, indices, transform):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __getitem__(self, idx):
        im, labels = self.dataset[self.indices[idx]]
        if self.transform:
            im = self.transform(im)
        return im, labels

    def __len__(self):
        return len(self.indices)

In [None]:
def train_test_dataset(dataset, test_split, train_transform, val_transform):
    X = dataset.imgs
    y = dataset.targets
    sss = StratifiedShuffleSplit(n_splits=1, test_size=test_split, random_state=42)
    train_idx, val_idx = next(sss.split(X, y))
    
    train_ds = Subset(dataset, train_idx, train_transform)
    val_ds = Subset(dataset, val_idx, val_transform)
    return train_ds, val_ds

train_ds, val_ds = train_test_dataset(all_ds, 0.2, train_transform, val_transform)

In [None]:
# plot a random image from train set
img, label = train_ds[random.randint(0, len(train_ds))]
plt.imshow(img, cmap="gray")

In [None]:
# plot a random image from validation set
img, label = val_ds[random.randint(0, len(val_ds))]
plt.imshow(img, cmap="gray")

Since pytorch will convert targets to numeric, we will map ids to labels and labels to ids, so we can get/acess the class names in the future.

In [None]:
label2id = {}
id2label = {}

for i, class_name in enumerate(all_ds.classes):
    label2id[class_name] = str(i)
    id2label[str(i)] = class_name

In [None]:
class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor
 
    def __call__(self, batch):
        encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
        encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
        return encodings

In [None]:
#feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label2id),
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

collator = ImageClassificationCollator(feature_extractor)
train_loader = DataLoader(train_ds, batch_size=8, collate_fn=collator, num_workers=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=16, collate_fn=collator, num_workers=2)

In [None]:
class Classifier(pl.LightningModule):

    def __init__(self, model, lr: float = 2e-5, **kwargs):
        super().__init__()
        self.save_hyperparameters('lr', *list(kwargs))
        self.model = model
        self.forward = self.model.forward
        self.val_acc = Accuracy()

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"train_loss", outputs.loss)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"val_loss", outputs.loss)
        acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
        self.log(f"val_acc", acc, prog_bar=True)
        return outputs.loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

Setting up teh logger with Wandb

In [None]:
# FOR THIS TO WORK, YOU NEED TO SET YOUR API KEY IN THE KAGGLE SECRETS ENVIRONMENT!
import os
from pytorch_lightning.loggers import WandbLogger
from kaggle_secrets import UserSecretsClient
import wandb

if USE_WANDB:
    project_name = "kaggle-sorghum-100-cultivar"
    user_secrets = UserSecretsClient()
    wandb.require(experiment="service")
    wandb.login(key=user_secrets.get_secret("WANDB_API_KEY"))
       
    model_logger = WandbLogger(project=project_name, log_model='all', config={"epochs": EPOCHS})
else:
    model_logger=None

In [None]:
model_logger.experiment.config.update({"epochs": EPOCHS}, allow_val_change=True)
model_logger.experiment.config

Train the model

In [None]:
pl.seed_everything(42)
classifier = Classifier(model, lr=2e-5)
trainer = pl.Trainer(gpus=1, precision=16, max_epochs=EPOCHS, logger=model_logger)

In [None]:
trainer.fit(classifier, train_loader, val_loader)

In [None]:
trainer.save_checkpoint(f"cultivar_baseline_epoch_{EPOCHS}_vit_transformer.ckpt")

In [None]:
if USE_WANDB: wandb.finish()

# Make predictions

Now we will make predictions on the test set.

After make some adjustments, I could score it via batch and reduce the time **from 4 hours to 6 minutes :D**

~~A working to improve here is to score the test set via batch and not one to one.~~

The old code are collapsed bellow for reference.

In [None]:
# OLD CODE, JUST FOR REFERENCE, DO NOT USE IT!
# It takes over 4 hours to do inference on all test images :-/

# def pred_image(img):
    
#     if not Path(img).exists(): return ''
    
#     im = Image.open(img)
#     # Transform our image and pass it through the model
#     inputs = feature_extractor(im, return_tensors='pt')
#     with torch.no_grad():
#         output = model(**inputs)

#     # Predicted Class probabilities
#     proba = output.logits.softmax(1)

#     # Predicted Classes
#     preds = proba.argmax(1)

#     return model.config.id2label[str(preds.item())]

# model.eval()

# TEST_FOLDER = "../input/sorghum-id-fgvc-9/test"

# test_df = pd.read_csv("../input/sorghum-id-fgvc-9/sample_submission.csv")

# test_df['cultivar'] = test_df.filename.progress_apply(lambda x: pred_image(f"{TEST_FOLDER}/{x}"))

# test_df.to_csv("submission.csv", index = False)


In [None]:
TEST_FOLDER = "../input/sorghum-100-cultivar-512x512-png-imagefolder/images/test"

test_ds = ImageDataset(Path(TEST_FOLDER))
test_dl = DataLoader(test_ds, batch_size=32, collate_fn=collator, num_workers=2)

In [None]:
# plot a random image from test set
img, label = test_ds[random.randint(0, len(test_ds))]
plt.imshow(img, cmap="gray")

In [None]:
model.cuda()
model.eval()

def batch_predictions(dl, ds, id2label):
    predictions = []
    for batch in tqdm(dl):
        image = batch['pixel_values'].cuda()
        with torch.no_grad():
            outputs = model(image)
            preds = outputs.logits.softmax(1).argmax(1).detach().cpu().numpy()
            predictions.append(preds)
        
    all_preds = []
    for batch in predictions:
        for prediction in batch:
            all_preds.append(id2label[str(prediction)])

    return all_preds, ds.filenames()

In [None]:
batch_preds, batch_filenames = batch_predictions(test_dl, test_ds, id2label)
df_preds = pd.DataFrame({'filename': batch_filenames, "cultivar": batch_preds})
df_preds.head()

# Submisson

At the moment, the testset or the sample_submission are broken and its not possible to submit. As soon as the organizers fix it, I will update with the submission.


In [None]:
test_df = pd.read_csv("../input/sorghum-id-fgvc-9/sample_submission.csv")

submission_df = pd.merge(test_df[['filename']], df_preds, how='inner', on='filename')

submission_df.to_csv("submission.csv", index = False)

submission_df.head()

## TO DO:

- ~~Make inference faster~~
- Find best learning rate
- ~~Add augmentation~~
- ~~Better split strategy~~
- Add CrossValidation
- ~~Make Wandb Work (I am getting an error right now with the logger in `self.experiment.config.update(params, allow_val_change=True)`)~~
- Load from checkpoint?

## If it helps you in some manner, please upvote the dataset and the notebook :D