# Fine-Tuning Google Vision Transformer (ViT) for Custom Classification

This notebook demonstrates how to fine-tune a pre-trained Google Vision Transformer (ViT) model for image classification on a custom dataset. Each class corresponds to a unique image from the DAM dataset, identified by its `id.jpg`, and includes 8 associated 3D images (`id-X.png`). The model is trained on the full DAM dataset and evaluated on a separate test set with labels provided in a handmade CSV file.

---

## Table of Contents

1. [Installation of Required Libraries](#installation)
2. [Importing Libraries](#importing-libraries)
3. [Setting Up Paths and Device](#paths-and-device)
4. [Loading and Preparing the Dataset](#loading-dataset)
5. [Preprocessing Images](#preprocessing)
6. [Preparing the Hugging Face Dataset](#hugging-face-dataset)
7. [Defining the Model](#defining-model)
8. [Training Arguments](#training-arguments)
9. [Metrics Calculation](#metrics)
10. [Training the Model](#training)
11. [Evaluating the Model](#evaluation)
12. [Saving the Fine-Tuned Model](#saving-model)
13. [Conclusion](#conclusion)

---

<a id="installation"></a>
## 1. Installation of Required Libraries

Ensure that all necessary libraries are installed. Run the following cell to install them.

```python
# Install necessary libraries
%pip install transformers datasets torch torchvision pandas scikit-learn Pillow numpy tqdm rembg matplotlib faiss-cpu

In [None]:
import os
import pandas as pd
import numpy as np
from glob import glob
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from evaluate import load

import faiss  # For similarity search (optional)

In [None]:
# Define paths
DAM_DIR = 'data/DAM'
TEST_DIR = 'data/test_image_headmind'
LABELS_CSV = 'labels/handmade_test_labels.csv'
TRELLIS_DIR = '.cache/TRELLIS'

# Supported image extensions
EXTENSIONS = ['*.jpg', '*.jpeg', '*.png']

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Get list of DAM original image file paths (id.jpg)
dam_original_images = glob(os.path.join(DAM_DIR, '*.jpg')) + glob(os.path.join(DAM_DIR, '*.jpeg')) + glob(os.path.join(DAM_DIR, '*.png'))
dam_original_images.sort()

# Function to retrieve 3D augmented image paths for a given class ID
def get_3d_augmented_paths(class_id):
    augmented_paths = []
    for i in range(1, 9):  # i from 1 to 8
        augmented_path = os.path.join(TRELLIS_DIR, class_id, f"{class_id}-{i}.png")
        if os.path.exists(augmented_path):
            augmented_paths.append(augmented_path)
    return augmented_paths

# Create a list to hold all image paths (original + augmented)
all_dam_images = []
labels = []

for img_path in dam_original_images:
    class_id = os.path.splitext(os.path.basename(img_path))[0]  # Remove extension
    all_dam_images.append(img_path)
    labels.append(class_id)
    
    # Add 3D augmented images
    augmented_paths = get_3d_augmented_paths(class_id)
    all_dam_images.extend(augmented_paths)
    labels.extend([class_id] * len(augmented_paths))  # Same label for augmented images

# Create DataFrame
dam_df = pd.DataFrame({
    'image_path': all_dam_images,
    'label': labels
})

# Get list of Test image file paths
test_images = []
for ext in EXTENSIONS:
    pattern = os.path.join(TEST_DIR, ext)
    test_images.extend(glob(pattern))
test_images.sort()

# Create Test DataFrame
test_df = pd.DataFrame({'image_path': test_images})

print("DAM DataFrame:")
print(dam_df.head())

print("\nTest DataFrame:")
print(test_df.head())

In [None]:
from utils.preprocessing import preprocess_image
from sklearn.model_selection import train_test_split

# Extract class labels from DAM image filenames (using the id)
def extract_label(image_path):
    filename = os.path.basename(image_path)
    label = filename.split('.')[0]  # Assumes label is the part before the first dot
    label = label.split('-')[0]  # Remove augmented image number if present
    return label

dam_df['label'] = dam_df['image_path'].apply(extract_label)

# Since each class has multiple images (1 original + 8 3D images), ensure all are included
# No need to split into train/validation here as we'll use the full dataset for training

# Convert DataFrame to Hugging Face Dataset
from datasets import Dataset as HFDataset

hf_dataset = HFDataset.from_pandas(dam_df)

# Split the dataset into training and validation sets if desired
# For simplicity, we'll proceed without a validation split here

In [None]:
# Initialize the image processor
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

def transform(example_batch):
    # Preprocess images
    images = [preprocess_image(img_path, background_removal="rembg") for img_path in example_batch['image_path']]
    images = [img for img in images if img is not None]

    if len(images) == 0:
        return {}
    
    # Apply processor to images
    inputs = processor(images=images, return_tensors='pt')
    
    # Extract labels
    labels = [example_batch['label'][i] for i in range(len(example_batch['label'])) if preprocess_image(example_batch['image_path'][i], background_removal="rembg") is not None]
    
    # Ensure labels align with processed images
    if len(labels) != len(inputs['pixel_values']):
        min_len = min(len(labels), len(inputs['pixel_values']))
        labels = labels[:min_len]
        inputs['pixel_values'] = inputs['pixel_values'][:min_len]
    
    # Create a label mapping if not already defined
    unique_labels = sorted(list(set(labels)))
    label2id = {label: idx for idx, label in enumerate(unique_labels)}
    id2label = {idx: label for label, idx in label2id.items()}
    
    # Map labels to IDs
    inputs['labels'] = torch.tensor([label2id[label] for label in labels], dtype=torch.long)
    
    return inputs

# Apply the transformation
prepared_dataset = hf_dataset.with_transform(transform)

In [None]:
# Determine the number of classes
num_classes = len(dam_df['label'].unique())
print(f"Number of classes: {num_classes}")

# Create label mappings
labels = sorted(dam_df['label'].unique())
label2id = {label: idx for idx, label in enumerate(labels)}
id2label = {idx: label for label, idx in label2id.items()}

# Initialize the model with a classification head
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_classes,
    id2label=id2label,
    label2id=label2id
)

model.to(device)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./vit-finetuned',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=4,
    learning_rate=2e-4,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    push_to_hub=False,
    remove_unused_columns=False,
)

In [None]:
import numpy as np
from evaluate import load

# Load the accuracy metric
metric = load('accuracy')

def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return metric.compute(predictions=preds, references=p.label_ids)

In [None]:
from transformers import Trainer

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=prepared_dataset,
    eval_dataset=prepared_dataset,  # For demonstration; ideally, use a separate validation set
    compute_metrics=compute_metrics,
)

# Start training
trainer.train()

In [None]:
# Set this variable to True to enable result display
DISPLAY_RESULTS = True

# Load the test labels
labels_df = pd.read_csv(LABELS_CSV)

# Create a dictionary mapping each test image filename to its labels
labels_dict = {}
for _, row in labels_df.iterrows():
    image_name = row['image'].strip()
    references = [ref.strip() for ref in str(row['reference']).split('/') if ref.strip() and ref.strip() != '?']
    labels_dict[image_name] = references

# Function to extract label from test image path
# def extract_test_label(image_path):
#     filename = os.path.basename(image_path)
#     label = filename.split('.')[0]  # Assumes label is the part before the first dot
#     return label

# Update test DataFrame with labels
test_df['label'] = test_df['image_path'].apply(extract_label)

# Convert Test DataFrame to Hugging Face Dataset
test_hf_dataset = HFDataset.from_pandas(test_df)

# Define transformation for test dataset
def test_transform(example_batch):
    # Preprocess images
    images = [preprocess_image(img_path, background_removal="RMBG_2") for img_path in example_batch['image_path']]
    images = [img for img in images if img is not None]

    if len(images) == 0:
        return {}
    
    # Apply processor to images
    inputs = processor(images=images, return_tensors='pt')
    
    # Extract labels
    labels = [example_batch['label'][i] for i in range(len(example_batch['label'])) if preprocess_image(example_batch['image_path'][i], background_removal="RMBG_2") is not None]
    
    # Ensure labels align with processed images
    if len(labels) != len(inputs['pixel_values']):
        min_len = min(len(labels), len(inputs['pixel_values']))
        labels = labels[:min_len]
        inputs['pixel_values'] = inputs['pixel_values'][:min_len]
    
    # Map labels to IDs using the same mapping as training
    inputs['labels'] = torch.tensor([label2id.get(label, -1) for label in labels], dtype=torch.long)
    
    return inputs

# Apply the transformation to the test dataset
prepared_test_dataset = test_hf_dataset.with_transform(test_transform)

# Evaluate the model
results = trainer.evaluate(prepared_test_dataset)
print("Evaluation Results:")
print(results)

# If DISPLAY_RESULTS is True, display each test image with expected and predicted DAM images
if DISPLAY_RESULTS:
    from transformers import pipeline
    
    # Initialize the image classification pipeline
    classifier = pipeline("image-classification", model=model, image_processor=processor, device=device)
    
    # Create a mapping from label to original DAM image path (original image, not augmented)
    label_to_original_path = {}
    for _, row in dam_df.iterrows():
        label = row['label']
        image_path = row['image_path']
        basename = os.path.basename(image_path)
        # Assuming original images do not have '-X' in their filenames
        if '-' not in basename:
            if label not in label_to_original_path:
                label_to_original_path[label] = image_path
    
    # Iterate over each test image
    for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Displaying Results"):
        test_image_path = row['image_path']
        test_image_name = os.path.basename(test_image_path)
        expected_labels = labels_dict.get(test_image_name, [])
        
        # Get expected DAM image paths
        expected_dam_paths = [label_to_original_path.get(label, None) for label in expected_labels]
        expected_dam_paths = [path for path in expected_dam_paths if path is not None]
        
        # Perform prediction
        prediction = classifier(test_image_path)[0]  # Assuming single image batch
        predicted_label = prediction['label']
        predicted_score = prediction['score']
        
        # Get predicted DAM image path
        print(f"predicted_label: {predicted_label}")
        predicted_dam_path = label_to_original_path.get(predicted_label, None)
        
        # Load images
        try:
            test_image = Image.open(test_image_path).convert('RGB')
        except Exception as e:
            print(f"Error loading test image {test_image_path}: {e}")
            continue
        
        expected_images = []
        for path in expected_dam_paths:
            try:
                img = Image.open(path).convert('RGB')
                expected_images.append(img)
            except Exception as e:
                print(f"Error loading expected DAM image {path}: {e}")
        
        if predicted_dam_path:
            try:
                predicted_image = Image.open(predicted_dam_path).convert('RGB')
            except Exception as e:
                print(f"Error loading predicted DAM image {predicted_dam_path}: {e}")
                predicted_image = None
        else:
            predicted_image = None
        
        # Display images
        num_expected = len(expected_images)
        num_cols = 2 + num_expected  # Test image, Predicted, Expected(s)
        plt.figure(figsize=(5 * num_cols, 5))
        
        # Display Test Image
        plt.subplot(1, num_cols, 1)
        plt.imshow(test_image)
        plt.title("Test Image")
        plt.axis('off')
        
        # Display Predicted Image
        if predicted_image:
            plt.subplot(1, num_cols, 2)
            plt.imshow(predicted_image)
            plt.title(f"Predicted: {predicted_label}\nScore: {predicted_score:.2f}")
            plt.axis('off')
        else:
            plt.subplot(1, num_cols, 2)
            plt.text(0.5, 0.5, "Predicted Image Not Found", horizontalalignment='center', verticalalignment='center')
            plt.title("Predicted Image")
            plt.axis('off')
        
        # Display Expected Images
        for i, expected_img in enumerate(expected_images):
            plt.subplot(1, num_cols, 3 + i)
            plt.imshow(expected_img)
            plt.title(f"Expected: {expected_labels[i]}")
            plt.axis('off')
        
        plt.tight_layout()
        plt.show()

In [None]:
# Save the model
model_save_path = "./vit-finetuned"
model.save_pretrained(model_save_path)
processor.save_pretrained(model_save_path)

print(f"Model and processor saved to {model_save_path}")