In [1]:
!pip install --upgrade pip
!pip install tokenizers
!pip install datasets --upgrade evaluate
!pip install transformers
!pip install numpy torch matplotlib pandas scikit-learn tqdm pillow
!pip install datasets evaluate transformers
!pip install torchvision
!pip install setuptools
!pip install wandb
!pip show wandb
!pip install schedulefree
!pip install nbformat

Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m43.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.0.1
Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting f

In [2]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from evaluate import load
from transformers import (
    ViTFeatureExtractor,
    ViTForImageClassification,
    TrainingArguments,
    Trainer,
    get_scheduler,
    AutoImageProcessor
)

from torch.optim import AdamW, SGD
import wandb
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torchvision.transforms as transforms
from PIL import Image
import random
from tqdm.auto import tqdm
from sklearn.metrics import confusion_matrix
from schedulefree import AdamWScheduleFree
from torch.optim.lr_scheduler import CyclicLR, ExponentialLR, ReduceLROnPlateau
from transformers import TrainerCallback, EarlyStoppingCallback
from scipy.ndimage import zoom
import os
import re
from PIL import Image

In [3]:
class Attentionmapcallback(TrainerCallback):
    def __init__(self, feature_extractor, image, output_dir,total_epochs,steps_tune, device = 'cuda'):
        self.feature_extractor = feature_extractor
        self.image = image
        self.output_dir = output_dir
        self.total_epochs = total_epochs
        self.device = device
        self.steps_tune = steps_tune
        os.makedirs(self.output_dir, exist_ok=True)

    def plt_attn(self, attentions, epoch, step , layer = 0, head = 0):

      attn = attentions[layer][0,head]  #.cpu().detach().numpy()
      cls_attn = attn[0,1:]
      nm_patch = cls_attn.shape[0]
      gid_size = int(np.sqrt(nm_patch))
      ##cls_attn = cls_attn.reshape(gid_size,gid_size)

      if gid_size * gid_size != nm_patch:
        print(f"skipping attn epoch {epoch}")
        return

      cls_attn = cls_attn.cpu().detach().numpy().reshape(gid_size,gid_size)

      if isinstance(self.image, np.ndarray):
        img_np = self.image
      else:
        img_np = np.array(self.image)

      attn_resized = zoom(cls_attn, (img_np.shape[0] / gid_size, img_np.shape[1] / gid_size))

      plt.figure(figsize=(8, 8))


      plt.imshow(self.image)

      plt.imshow(attn_resized, cmap = 'jet',  alpha=0.3)

      plt.axis('off')
      plt.title(f"Epoch {epoch}, Layer {layer}, Head {head}")
      plt.savefig(os.path.join(self.output_dir, f"step_{step}_epoch_{epoch}_layer_{layer}_head_{head}.png"))
      plt.close()



    def on_step_end(self, args, state, control, **kwargs):
      epoch = state.epoch
      step = state.global_step
      if step in self.steps_tune:
        model = kwargs['model']
        model.eval()
        inputs = self.feature_extractor(images=self.image, return_tensors="pt").to(self.device)
        with torch.no_grad():
          outputs = model(**inputs, output_attentions=True)
          attentions = outputs.attentions
          #self.plt_attn(attentions, int(epoch), int(step))
          self.plt_attn(attentions, int(epoch), int(step) ,layer = -1, head = -1)



In [4]:
def stp_extrct(fname):
  match = re.search(r'step_(\d+)_epoch_(\d+)', fname)
  return int(match.group(1)) if match else flat('inf')

def make_log_fig(step_func,num_epoch, sch_name, opt_name):
  img_files = []
  # for i in range(num_epoch):
  #   for j in step_func:
  #     img_addr = './attention_maps/' + str(sch_name).split('.')[-1] + "_" + str(opt_name).split('.')[-1] + "/" + "step_" + str(j) + "_epoch_" + str(i) + "_layer_-1_head_-1.png"
  #     img_files.append(img_addr)
  img_files_names = os.listdir('./attention_maps/' + str(sch_name).split('.')[-1] + "_" + str(opt_name).split('.')[-1])
  img_files_names.sort(key=stp_extrct)
  for i in img_files_names:
    img_addr = './attention_maps/' + str(sch_name).split('.')[-1] + "_" + str(opt_name).split('.')[-1] + "/" + i
    img_files.append(img_addr)


  images = [Image.open(file) for file in img_files]
  min_height = min(img.height for img in images)
  images = [img.resize((img.width, min_height)) for img in images]
  total_width = sum(img.width for img in images)
  result = Image.new('RGB', (total_width, min_height))
  x_offset = 0
  for img in images:
      result.paste(img, (x_offset, 0))
      x_offset += img.width

  output_path = './attention_maps/' + str(sch_name).split('.')[-1] + "_" + str(opt_name).split('.')[-1] + "attn_images.png"

  result.save(output_path)
  logimg = "attn_image_" + str(sch_name).split('.')[-1] + "_" + str(opt_name).split('.')[-1]

  wandb.log({logimg: wandb.Image(output_path)})

