<a href="https://www.kaggle.com/code/ibombonato/vit-transformers-inference-from-wandb-checkpoint?scriptVersionId=91059505" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Inference using checkpoint from Wandb + Lightning

This notebook is a companion for the Vit Transformer stater notebook:

[ViT Transformers - Sorghum 100 cultivar - Starter](https://www.kaggle.com/code/ibombonato/vit-transformers-sorghum-100-cultivar-starter)

Now that we have our model checkpoint saved on Wandb, we will do inference based on checkpoint loaded from Wandb. That way we can compute and fine tune models outside kaggle environment and this **will help us make ensemble of models** later.

**If it helps you in some manner, please upvote the dataset and the notebooks :D**

![image.png](attachment:40f45dd0-4c58-4034-9b51-a5cd34520e23.png)

### Load libs and minimal setup

In [1]:
!pip install -q timm
!pip install -q --upgrade wandb



In [2]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path

In [3]:
#Confirm that a GPU is available
!nvidia-smi

Wed Mar 23 18:39:12 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.04   Driver Version: 450.119.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

Grab your checkpoint file on wandb site and set the variable `CHECKPOINT_WANDB`:

https://wandb.ai/ibombonato/kaggle-sorghum-100-cultivar/artifacts/model/

![image.png](attachment:772d198e-6f4c-40ca-9c12-45959379a818.png)

In [4]:
ORIGIN_FOLDER = "../input/sorghum-100-cultivar-512x512-png-imagefolder/images"
MODEL_NAME = 'google/vit-base-patch16-224-in21k'

# CHANGE TO YOUR CHECKPOINT HERE!
CHECKPOINT_WANDB = 'ibombonato/kaggle-sorghum-100-cultivar/model-ul10jk3n:v46'

Load artifact from Wandb based on the `CHECKPOINT_WANDB` file

In [5]:
from kaggle_secrets import UserSecretsClient
import wandb

user_secrets = UserSecretsClient()

wandb.login(key=user_secrets.get_secret("WANDB_API_KEY"))

run = wandb.init()
artifact = run.use_artifact(CHECKPOINT_WANDB, type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m: W&B API key is configured (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mibombonato[0m (use `wandb login --relogin` to force relogin)


[34m[1mwandb[0m: Downloading large artifact model-ul10jk3n:v46, 982.96MB. 1 files... Done. 0:0:0


In [6]:
train_raw = pd.read_csv("../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv")

In [7]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchvision.datasets import ImageFolder
from transformers import AutoFeatureExtractor, ViTForImageClassification
from timm.data import ImageDataset

## Recreating train classes and models

Train notebook for reference:

[ViT Transformers - Sorghum 100 cultivar - Starter](https://www.kaggle.com/code/ibombonato/vit-transformers-sorghum-100-cultivar-starter)

Since pytorch will convert targets to numeric, we will map ids to labels and labels to ids, so we can get/acess the class names in the future.

In [8]:
all_ds = ImageFolder(Path(ORIGIN_FOLDER, "train"))

label2id = {}
id2label = {}

for i, class_name in enumerate(all_ds.classes):
    label2id[class_name] = str(i)
    id2label[str(i)] = class_name

In [9]:
class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor
 
    def __call__(self, batch):
        encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
        encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
        return encodings

In [10]:
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label2id),
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

collator = ImageClassificationCollator(feature_extractor)

Downloading:   0%|          | 0.00/502 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/160 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/330M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
class Classifier(pl.LightningModule):

    def __init__(self, model, lr: float = 2e-5, **kwargs):
        super().__init__()
        self.save_hyperparameters('lr', *list(kwargs))
        self.model = model
        self.forward = self.model.forward
        self.val_acc = Accuracy()

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"train_loss", outputs.loss)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"val_loss", outputs.loss)
        acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
        self.log(f"val_acc", acc, prog_bar=True)
        return outputs.loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

## Loading model weights from checkpoint

In [12]:
pl.seed_everything(42)
classifier = Classifier(model)

CHECKPOINT_FILE = Path("./artifacts", Path(CHECKPOINT_WANDB).stem , "model.ckpt")

model = Classifier.load_from_checkpoint(
    CHECKPOINT_FILE,
    model=model
    )

# Make predictions

Now we will load and make predictions on the test set.

In [13]:
TEST_FOLDER = "../input/sorghum-100-cultivar-512x512-png-imagefolder/images/test"

test_ds = ImageDataset(Path(TEST_FOLDER))
test_dl = DataLoader(test_ds, batch_size=32, collate_fn=collator, num_workers=2)

In [14]:
model.cuda()
model.eval()

def batch_predictions(dl, ds, id2label):
    predictions = []
    for batch in tqdm(dl):
        image = batch['pixel_values'].cuda()
        with torch.no_grad():
            outputs = model(image)
            preds = outputs.logits.softmax(1).argmax(1).detach().cpu().numpy()
            predictions.append(preds)
        
    all_preds = []
    for batch in predictions:
        for prediction in batch:
            all_preds.append(id2label[str(prediction)])

    return all_preds, ds.filenames()

In [15]:
batch_preds, batch_filenames = batch_predictions(test_dl, test_ds, id2label)
df_preds = pd.DataFrame({'filename': batch_filenames, "cultivar": batch_preds})
df_preds.head()

100%|██████████| 739/739 [05:57<00:00,  2.06it/s]


Unnamed: 0,filename,cultivar
0,88028.png,PI_221548
1,181578.png,PI_273465
2,320611.png,PI_152860
3,350168.png,PI_155760
4,492639.png,PI_197542


# Submisson

At the moment, the testset or the sample_submission are broken and its not possible to submit. As soon as the organizers fix it, I will update with the submission.


In [16]:
test_df = pd.read_csv("../input/sorghum-id-fgvc-9/sample_submission.csv")

submission_df = pd.merge(test_df[['filename']], df_preds, how='inner', on='filename')

submission_df.to_csv("submission.csv", index = False)

submission_df.head()

Unnamed: 0,filename,cultivar
0,1000005362.png,PI_218112
1,1000099707.png,PI_329333
2,1000135300.png,PI_92270
3,1000136796.png,PI_329256
4,1000292439.png,PI_155516


## If it helps you in some manner, please upvote the dataset and the notebook :D