In [1]:
# Standard library imports
import os
import json
import pandas as pd
import timm
import matplotlib.pyplot as plt
from fastai.vision.all import *
import torch

In [None]:
import timm

# List all available models in timm
# Check CVT model
available_models = timm.list_models()
print(available_models)

In [2]:
import json

class VisionTransformerTrainer:
    def __init__(self, csv_path, train_dir, model_name='vit_base_patch16_224', image_size=224, batch_size=8, valid_pct=0.2):
        self.csv_path = csv_path
        self.train_dir = train_dir
        self.model_name = model_name
        self.image_size = image_size
        self.batch_size = batch_size
        self.valid_pct = valid_pct
        self.learn = None
        self.lr = None  # Placeholder for the learning rate if specified manually
        self.metrics = {}  # To store metrics (loss and accuracy)

        self._prepare_data()
        self._create_dataloaders()

    def _prepare_data(self):
        self.train_data = pd.read_csv(self.csv_path)
        self.train_data['image_path'] = self.train_data['id'].apply(lambda x: os.path.join(self.train_dir, f"{x}.jpg"))
        self.train_data['stable_height'] = self.train_data['stable_height'].astype(str)

    def _create_dataloaders(self):
        dblock = DataBlock(
            blocks=(ImageBlock, CategoryBlock),
            get_x=ColReader('image_path'),
            get_y=ColReader('stable_height'),
            splitter=RandomSplitter(valid_pct=self.valid_pct),
            item_tfms=Resize(self.image_size),
            batch_tfms=aug_transforms()
        )
        self.dls = dblock.dataloaders(self.train_data, bs=self.batch_size, num_workers=0)

    def initialize_model(self, pretrained=True):
        self.model = timm.create_model(self.model_name, pretrained=pretrained, num_classes=self.dls.c)
        self.learn = vision_learner(self.dls, self.model_name, metrics=accuracy, pretrained=pretrained)
        
        # Check if CUDA is available
        if torch.cuda.is_available():
            self.learn.model = self.learn.model.cuda()
            print("CUDA is available. Using GPU for training.")
        else:
            print("CUDA is not available. Using CPU for training.")

    def set_learning_rate(self, learning_rate):
        """
        Sets the learning rate manually.
        
        Args:
            learning_rate (float): The learning rate to use for training.
        """
        self.lr = learning_rate
        print(f"Learning rate set to: {self.lr}")

    def fine_tune(self, epochs=5):
        if self.learn is None:
            raise Exception("Model has not been initialized. Please run initialize_model() first.")
        if self.lr is None:
            raise Exception("Learning rate not set. Please set it using set_learning_rate() first.")

        self.learn.fine_tune(epochs, base_lr=self.lr)
        print("Training complete.")

        # Collecting the train loss
        self.metrics['train_loss'] = [float(loss) for loss in self.learn.recorder.losses]
        
        # Collect validation losses and accuracy, ensuring to handle potential float values
        self.metrics['valid_loss'] = []
        self.metrics['accuracy'] = []

        for values in self.learn.recorder.values:
            try:
                # Ensure 'values' is treated as a list
                if isinstance(values, (list, tuple)):
                    valid_loss, accuracy = float(values[0]), float(values[1])
                    self.metrics['valid_loss'].append(valid_loss)
                    self.metrics['accuracy'].append(accuracy)
                else:
                    print(f"Unexpected value format in metrics: {values}")
            except ValueError:
                print(f"Skipping invalid entry in metrics: {values}")

        torch.cuda.empty_cache()

    def save_model(self, model_name=None):
        if self.learn is None:
            raise Exception("Model has not been initialized. Please run initialize_model() first.")
        
        if model_name is None:
            model_name = self.model_name.replace('/', '_')
        self.learn.save(model_name)
        print(f"Model saved as {model_name}")

    def load_model(self, model_name=None):
        if self.learn is None:
            raise Exception("Model has not been initialized. Please run initialize_model() first.")
        
        if model_name is None:
            model_name = self.model_name.replace('/', '_')
        self.learn.load(model_name)
        print(f"Model loaded from {model_name}")

    def predict(self, image_path):
        if self.learn is None:
            raise Exception("Model has not been initialized. Please run initialize_model() first.")
        
        img = PILImage.create(image_path)
        pred, pred_idx, probs = self.learn.predict(img)
        return {'id': os.path.basename(image_path).split('.')[0], 'predicted_stable_height': pred, 'probability': probs[pred_idx].item()}

    def plot_metrics(self):
        if self.learn is None or not hasattr(self.learn, 'recorder'):
            raise Exception("No training data found. Train the model first using fine_tune().")
        
        self.learn.recorder.plot_loss()
        plt.title(f"Loss and Accuracy for Model: {self.model_name}")
        plt.show()
        
    def save_metrics(self, file_path='training_metrics.json'):
        """
        Save collected metrics to a JSON file.
        """
        metrics_data = {
            'model_name': self.model_name,
            'learning_rate': self.lr,
            'train_loss': [float(loss) for loss in self.metrics['train_loss']],
            'valid_loss': [[float(value) for value in values] for values in self.metrics['valid_loss']],
            'accuracy': [float(acc) for acc in self.metrics['accuracy']],
        }

        with open(file_path, 'w') as json_file:
            json.dump(metrics_data, json_file, indent=4)
        
        print(f"Training metrics saved to {file_path}")