In [5]:
# Set seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed()

# Initialize experiment tracking
def init_wandb(project_name, experiment_name, config):
    return wandb.init(
        # entity="dl_project_sp25",
        project=project_name,
        name=experiment_name,
        config=config,
        reinit=True
    )

# Load and prepare dataset
def prepare_dataset(dataset_name, image_processor, row_indx):
    """
    Load and prepare a dataset from Hugging Face for ViT fine-tuning
    """
    # Load the dataset
    print(f"Loading dataset: {dataset_name}")
    dataset = load_dataset(dataset_name)

    # Get label information
    if "label" in dataset["train"].features:
        labels = dataset["train"].features["label"].names
    elif "labels" in dataset["train"].features:
        labels = dataset["train"].features["labels"].names
    else:
        # Count unique labels and create labels list
        all_labels = dataset["train"][0]["label"] if "label" in dataset["train"][0] else dataset["train"][0]["labels"]
        num_labels = len(set(all_labels))
        labels = [str(i) for i in range(num_labels)]

    # Create label mappings
    label2id = {label: i for i, label in enumerate(labels)}
    id2label = {i: label for i, label in enumerate(labels)}

    # Set up image transformations based on the model's requirements
    normalize = transforms.Normalize(
        mean=image_processor.image_mean,
        std=image_processor.image_std
    )

    # Get the expected image size
    if "shortest_edge" in image_processor.size:
        size = image_processor.size["shortest_edge"]
    else:
        size = (image_processor.size["height"], image_processor.size["width"])

    # Define transforms for training data
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    # Define transforms for validation/test data
    val_transforms = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        normalize,
    ])

    # Apply transformations to the dataset
    def preprocess_train(examples):
        examples["pixel_values"] = [
            train_transforms(image.convert("RGB"))
            for image in examples["image"]
        ]
        return examples

    def preprocess_val(examples):
        examples["pixel_values"] = [
            val_transforms(image.convert("RGB"))
            for image in examples["image"]
        ]
        return examples

    # taking the image out for attention_mp
    #####Adding the row index for the image you want to get the attention map for
    data_for_img = dataset["train"][row_indx]

    # Apply preprocessing to each split
    train_dataset = dataset["train"].map(
        preprocess_train, batched=True, remove_columns=["image"]
    )

    if "validation" in dataset:
        val_dataset = dataset["validation"].map(
            preprocess_val, batched=True, remove_columns=["image"]
        )

    else:
        # Create a validation split if none exists
        splits = train_dataset.train_test_split(test_size=0.2, seed=42)
        train_dataset = splits["train"]
        val_dataset = splits["test"]

    if "test" in dataset:
        test_dataset = dataset["test"].map(
            preprocess_val, batched=True, remove_columns=["image"]
        )
    else:
        # test_dataset = val_dataset    #split further rather than using validation as test dataset

        # Further split validation dataset to create a test dataset
        test_split = val_dataset.train_test_split(test_size=0.2, seed=42)
        val_dataset = test_split["train"]  # Update validation dataset
        test_dataset = test_split["test"]  # Create test dataset

    print(f"Dataset prepared with {len(train_dataset)} training, {len(val_dataset)} validation, and {len(test_dataset)} test examples")

    return train_dataset, val_dataset, test_dataset, id2label, label2id, data_for_img

# Define compute_metrics function for evaluation
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    acc = accuracy_score(labels, predictions)

    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
    }

