<a href="https://colab.research.google.com/github/juliobellano/CV_Notebooks/blob/main/HuggingPics2_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# HuggingPics 🤗🖼️

Fine-tune Vision Transformers for regression problem. ("Guessthecorrelation.com")

In [None]:
%%capture

! pip install transformers pytorch-lightning --quiet
! sudo apt -qq install git-lfs
! git config --global credential.helper store

In [None]:
import requests
import math
import time
import matplotlib.pyplot as plt
import shutil
import os
from getpass import getpass
from PIL import Image, UnidentifiedImageError
from requests.exceptions import HTTPError
from io import BytesIO
from pathlib import Path
import torch
import pandas as pd
import torchvision.transforms as transforms
import pytorch_lightning as pl
from huggingface_hub import HfApi, HfFolder, Repository, notebook_login
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy
from torchvision.datasets import ImageFolder
from transformers import ViTFeatureExtractor, ViTForImageClassification
from sklearn.model_selection import train_test_split

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp /content/drive/MyDrive/input.zip /content/input.zip

In [None]:
!unzip input.zip -d /content/input

In [None]:
data_dir = Path('input/correlation_assignment/images')

## Init Dataset and Split into Training and Validation Sets


In [None]:
df = pd.read_csv('input/correlation_assignment/responses.csv')

image_id = df['id'].values
labels = df['corr'].values
print(len(image_id))

In [None]:
data_transforms = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0.5,),(0.5,))])
class ImageDataset(Dataset):
    def __init__(self, image_ids, labels, image_dir='input/correlation_assignment/images', transform=None):
        self.image_ids = image_ids
        self.labels = labels
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        #get image pathhh
        img_name = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, f"{img_name}.png")

        #load image
        image = Image.open(img_path).convert('1')

        #apply transforms
        if self.transform:
            image = self.transform(image)

        #get label
        label = self.labels[idx]

        return image, label

In [None]:
X_train, X_val, y_train, y_val = train_test_split(image_id, labels, test_size=0.15, random_state=42, shuffle=True)
print(f"length of training dataset {len(X_train)}")
print(f"length of val dataset {len(X_val)}")

train_ds = ImageDataset(X_train, y_train, transform = data_transforms)
val_ds = ImageDataset(X_val, y_val, transform = data_transforms)

In [None]:
plt.figure(figsize=(20,10))

for i in range(6):
    plt.subplot(2, 3, i+1)
    plt.title(f'{y_train[i]}')
    img = plt.imread(f'/content/input/correlation_assignment/images/{X_train[i]}.png')
    plt.imshow(img)
    plt.axis('off')

## Image Classification Collator

To apply our transforms to images, we'll use a custom collator class. We'll initialize it using an instance of `ViTFeatureExtractor` and pass the collator instance to `torch.utils.data.DataLoader`'s `collate_fn` kwarg.

In [None]:
class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor

    def __call__(self, batch):
        processed_images = []
        for x in batch:
            img = x[0]

            # Convert from [-1,1] to [0,255] range
            img = ((img + 1) * 127.5).byte()

            # Convert grayscale to RGB using PIL
            # This works better than just repeating channels
            if img.shape[0] == 1:  # If single channel
                # Convert tensor to PIL
                pil_img = Image.fromarray(img[0].numpy())
                # Convert to RGB
                pil_img = pil_img.convert('RGB')
                processed_images.append(pil_img)
            else:
                processed_images.append(img)

        encodings = self.feature_extractor(processed_images, return_tensors='pt')
        encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.float)
        return encodings

## Init Feature Extractor, Model, Data Loaders


In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=1,
    problem_type='regression'
)

model.classifier = torch.nn.Sequential(
    torch.nn.Linear(model.classifier.in_features, 1),
    torch.nn.Tanh()
)


collator = ImageClassificationCollator(feature_extractor)
train_loader = DataLoader(train_ds, batch_size=64, collate_fn=collator, num_workers=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64, collate_fn=collator, num_workers=2)

# Training

⚡ We'll use [PyTorch Lightning](https://pytorchlightning.ai/) to fine-tune our model.

In [None]:
class Classifier(pl.LightningModule):

    def __init__(self, model, lr: float = 2e-5, **kwargs):
        super().__init__()
        self.save_hyperparameters('lr', *list(kwargs))
        self.model = model
        self.forward = self.model.forward
        self.mse = torch.nn.MSELoss()


    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = self.mse(outputs.logits.squeeze(), batch['labels'])
        self.log(f"train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        preds = outputs.logits.squeeze()
        loss = self.mse(preds, batch['labels'])
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [None]:
pl.seed_everything(42)
classifier = Classifier(model, lr=2e-5)
trainer = pl.Trainer(accelerator='gpu', devices=1, precision=16, max_epochs=1)
trainer.fit(classifier, train_loader, val_loader)

## Check if it Worked 😅

In [None]:
model = ('ViTtest.pt')

In [None]:
val_batch = next(iter(val_loader))
start_time = time.time()
outputs = model(**val_batch)
time = time.time() - start_time
print(f'Inference time for 1 batch: {time}')
print('Preds: ', outputs.logits.squeeze())
print('Labels:', val_batch['labels'])
model_output = outputs.logits.squeeze().detach().cpu().numpy()
labels_actual = val_batch['labels'].cpu().numpy()
print('model output:', model_output)
print('Labels:', labels_actual)

In [None]:
trainer.save_checkpoint("full_model.ckpt")