In [3]:
# List of Vision Transformer models to test from timm
model_names = [  
    'vit_base_patch16_224',     # Vanilla ViT model  
    'beit_base_patch16_224',    # BEiT model
    'cvt-21-224x224',           # CvT model
    'deit_base_patch16_224'     # DeiT model
]

# Paths to your dataset
csv_path = 'COMP90086_2024_Project_train/train.csv'
train_dir = 'COMP90086_2024_Project_train/train'

# Define the learning rate to be used for all models
learning_rate = 3e-4  # You can adjust this value as needed

# Loop through each Vision Transformer model and train
for model_name in model_names:
    print(f"\nTraining with model: {model_name}")
    
    try:
        # Initialize the Vision Transformer Trainer for the current model
        trainer = VisionTransformerTrainer(csv_path, train_dir, model_name=model_name, image_size=224, batch_size=8)
        
        # Initialize the model
        trainer.initialize_model(pretrained=True)
        
        # Set the learning rate manually
        trainer.set_learning_rate(learning_rate=learning_rate)
        
        # Fine-tune the model
        trainer.fine_tune(epochs=5)  # Adjust the number of epochs as needed
        
        # Save the trained model
        trainer.save_model(model_name=model_name)
        
        # Save the training metrics to a JSON file
        trainer.save_metrics(file_path=f'{model_name}_training_metrics.json')
        
        # Plot the training metrics
        trainer.plot_metrics()
        
    except Exception as e:
        print(f"An error occurred while training model {model_name}: {e}")



Training with model: vit_base_patch16_224
CUDA is available. Using GPU for training.
Learning rate set to: 0.0003


epoch,train_loss,valid_loss,accuracy,time
0,2.539186,1.94039,0.239583,03:53


  x = F.scaled_dot_product_attention(


epoch,train_loss,valid_loss,accuracy,time
0,2.216609,1.711342,0.269531,05:31
1,2.043823,1.628476,0.302734,04:53
2,1.913306,1.516157,0.355469,03:45
3,1.854125,1.474757,0.361979,03:54
4,1.74471,1.463529,0.36849,03:58


Training complete.
Model saved as vit_base_patch16_224
An error occurred while training model vit_base_patch16_224: 'float' object is not iterable

Training with model: beit_base_patch16_224
CUDA is available. Using GPU for training.
Learning rate set to: 0.0003


epoch,train_loss,valid_loss,accuracy,time
0,2.433183,1.788025,0.260417,03:16


epoch,train_loss,valid_loss,accuracy,time
0,2.213568,1.694043,0.295573,05:30
1,1.998419,1.575691,0.327474,04:39
2,1.937005,1.487411,0.342448,06:02
3,1.810161,1.455377,0.353516,05:56
4,1.757939,1.450001,0.354818,05:20


Training complete.
Model saved as beit_base_patch16_224
An error occurred while training model beit_base_patch16_224: 'float' object is not iterable

Training with model: cvt-21-224x224
An error occurred while training model cvt-21-224x224: Unknown model (cvt-21-224x224)

Training with model: deit_base_patch16_224
CUDA is available. Using GPU for training.
Learning rate set to: 0.0003


epoch,train_loss,valid_loss,accuracy,time
0,2.519091,1.854182,0.245443,03:00


epoch,train_loss,valid_loss,accuracy,time
0,2.385565,1.725879,0.27474,17:01
1,2.079112,1.706737,0.279297,13:00
2,1.908475,1.673879,0.289062,05:35
3,1.932471,1.615718,0.28776,05:34
4,1.896061,1.620193,0.289062,05:41


Training complete.
Model saved as deit_base_patch16_224
An error occurred while training model deit_base_patch16_224: 'float' object is not iterable