# Main experiment pipeline
def run_vit_experiment(config):
    """
    Run a ViT experiment with the specified configuration
    """
    # Initialize wandb for experiment tracking
    # run = init_wandb("ViT-LR-Schedulers", config["experiment_name"], config)

    # Initialize wandb for experiment tracking with config logging
    run = wandb.init(
        project="ViT-LR-Schedulers",
        name=config["experiment_name"],
        group=f"{config['optimizer_name']}_experiments",  # Group by optimizer
        config={
            # Explicitly list all important hyperparameters
            "optimizer": config["optimizer_name"],
            "scheduler": config["scheduler_name"],
            "learning_rate": config["learning_rate"],
            "batch_size": config["batch_size"],
            "num_epochs": config["num_epochs"],
            "weight_decay": config["weight_decay"],
            "warmup_ratio": config.get("warmup_ratio", 0.0),
            "dataset": config["dataset_name"],
            "model": config["model_name"],
            "row_indx" : config["row_indx"],
            "attention_steps": config["attention_steps"]
        },
        tags=[tag for tag in [config["optimizer_name"], config["scheduler_name"]] if tag is not None],
        reinit=True
    )


    # Load the image processor for the model
    image_processor = AutoImageProcessor.from_pretrained(config["model_name"], use_fast=True)

    # Prepare the dataset
    train_dataset, val_dataset, test_dataset, id2label, label2id, data_for_img = prepare_dataset(
        config["dataset_name"], image_processor, config["row_indx"]
    )

    # Load the ViT model
    model = ViTForImageClassification.from_pretrained(
        config["model_name"],
        num_labels=len(id2label),
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True
    )

    # Define training arguments
    training_args = TrainingArguments(
        output_dir=f"./results/{config['experiment_name']}",
        per_device_train_batch_size=config["batch_size"],
        per_device_eval_batch_size=config["batch_size"],
        num_train_epochs=config["num_epochs"],
        weight_decay=config["weight_decay"],
        eval_strategy="steps",
        save_strategy="steps",
        logging_strategy="steps",  # Ensure logging is enabled
        logging_steps=10,          # Log every 10 steps (adjust as needed)
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        push_to_hub=False,
        report_to="wandb",
        remove_unused_columns=False,
        learning_rate=config["learning_rate"],
    )

    # Setup optimizer
    # if config["optimizer_name"] == "AdamW":
    #     optimizer = AdamW(model.parameters(), lr=config["learning_rate"])
    # else:  # SGD
    #     optimizer = SGD(model.parameters(), lr=config["learning_rate"], momentum=0.9)

    # Full run, 10 epoch set up.
    if config["optimizer_name"] == "schedule_free_adamw":
        optimizer = AdamWScheduleFree(model.parameters(), lr=config["learning_rate"])
        scheduler = None  # No external scheduler
    elif config["optimizer_name"] == "AdamW":
        optimizer = AdamW(model.parameters(), lr=config["learning_rate"])
    elif config["optimizer_name"] == "SGD":
        optimizer = SGD(model.parameters(), lr=config["learning_rate"], momentum=0.9)
    elif config["optimizer_name"] == "RMSProp":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=config["learning_rate"])
    elif config["optimizer_name"] == "AdaGrad":
        optimizer = torch.optim.Adagrad(model.parameters(), lr=config["learning_rate"])
    else:
        raise ValueError(f"Optimizer {config['optimizer_name']} not supported")

    # Setup scheduler
    num_training_steps = len(train_dataset) // config["batch_size"] * config["num_epochs"]
    num_warmup_steps = int(num_training_steps * config["warmup_ratio"]) if "warmup_ratio" in config else 0

    scheduler_name = config["scheduler_name"]

    if scheduler_name is None:
        scheduler = None    # Handle the case for schedule_free

    elif scheduler_name == "linear":
        scheduler = get_scheduler(
            "linear",
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )
    elif scheduler_name == "cosine":
        scheduler = get_scheduler(
            "cosine",
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )

    elif scheduler_name == "polynomial":
        scheduler = get_scheduler(
            "polynomial",
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            # power=config.get("poly_power", 1.0),
        )
    elif scheduler_name == "cyclic":
        # CyclicLR from torch.optim.lr_scheduler
        # Using step_size_up as 1/3 of training steps and step_size_down as 2/3
        step_size_up = num_training_steps // 3
        scheduler = CyclicLR(
            optimizer,
            base_lr=config["learning_rate"] / 10,  # Lower bound of cycle
            max_lr=config["learning_rate"],       # Upper bound of cycle
            step_size_up=step_size_up,
            step_size_down=step_size_up * 2,
            mode='triangular',                    # Three modes: triangular, triangular2, exp_range
            cycle_momentum=False                  # Don't cycle momentum
        )
    elif scheduler_name == "exponential":
        # ExponentialLR from torch.optim.lr_scheduler
        # gamma < 1.0 for decay, common values: 0.9, 0.95, 0.99
        scheduler = ExponentialLR(optimizer, gamma=0.95)

    # elif scheduler_name == "adaptive":
    #     # ReduceLROnPlateau - reduces LR when metric stops improving
    #     # This requires modification to the training loop to update based on validation performance
    #     scheduler = ReduceLROnPlateau(
    #         optimizer,
    #         mode='max',              # Since we want to maximize accuracy
    #         factor=0.5,              # Multiply LR by this factor when plateauing
    #         patience=2,              # Number of epochs with no improvement after which LR will be reduced
    #         threshold=0.01,          # Threshold for measuring improvement
    #         threshold_mode='rel',    # Interpret threshold as relative change
    #         min_lr=1e-6              # Lower bound on the learning rate
    #     )

    elif scheduler_name == "constant":
        scheduler = get_scheduler(
            "constant",
            optimizer=optimizer,
        )
    elif scheduler_name == "cosine_with_restarts":
        scheduler = get_scheduler(
            "cosine_with_restarts",
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
        )
    elif scheduler_name == "constant_with_warmup":
        scheduler = get_scheduler(
            "constant_with_warmup",
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
        )
    # add more experiments if required ...

    else:
        raise ValueError(f"Scheduler {scheduler_name} not supported")
    # print(config)
    # print(config.keys())
    # print(config.optimizer_name)
    # Initialize attention callback (self, feature_extractor, image, output_dir,total_epochs,steps_tune, device = 'cuda'):
    # print(config['optimizer_name'])
    # print(scheduler_name)
    # print(config["attention_steps"])
    # plt.imshow(data_for_img['image'])
    attn_map_clbck = Attentionmapcallback(image_processor,
                                  data_for_img['image'],
                                  './attention_maps/' + str(scheduler_name).split('.')[-1] + "_" + str(config['optimizer_name']).split('.')[-1],
                                  training_args.num_train_epochs,
                                  steps_tune = config["attention_steps"])

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        optimizers=(optimizer, scheduler),
        callbacks = [attn_map_clbck]

    )

    # # Loss validation curve in the training loop to log metrics to W&B
    # for epoch in range(config["num_epochs"]):
    #     print(f"Epoch {epoch + 1}/{config['num_epochs']}")

    #     # Train for one epoch
    #     trainer.train()

    #     # Evaluate on validation set
    #     eval_results = trainer.evaluate(val_dataset)

    #     print(trainer.state.log_history)
    #     if trainer.state.log_history and "loss" in trainer.state.log_history[-1]:
    #         train_loss = trainer.state.log_history[-1]["loss"]
    #     else:
    #         train_loss = None

    #     # Log training and validation metrics to W&B
    #     wandb.log({
    #         "epoch": epoch + 1,
    #         "train_loss": trainer.state.log_history[-1].get("loss", None),
    #         "val_loss": eval_results["eval_loss"],
    #         "val_accuracy": eval_results["eval_accuracy"],
    #     })

    # # Loss epoch curve in the training loop to log metrics to W&B
    # for epoch in range(config["num_epochs"]):
    #     print(f"Epoch {epoch + 1}/{config['num_epochs']}")

    #     # Train for one epoch
    #     trainer.train()

    #     # Evaluate on validation set
    #     eval_results = trainer.evaluate(val_dataset)

    #     # Extract training loss from the trainer's state
    #     if trainer.state.log_history and "loss" in trainer.state.log_history[-1]:
    #         train_loss = trainer.state.log_history[-1]["loss"]
    #     else:
    #         train_loss = None  # Handle missing loss gracefully

    #     # Log training and validation metrics to W&B
    #     wandb.log({
    #         "epoch": epoch + 1,
    #         "train_loss": train_loss,                  # Training loss
    #         "val_loss": eval_results["eval_loss"],    # Validation loss
    #         "val_accuracy": eval_results["eval_accuracy"],  # Validation accuracy
    #     })

    # Train the model
    print(f"Starting training for {config['experiment_name']}...")
    trainer.train()

    # Evaluate the model
    print(f"Evaluating {config['experiment_name']}...")
    eval_results = trainer.evaluate(test_dataset)


    # Log final metrics
    wandb.log({
        "final_accuracy": eval_results["eval_accuracy"],
        "final_f1": eval_results["eval_f1"],
        "final_precision": eval_results["eval_precision"],
        "final_recall": eval_results["eval_recall"],
    })

    # Compute confusion matrix for test set
    predictions, labels, _ = trainer.predict(test_dataset)
    predictions = np.argmax(predictions, axis=1)

    # Convert to lists
    labels = labels.tolist()
    predictions = predictions.tolist()

    # Log confusion matrix to W&B
    wandb.log({
        "confusion_matrix_test": wandb.plot.confusion_matrix(
            probs=None,
            y_true=labels,
            preds=predictions,
            class_names=[str(i) for i in range(len(np.unique(labels)))]
        )
    })



    # Save the model
    trainer.save_model(f"./saved_models/{config['experiment_name']}")

    # logging attention maps
    make_log_fig(config['attention_steps'], config['num_epochs'], config['scheduler_name'], config['optimizer_name'])

    # Finish wandb run
    # wandb.finish()

    return eval_results

# Get experiment configurations for challenging datasets
def get_experiment_configs():
    # We'll use a more complex dataset from Hugging Face
    base_config = {
        "model_name": "google/vit-base-patch16-224-in21k",
        "dataset_name": "jbarat/plant_species",  # Any challenging dataset.
        "batch_size": 16,
        "num_epochs": 10, # let's keep smaller number to begin with.
        "weight_decay": 0.01,
        "attention_steps" : [5, 10, 15, 50,100,200,300,400,500],
        "row_indx" : 60
        # "optimizer_name": "AdamW",
    }

    # Optimizers with their corresponding learning rates
    optimizers = {
        # "schedule_free_adamw": 0.0002,    #Lets get this separately.
        "AdamW": 0.0002,
        "RMSProp": 0.0002,
        "AdaGrad": 0.0002,
        "SGD": 0.02,
    }

    # Schedulers to test
    schedulers = [
        "linear",
        "cosine",
        "polynomial",
        "cyclic",
        "exponential",
        # "adaptive",
        "constant",
        "cosine_with_restarts",
        "constant_with_warmup",

    ]



    # Different learning rate scheduler configurations
    configs = []

    # # Constant learning rate (baseline)
    # configs.append({
    #     **base_config,
    #     "experiment_name": "vit_constant_lr",
    #     "learning_rate": 2e-4,
    #     "scheduler_name": "constant",
    # })

    # # Add a schedule-free-only experiment
    # configs.append({
    #     **base_config,
    #     "experiment_name": "schedule_free_adamw_no_scheduler",
    #     "optimizer_name": "schedule_free_adamw",
    #     "learning_rate": 0.0002,
    #     "scheduler_name": None,  # Explicitly set to None
    # })


    for optimizer_name, learning_rate in optimizers.items():
        for scheduler_name in schedulers:
            config = {
                **base_config,
                "experiment_name": f"{optimizer_name}_{scheduler_name}",
                "optimizer_name": optimizer_name,
                "learning_rate": learning_rate,
                "scheduler_name": scheduler_name,
                "warmup_ratio": 0.1,  # Keep warmup ratio consistent
            }
            configs.append(config)


    # # Cosine with restarts
    # configs.append({
    #     **base_config,
    #     "experiment_name": "vit_cosine_restarts",
    #     "learning_rate": 2e-4,
    #     "scheduler_name": "cosine_with_restarts",
    #     "warmup_ratio": 0.1,
    # })

    # # Constant with warmup
    # configs.append({
    #     **base_config,
    #     "experiment_name": "vit_constant_warmup",
    #     "learning_rate": 2e-4,
    #     "scheduler_name": "constant_with_warmup",
    #     "warmup_ratio": 0.1,
    # })

    # # Linear decay
    # configs.append({
    #     **base_config,
    #     "experiment_name": "vit_linear_decay",
    #     "learning_rate": 5e-5,
    #     "scheduler_name": "linear",
    #     "warmup_ratio": 0.1,
    # })

    # # Cosine decay (commonly used with ViT)
    # configs.append({
    #     **base_config,
    #     "experiment_name": "vit_cosine_decay",
    #     "learning_rate": 5e-5,
    #     "scheduler_name": "cosine",
    #     "warmup_ratio": 0.1,
    # })


    # # Polynomial decay
    # configs.append({
    #     **base_config,
    #     "experiment_name": "vit_polynomial",
    #     "learning_rate": 5e-5,
    #     "scheduler_name": "polynomial",
    #     "warmup_ratio": 0.1,
    #     "poly_power": 2.0,
    # })


    # # Different learning rate experiments
    # for lr in [1e-5, 3e-5, 1e-4]:
    #     configs.append({
    #         **base_config,
    #         "experiment_name": f"vit_cosine_lr_{lr}",
    #         "learning_rate": lr,
    #         "scheduler_name": "cosine",
    #         "warmup_ratio": 0.1,
    #     })

    # # Different optimizer experiments
    # configs.append({
    #     **base_config,
    #     "experiment_name": "vit_sgd_cosine",
    #     "learning_rate": 0.01,  # Higher LR for SGD
    #     "scheduler_name": "cosine",
    #     "warmup_ratio": 0.1,
    #     "optimizer_name": "SGD",
    # })

    # here we can make changes to add new datasets to experiment.
    # or change batch_size to see the impact.
    # Other datasets to try (uncomment to use)
    #   Erik: We can use a data set as a strech. Maybe something less similar than plants for better contrasting comparison?
    # flowers dataset
    # configs.append({
    #     **base_config,
    #     "dataset_name": "huggan/flowers",
    #     "experiment_name": "vit_flowers_cosine",
    #     "learning_rate": 5e-5,
    #     "scheduler_name": "cosine",
    #     "warmup_ratio": 0.1,
    # })

    return configs

# Run experiments and visualize results
def run_all_experiments():
    configs = get_experiment_configs()
    results = []

    for config in configs:
        print(f"\n{'='*50}")
        print(f"Running experiment: {config['experiment_name']}")
        print(f"{'='*50}\n")

        eval_results = run_vit_experiment(config)
        results.append({
            "experiment": config['experiment_name'],
            "accuracy": eval_results["eval_accuracy"],
            "f1": eval_results["eval_f1"],
            "precision": eval_results["eval_precision"],
            "recall": eval_results["eval_recall"],
            "config": config
        })

    # Make sure to close the final run
    if wandb.run is not None:
        wandb.finish()

    return results

# Visualize and compare results
def visualize_results(results):
    # Create DataFrame for easier plotting
    df = pd.DataFrame([
        {
            "Experiment": result["experiment"],
            "Accuracy": result["accuracy"],
            "F1 Score": result["f1"],
            "Precision": result["precision"],
            "Recall": result["recall"],
            "Learning Rate": result["config"]["learning_rate"],
            "Scheduler": result["config"]["scheduler_name"],
            "Optimizer": result["config"]["optimizer_name"],
            "Dataset": result["config"]["dataset_name"]
        }
        for result in results
    ])

    # Plot accuracy comparison
    plt.figure(figsize=(14, 8))
    ax = plt.bar(df["Experiment"], df["Accuracy"], color='skyblue')
    plt.xlabel('Experiment')
    plt.ylabel('Accuracy')
    plt.title('Comparison of Model Accuracy Across Experiments')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig("accuracy_comparison.png")
    plt.close()

    # Plot all metrics for a more comprehensive comparison
    plt.figure(figsize=(16, 10))
    metrics = ["Accuracy", "F1 Score", "Precision", "Recall"]
    x = np.arange(len(df["Experiment"]))
    width = 0.2

    for i, metric in enumerate(metrics):
        plt.bar(x + i*width, df[metric], width=width, label=metric)

    plt.xlabel('Experiment')
    plt.ylabel('Score')
    plt.title('Comparison of Metrics Across Experiments')
    plt.xticks(x + width*1.5, df["Experiment"], rotation=45, ha='right')
    plt.legend()
    plt.tight_layout()
    plt.savefig("metrics_comparison.png")
    plt.close()

    # Plot results by scheduler type
    plt.figure(figsize=(14, 8))
    schedulers = df["Scheduler"].unique()
    for scheduler in schedulers:
        scheduler_data = df[df["Scheduler"] == scheduler]
        plt.plot(scheduler_data["Learning Rate"], scheduler_data["Accuracy"], 'o-', label=scheduler)

    plt.xlabel('Learning Rate')
    plt.ylabel('Accuracy')
    plt.title('Accuracy vs. Learning Rate by Scheduler Type')
    plt.xscale('log')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("scheduler_comparison.png")
    plt.close()

    # Create a table with results
    print("Results Summary:")
    print(df[["Experiment", "Accuracy", "F1 Score", "Precision", "Recall", "Scheduler", "Learning Rate", "Optimizer", "Dataset"]])

    # Save results to CSV
    df.to_csv("experiment_results.csv", index=False)

    return df

# Function to run a single experiment (useful for testing)
def run_single_experiment(experiment_index=0):
    configs = get_experiment_configs()
    if experiment_index >= len(configs):
        print(f"Invalid experiment index. Choose between 0 and {len(configs)-1}")
        return

    config = configs[experiment_index]
    print(f"Running single experiment: {config['experiment_name']}")
    eval_results = run_vit_experiment(config)

    print(f"\nResults for {config['experiment_name']}:")
    print(f"Accuracy: {eval_results['eval_accuracy']:.4f}")
    print(f"F1 Score: {eval_results['eval_f1']:.4f}")
    print(f"Precision: {eval_results['eval_precision']:.4f}")
    print(f"Recall: {eval_results['eval_recall']:.4f}")

    return eval_results


In [6]:
# # Uncomment to use: For 3 epoch sweep with only cosine scheduler to get best LR:
# def run_optimizer_sweep():
#     # Initialize W&B run first, then access config
#     with wandb.init() as run:
#         print(f"W&B initialized: {run.name}")

#         # Get config from sweep
#         config = wandb.config

#         # Set experiment name based on sweep parameters
#         custom_name = f"vit_{config.optimizer_name}_{config.learning_rate}"
#         # Update the run name after initialization
#         wandb.run.name = custom_name
#         wandb.run.save()

#         print(f"Running experiment: {custom_name}")

#         # Load model and processor
#         model_name = "google/vit-base-patch16-224-in21k"
#         dataset_name = "jbarat/plant_species"

#         # Load the image processor
#         image_processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)

#         # Prepare dataset
#         train_dataset, val_dataset, test_dataset, id2label, label2id, data_for_img = prepare_dataset(
#             dataset_name, image_processor, config.row_indx
#         )

#         # Load the ViT model
#         model = ViTForImageClassification.from_pretrained(
#             model_name,
#             num_labels=len(id2label),
#             id2label=id2label,
#             label2id=label2id,
#             ignore_mismatched_sizes=True
#         )

#         # Define training arguments
#         training_args = TrainingArguments(
#             output_dir=f"./results/{custom_name}",
#             per_device_train_batch_size=config.batch_size,
#             per_device_eval_batch_size=config.batch_size,
#             num_train_epochs=config.num_epochs,
#             weight_decay=0.01,
#             eval_strategy="steps",
#             save_strategy="steps",
#             logging_strategy="steps",
#             logging_steps=10,
#             load_best_model_at_end=True,
#             metric_for_best_model="accuracy",
#             push_to_hub=False,
#             report_to="wandb",
#             remove_unused_columns=False,
#             learning_rate=config.learning_rate,
#         )
#         # Set up optimizer based on config
#         if config.optimizer_name == "schedule_free_adamw":
#             optimizer = AdamWScheduleFree(
#                 model.parameters(),
#                 lr=config.learning_rate,  # Learning rate
#                 # warmup_steps=500  # Optional: Adjust based on your dataset
#             )
#         elif config.optimizer_name == "AdamW":
#             optimizer = AdamW(model.parameters(), lr=config.learning_rate)
#         elif config.optimizer_name == "SGD":
#             optimizer = SGD(model.parameters(), lr=config.learning_rate, momentum=0.9)
#         elif config.optimizer_name == "RMSProp":
#             optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate)
#         elif config.optimizer_name == "AdaGrad":
#             optimizer = torch.optim.Adagrad(model.parameters(), lr=config.learning_rate)
#         else:
#             optimizer = AdamW(model.parameters(), lr=config.learning_rate)

#         # Setup scheduler
#         num_training_steps = len(train_dataset) // config.batch_size * config.num_epochs
#         num_warmup_steps = int(num_training_steps * 0.1)  # 10% warmup

#         scheduler = get_scheduler(
#             config.scheduler_name,
#             optimizer=optimizer,
#             num_warmup_steps=num_warmup_steps,
#             num_training_steps=num_training_steps
#         )


#         # Initialize attention callback (self, feature_extractor, image, output_dir,total_epochs,steps_tune, device = 'cuda'):
#         attn_map_clbck = Attentionmapcallback(image_processor,
#                                       data_for_img['image'],
#                                       './attention_maps/' + str(config.scheduler_name).split('.')[-1] + "_" + str(config.optimizer_name).split('.')[-1],
#                                       training_args.num_train_epochs,
#                                       steps_tune = config.attention_steps)


#         # Initialize Trainer
#         trainer = Trainer(
#             model=model,
#             args=training_args,
#             train_dataset=train_dataset,
#             eval_dataset=val_dataset,
#             compute_metrics=compute_metrics,
#             optimizers=(optimizer, scheduler),
#             callbacks = [attn_map_clbck]
#         )

#         # Train the model
#         print(f"Starting training...")
#         # optimizer.train()  # Switch optimizer to training mode only for schedule_free
#         trainer.train()

#         # Evaluate on validation dataset
#         print(f"Evaluating on validation set...")
#         # optimizer.eval()  # Switch optimizer to evaluation mode only for schedule_free
#         eval_results = trainer.evaluate(val_dataset)

#         # Log validation metrics
#         run.log({
#             "val_accuracy": eval_results["eval_accuracy"],
#             "val_f1": eval_results["eval_f1"],
#             "val_precision": eval_results["eval_precision"],
#             "val_recall": eval_results["eval_recall"],
#             "val_loss": eval_results["eval_loss"]
#         })

#         # Evaluate on test dataset
#         print(f"Evaluating on test set...")
#         test_results = trainer.evaluate(test_dataset)

#         # Log test metrics
#         run.log({
#             "test_accuracy": test_results["eval_accuracy"],
#             "test_f1": test_results["eval_f1"],
#             "test_precision": test_results["eval_precision"],
#             "test_recall": test_results["eval_recall"],
#             "test_loss": test_results["eval_loss"]
#         })

#         # Compute confusion matrix for test set
#         predictions, labels, _ = trainer.predict(test_dataset)
#         predictions = np.argmax(predictions, axis=1)

#         # Log confusion matrix
#         run.log({
#             "confusion_matrix": wandb.plot.confusion_matrix(
#                 probs=None,
#                 y_true=labels.tolist(),
#                  preds=predictions.tolist(),
#                 class_names=[id2label[i] for i in range(len(id2label))]
#             )
#         })

#         # Save the model
#         model_path = f"./saved_models/{custom_name}"
#         trainer.save_model(model_path)
#         print(f"Model saved to {model_path}")

#         # Save the attn_fig
#         make_log_fig(config.attention_steps, config.num_epochs, config.scheduler_name, config.optimizer_name)

In [None]:

# Main execution
if __name__ == "__main__":
    print("Starting ViT experiments with different learning rate schedulers...")
    os.environ["WANDB_PROJECT"] = "ViT-LR-Schedulers"

    # Option 1: Run all experiments (time-consuming)
    results = run_all_experiments()
    results_df = visualize_results(results)

    # Option 2: Run a single experiment for testing
    # run_single_experiment(0)  # Try the baseline experiment first

    #option 3: Optimizer sweep:
    # Define sweep configuration
    # sweep_config = {
    #     "method": "grid",  # we can use "grid", "random", or "bayes"
    #     "metric": {
    #         "name": "val_accuracy",  # Metric to optimize
    #         "goal": "maximize"       # Goal: maximize or minimize
    #     },
    #     "parameters": {
    #         "optimizer_name": {
    #             "values": ["schedule_free_adamw","AdamW", "SGD", "RMSProp", "AdaGrad"]  # Optimizers to test
    #         },
    #         "learning_rate": {
    #             "values": [2e-5, 2e-4, 2e-3, 2e-2, 2e-1]  # Learning rates to test
    #         },
    #         "batch_size": {
    #             "values": [16]
    #         },
    #         "num_epochs": {
    #             "values": [3]
    #         },
    #         "scheduler_name": {
    #             "values": ["cosine"]
    #         }
    #     }
    # }

    # # Initialize the sweep
    # sweep_id = wandb.sweep(sweep_config, project="ViT-Optimizer-Sweep")
    # wandb.agent(sweep_id, function=run_optimizer_sweep)


    print("Experiments completed!")

Starting ViT experiments with different learning rate schedulers...

Running experiment: AdamW_linear



[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjraisinghani3[0m ([33mdl_project_sp25[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Loading dataset: jbarat/plant_species


README.md:   0%|          | 0.00/800 [00:00<?, ?B/s]

(…)-00000-of-00001-15efca0bf2e6a460.parquet:   0%|          | 0.00/82.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Dataset prepared with 640 training, 128 validation, and 32 test examples


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting training for AdamW_linear...




Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
10,2.0191,1.858745,0.546875,0.51483,0.527753,0.546875
20,1.6946,1.44201,0.6875,0.651616,0.747841,0.6875
30,1.356,1.130157,0.773438,0.77151,0.784385,0.773438
40,1.044,0.990746,0.734375,0.73433,0.765696,0.734375
50,0.6824,0.890501,0.757812,0.758126,0.785176,0.757812
60,0.6275,0.780365,0.789062,0.788324,0.8184,0.789062
70,0.529,0.732681,0.773438,0.773331,0.801166,0.773438
80,0.5094,0.660391,0.796875,0.796551,0.810506,0.796875
90,0.336,0.640551,0.8125,0.815499,0.853582,0.8125
100,0.2943,0.643465,0.804688,0.801936,0.821672,0.804688


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Evaluating AdamW_linear...



Running experiment: AdamW_cosine



0,1
eval/accuracy,▁▅▇▆▆▇▇▇████▆
eval/f1,▁▄▇▆▇▇▇█████▆
eval/loss,█▆▄▃▃▂▂▁▁▁▁▁▂
eval/precision,▁▆▆▆▇▇▇▇█▇██▇
eval/recall,▁▅▇▆▆▇▇▇████▆
eval/runtime,▅▅▅▆▅▅▅▆▆▆▇█▁
eval/samples_per_second,███▇███▇▇▇▃▁▃
eval/steps_per_second,███▇███▇▇▇▃▁▃
final_accuracy,▁
final_f1,▁

0,1
eval/accuracy,0.75
eval/f1,0.7224
eval/loss,0.74217
eval/precision,0.82465
eval/recall,0.75
eval/runtime,4.2365
eval/samples_per_second,7.553
eval/steps_per_second,0.472
final_accuracy,0.75
final_f1,0.7224


Loading dataset: jbarat/plant_species
Dataset prepared with 640 training, 128 validation, and 32 test examples


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting training for AdamW_cosine...


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
10,2.0238,1.832424,0.671875,0.661673,0.663536,0.671875
20,1.6395,1.36641,0.742188,0.714578,0.796885,0.742188
30,1.2955,1.087709,0.789062,0.790965,0.801439,0.789062
40,0.9854,0.941136,0.78125,0.781186,0.820797,0.78125


In [None]:
# Identify best of learning rate sweep:
# import pandas as pd

# # Load the exported CSV file
# df = pd.read_csv("wandb_export.csv")

# # Group by optimizer and find the best learning rate for each
# best_lr_per_optimizer = (
#     df.groupby("optimizer_name")
#     .apply(lambda group: group.loc[group["val_accuracy"].idxmax()])
#     [["optimizer_name", "learning_rate", "val_accuracy"]]
# )

# print(best_lr_per_optimizer)