<a href="https://colab.research.google.com/github/jsl5710/greenland/blob/main/GREENLAND_Fine_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Step 1: Setup & Installation

In [1]:
# Install and upgrade necessary libraries
!pip install --quiet --upgrade pip
!pip install --quiet --upgrade transformers
!pip install --quiet --upgrade datasets
!pip install --quiet --upgrade wandb
!pip install --quiet git+https://github.com/huggingface/peft.git peft

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m92.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m93.7 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.[0m[31m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.1/16.1 MB[0m [31m114.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for peft (pyproject.toml) ... [?25l[?25h

# Step 2: Import Libraries

In [2]:
import os
import torch
import wandb
import pandas as pd
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)
from datasets import load_dataset, Dataset
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
from peft import get_peft_model, LoraConfig, TaskType, AutoPeftModelForSequenceClassification
from google.colab import drive
from requests.exceptions import HTTPError

# Step 3: Define Model Checkpoints

In [3]:
model_checkpoints = {
    "MBERT_uncased": {
        "path": "google-bert/bert-base-multilingual-uncased",
        "max_length": 512
    },
    # "XLM_100": {
    #     "path": "FacebookAI/xlm-mlm-100-1280",
    #     "max_length": 512
    # },
    # "XLM_17": {
    #     "path": "FacebookAI/xlm-mlm-17-1280",
    #     "max_length": 512
    # },
    # "XLM-RoBERTa_xxl": {
    #     "path": "facebook/xlm-roberta-xxl",
    #     "max_length": 512
    # },
    # "mDeBERTa_v3_base": {
    #     "path": "microsoft/mdeberta-v3-base",
    #     "max_length": 512
    # },
    # "S-BERT_LaBSE": {
    #     "path": "sentence-transformers/LaBSE",
    #     "max_length": 512
    # },
    # "S-BERT_distiluse": {
    #     "path": "sentence-transformers/distiluse-base-multilingual-cased",
    #     "max_length": 512
    # },
    # "XLM-R_bernice": {
    #     "path": "jhu-clsp/bernice",
    #     "max_length": 512
    # },
    # "XLM-T_twitter": {
    #     "path": "cardiffnlp/twitter-xlm-roberta-base",
    #     "max_length": 512
    # },
    # "XLM-E_align": {
    #     "path": "microsoft/xlm-align-base",
    #     "max_length": 512
    # },
    # "XLM-E_infoxlm_large": {
    #     "path": "microsoft/infoxlm-large",
    #     "max_length": 512
    # },
    # "XLM-V_base": {
    #     "path": "facebook/xlm-v-base",
    #     "max_length": 512
    # }
}


# model_checkpoints = {
#     "MBERT_uncased": "google-bert/bert-base-multilingual-uncased",
#     # "MBERT_cased": "google-bert/bert-base-multilingual-cased",
#     "XLM_100": "FacebookAI/xlm-mlm-100-1280",
#     "XLM_17": "FacebookAI/xlm-mlm-17-1280",
#     # "XLM-RoBERTa_large": "FacebookAI/xlm-roberta-large",
#     # "XLM-RoBERTa_base": "FacebookAI/xlm-roberta-base",
#     # "XLM-RoBERTa_xl": "facebook/xlm-roberta-xl",
#     "XLM-RoBERTa_xxl": "facebook/xlm-roberta-xxl",
#     "mDeBERTa_v3_base": "microsoft/mdeberta-v3-base",
#     # "M-distilBERT": "distilbert/distilbert-base-multilingual-cased",
#     "S-BERT_LaBSE": "sentence-transformers/LaBSE",
#     "S-BERT_distiluse": "sentence-transformers/distiluse-base-multilingual-cased",
#     "XLM-R_bernice": "jhu-clsp/bernice",
#     "XLM-T_twitter": "cardiffnlp/twitter-xlm-roberta-base",
#     "XLM-E_align": "microsoft/xlm-align-base",
#     # "XLM-E_infoxlm_base": "microsoft/infoxlm-base",
#     "XLM-E_infoxlm_large": "microsoft/infoxlm-large",
#     "XLM-V_base": "facebook/xlm-v-base"
# }

# model_checkpoints = {
    # "MBERT_uncased": "google-bert/bert-base-multilingual-uncased",
    # "XLM_100": "FacebookAI/xlm-mlm-100-1280",
    # "XLM_17": "FacebookAI/xlm-mlm-17-1280",
    # "XLM-RoBERTa_xxl": "facebook/xlm-roberta-xxl",
    # "mDeBERTa_v3_base": "microsoft/mdeberta-v3-base",
    # "S-BERT_LaBSE": "sentence-transformers/LaBSE",
    # "S-BERT_distiluse": "sentence-transformers/distiluse-base-multilingual-cased",
    # "XLM-R_bernice": "jhu-clsp/bernice",
    # "XLM-T_twitter": "cardiffnlp/twitter-xlm-roberta-base",
    # "XLM-E_align": "microsoft/xlm-align-base",
    #     "XLM-E_infoxlm_large": "microsoft/infoxlm-large",
    # "XLM-V_base": "facebook/xlm-v-base"
# }



# Step 4: Authenticate and Initialize

In [4]:
# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# Authenticate with Hugging Face
!huggingface-cli login --token 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

# Authenticate with W&B
wandb.login(key="1b5caf38a8b6ada0e6918798e9379b2ea764062d")
wandb.init(project="greenland")


Mounted at /content/drive
The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
The token `greenland` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `greenland`


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjasonsamlucas[0m ([33mpike[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Step 5: Define Save Paths and Ensure Directories Exist

In [5]:
# Define save locations
local_save_path = "/content/sample_data/best_models/"
drive_save_path = "/content/drive/MyDrive/GREENLAND/Modeling/Best_models/"
results_dir = "/content/drive/MyDrive/GREENLAND/Results/"

# Ensure save directories exist
os.makedirs(local_save_path, exist_ok=True)
os.makedirs(drive_save_path, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)


# Step 6: Load and Process the Dataset

In [6]:
# Load datasets from CSV files in Google Drive
train_df = pd.read_csv('/content/drive/MyDrive/GREENLAND/Datasets/Consolidated_Data/Experiment_Training_Splits/train_data.csv')
val_df = pd.read_csv('/content/drive/MyDrive/GREENLAND/Datasets/Consolidated_Data/Experiment_Training_Splits/val_data.csv')
test_df = pd.read_csv('/content/drive/MyDrive/GREENLAND/Datasets/Consolidated_Data/Experiment_Training_Splits/test_data.csv')

# Convert to Hugging Face Dataset format
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Combine datasets into a dictionary for easy access
dataset = {
    "train": train_dataset,
    "validation": val_dataset,
    "test": test_dataset
}

# Step 7: Define Dataset Processing Functions

In [7]:
def tokenize_datasets(model_name, dataset):
    model_info = model_checkpoints[model_name]
    tokenizer = AutoTokenizer.from_pretrained(model_info["path"])
    max_length = model_info["max_length"]

    print(f"Using max_length={max_length} for model {model_name}")

    def preprocess_function(examples):
        # Dynamically pad to save memory, but cap at max_length
        return tokenizer(
            examples["text"],
            truncation=True,
            padding=True,  # Changed from "max_length" to True for dynamic padding
            max_length=max_length,
            return_tensors=None  # Don't convert to tensors yet
        )

    # Apply tokenization with batching for speed
    tokenized_data = {
        split: data.map(
            preprocess_function,
            batched=True,
            batch_size=1000,  # Increased batch size
            num_proc=4,  # Use multiple processes
            remove_columns=data.column_names,
            desc=f"Tokenizing {split} set"
        )
        for split, data in dataset.items()
    }

    return tokenized_data

def analyze_text_lengths(dataset):
    """
    Analyze text lengths in the dataset without tokenization first
    """
    # Get raw text lengths
    lengths = [len(text.split()) for text in dataset["train"]["text"]]

    stats = {
        "average_length": sum(lengths)/len(lengths),
        "max_length": max(lengths),
        "median_length": sorted(lengths)[len(lengths)//2],
        "95th_percentile": sorted(lengths)[int(len(lengths)*0.95)],
        "length_distribution": {
            "< 128 words": sum(1 for l in lengths if l < 128),
            "128-256 words": sum(1 for l in lengths if 128 <= l < 256),
            "256-512 words": sum(1 for l in lengths if 256 <= l < 512),
            "> 512 words": sum(1 for l in lengths if l >= 512)
        }
    }

    # Calculate percentages for distribution
    total_samples = len(lengths)
    stats["length_distribution_percent"] = {
        k: (v/total_samples * 100) for k, v in stats["length_distribution"].items()
    }

    print("\nText Length Analysis (word-based):")
    print(f"Average length: {stats['average_length']:.1f} words")
    print(f"Median length: {stats['median_length']} words")
    print(f"Max length: {stats['max_length']} words")
    print(f"95th percentile: {stats['95th_percentile']} words")
    print("\nLength Distribution:")
    for category, count in stats["length_distribution"].items():
        percentage = stats["length_distribution_percent"][category]
        print(f"{category}: {count} texts ({percentage:.1f}%)")

    # Character-based analysis
    char_lengths = [len(text) for text in dataset["train"]["text"]]
    stats["char_stats"] = {
        "average_length": sum(char_lengths)/len(char_lengths),
        "max_length": max(char_lengths),
        "median_length": sorted(char_lengths)[len(char_lengths)//2],
        "95th_percentile": sorted(char_lengths)[int(len(char_lengths)*0.95)]
    }

    print("\nCharacter-based Analysis:")
    print(f"Average length: {stats['char_stats']['average_length']:.1f} characters")
    print(f"Median length: {stats['char_stats']['median_length']} characters")
    print(f"Max length: {stats['char_stats']['max_length']} characters")
    print(f"95th percentile: {stats['char_stats']['95th_percentile']} characters")

    return stats

# Step 8: Define Loss Functions

In [8]:
class WeightedBinaryCrossEntropyLoss(torch.nn.Module):
    def __init__(self, pos_weight):
        super(WeightedBinaryCrossEntropyLoss, self).__init__()
        self.register_buffer('pos_weight', pos_weight)  # Register as buffer for proper device management

    def forward(self, logits, labels):
        labels = labels.float()
        return torch.nn.functional.binary_cross_entropy_with_logits(
            logits[:, 1],
            labels,
            pos_weight=self.pos_weight
        )

class WeightedFocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, labels):
        # Apply softmax for multi-class probabilities
        probs = torch.softmax(logits, dim=1)[:, 1]  # Probability for positive class
        labels = labels.float()
        BCE_loss = torch.nn.functional.binary_cross_entropy(probs, labels, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

class SymmetricCrossEntropyLoss(torch.nn.Module):
    def __init__(self, alpha=0.1, beta=1.0):
        super(SymmetricCrossEntropyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta

    def forward(self, logits, labels):
        ce_loss = torch.nn.functional.cross_entropy(logits, labels)
        labels_one_hot = torch.nn.functional.one_hot(labels, num_classes=logits.size(-1))
        rce_loss = -((torch.softmax(logits, dim=1) * labels_one_hot).sum(dim=-1).log().mean())
        return self.alpha * ce_loss + self.beta * rce_loss

class SquaredBCEWithLogitsLoss(torch.nn.Module):
    def forward(self, logits, labels):
        labels = labels.float()
        probs = torch.sigmoid(logits[:, 1])
        return torch.mean((probs - labels) ** 2)

class SupervisedContrastiveCrossEntropyLoss(torch.nn.Module):
    def __init__(self, temperature=0.07, lam=0.5):
        super(SupervisedContrastiveCrossEntropyLoss, self).__init__()
        self.temperature = temperature
        self.lam = lam

    def forward(self, logits, labels):
        batch_size = logits.size(0)
        labels = labels.view(-1, 1)
        mask = (labels == labels.T).float()

        # Temperature scaling
        logits_scaled = logits / self.temperature

        # Calculate exp(logits)
        exp_logits = torch.exp(logits_scaled)

        # Calculate positive and negative terms
        pos_mask = mask
        neg_mask = 1 - mask

        pos = torch.exp(logits_scaled) * pos_mask
        neg = torch.exp(logits_scaled) * neg_mask

        # Calculate log sum, adding epsilon for numerical stability
        pos_sum = torch.log(pos.sum(1) + 1e-9)
        neg_sum = torch.log(neg.sum(1) + 1e-9)

        # Calculate final loss
        loss = -self.lam * pos_sum.mean() - (1 - self.lam) * neg_sum.mean()

        return loss

# Step 9: Loss Functions Factory

In [9]:
# Loss function dictionary
def get_loss_functions(device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return {
        # "CrossEntropyLoss": torch.nn.CrossEntropyLoss().to(device),
        "BCEWithLogitsLoss": torch.nn.BCEWithLogitsLoss(reduction='mean').to(device),  # Added reduction='mean'
        "SquaredBCEWithLogitsLoss": SquaredBCEWithLogitsLoss().to(device),
        "WeightedBinaryCrossEntropy": WeightedBinaryCrossEntropyLoss(
            pos_weight=torch.tensor([3.0]).to(device)
        ),
        "WeightedFocalLoss": WeightedFocalLoss(
            alpha=0.25,
            gamma=2
        ).to(device),
        "SymmetricCrossEntropy": SymmetricCrossEntropyLoss(
            alpha=0.1,
            beta=1.0
        ).to(device),
        "SupervisedContrastiveCrossEntropyLoss": SupervisedContrastiveCrossEntropyLoss(
            temperature=0.07,
            lam=0.5
        ).to(device),
    }

# Step 10: Evaluation Metrics

In [10]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    try:
        roc_auc = roc_auc_score(labels, preds)
    except ValueError:
        roc_auc = 0  # Handle cases where there might be only one class

    return {
        'accuracy': accuracy_score(labels, preds),
        'f1': f1_score(labels, preds, average='binary'),
        'precision': precision_score(labels, preds, average='binary'),
        'recall': recall_score(labels, preds, average='binary'),
        'roc_auc': roc_auc
    }


# Step 11: Custom Trainer

In [11]:
import logging
logging.basicConfig(level=logging.INFO)

class CustomTrainer(Trainer):
    def __init__(self, *args, loss_func=None, processing_class=None, **kwargs):
        super().__init__(*args, processing_class=processing_class, **kwargs)
        self.loss_func = loss_func

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        labels = labels.long()

        outputs = model(**inputs)
        logits = outputs.get("logits")

        if self.loss_func is not None:
            if isinstance(self.loss_func, torch.nn.BCEWithLogitsLoss):
                # Convert labels to one-hot encoded format
                batch_size = labels.size(0)
                labels_one_hot = torch.zeros(batch_size, 2, device=labels.device)
                labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)
                loss = self.loss_func(logits, labels_one_hot)

            elif isinstance(self.loss_func, SquaredBCEWithLogitsLoss):
                # Convert labels for SquaredBCEWithLogitsLoss
                batch_size = labels.size(0)
                labels_one_hot = torch.zeros(batch_size, 2, device=labels.device)
                labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)
                loss = self.loss_func(logits, labels_one_hot)

            elif isinstance(self.loss_func, WeightedBinaryCrossEntropyLoss):
                # Convert labels for WeightedBinaryCrossEntropyLoss
                batch_size = labels.size(0)
                labels_one_hot = torch.zeros(batch_size, 2, device=labels.device)
                labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)
                loss = self.loss_func(logits, labels_one_hot)

            else:
                # For CrossEntropyLoss and other losses that expect class indices
                loss = self.loss_func(logits, labels)
        else:
            # Use default loss computation
            loss = outputs.get("loss")

        return (loss, outputs) if return_outputs else loss


# Step 12: Model Save/Load Functions

In [12]:
def save_model_with_fallback(trainer, model_name):
    try:
        trainer.push_to_hub(f"jslai/{model_name}")
        print(f"Model saved to Hugging Face Hub as jslai/{model_name}")
    except Exception as e:
        print(f"Failed to save to Hugging Face Hub: {e}")
        try:
            trainer.save_model(os.path.join(drive_save_path, model_name))
            print(f"Model saved to Google Drive at {drive_save_path}/{model_name}")
        except Exception as e:
            print(f"Failed to save to Google Drive: {e}")
            trainer.save_model(os.path.join(local_save_path, model_name))
            print(f"Model saved locally at {local_save_path}/{model_name}")

def load_best_model(model_name):
    try:
        print(f"Attempting to load {model_name} from Hugging Face Hub.")
        model = AutoModelForSequenceClassification.from_pretrained(f"jslai/{model_name}")
    except (OSError, HTTPError) as e:
        print(f"Failed to load {model_name} from Hugging Face Hub: {e}")
        try:
            google_drive_path = os.path.join(drive_save_path, model_name)
            if os.path.isdir(google_drive_path):
                print(f"Attempting to load {model_name} from Google Drive.")
                model = AutoModelForSequenceClassification.from_pretrained(google_drive_path)
            else:
                raise OSError(f"Directory {google_drive_path} does not exist on Google Drive.")
        except (OSError, HTTPError) as e:
            print(f"Failed to load {model_name} from Google Drive: {e}")
            try:
                local_path = os.path.join(local_save_path, model_name)
                if os.path.isdir(local_path):
                    print(f"Attempting to load {model_name} from local storage.")
                    model = AutoModelForSequenceClassification.from_pretrained(local_path)
                else:
                    raise OSError(f"Directory {local_path} does not exist in local storage.")
            except (OSError, HTTPError) as e:
                print(f"Failed to load {model_name} from local storage: {e}")
                raise FileNotFoundError(f"Model {model_name} could not be found in any location.")
    return model

# Step 13: Training Functions

In [14]:

def full_fine_tune_all_models(model_checkpoints, dataset, loss_functions=None):
    # Initialize device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Get loss functions if not provided
    if loss_functions is None:
        loss_functions = get_loss_functions(device)

    for model_name, model_info in model_checkpoints.items():
        # Analyze text lengths for this model
        # stats = analyze_text_lengths(model_name, dataset)

        # Tokenize dataset specific to model
        tokenized_data = tokenize_datasets(model_name, dataset)

        for loss_fn_name, loss_fn in loss_functions.items():
            try:
                print(f"\nTraining {model_name} with {loss_fn_name}")
                print(f"Using device: {device}")

                # Clear CUDA cache before loading new model
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                # Initialize model and tokenizer
                tokenizer = AutoTokenizer.from_pretrained(model_info["path"])
                model = AutoModelForSequenceClassification.from_pretrained(
                    model_info["path"],
                    num_labels=2,
                    problem_type="single_label_classification"
                ).to(device)

                # Ensure loss function is on correct device
                loss_fn = loss_fn.to(device)

                # Define training arguments
                training_args = TrainingArguments(
                    output_dir=f"{local_save_path}/{model_name}_{loss_fn_name}_full_ft",
                    eval_strategy="epoch",
                    save_strategy="epoch",
                    learning_rate=2e-5,
                    per_device_train_batch_size=8, # Increased from 8
                    per_device_eval_batch_size=8,  # Increased from 8
                    num_train_epochs=3,
                    weight_decay=0.01,
                    load_best_model_at_end=True,
                    metric_for_best_model="f1",
                    logging_dir="./logs",
                    report_to="wandb",
                    logging_steps=100,
                    fp16=True,  # Enable mixed precision training
                    half_precision_backend="cuda_amp",  # Specify the backend
                    gradient_checkpointing=True,  # Enable gradient checkpointing
                    gradient_accumulation_steps=2,  # Added gradient accumulation
                    warmup_ratio=0.1,              # Added warmup
                    dataloader_num_workers=4,      # Added multiple workers
                    dataloader_pin_memory=True,    # Added pin memory
                    seed=42                        # Added for reproducibility
                )

                # Initialize trainer
                trainer = CustomTrainer(
                    model=model,
                    args=training_args,
                    train_dataset=tokenized_data["train"],
                    eval_dataset=tokenized_data["validation"],
                    processing_class=tokenizer,
                    data_collator=DataCollatorWithPadding(tokenizer,
                                                          padding=True,
                                                          pad_to_multiple_of=8),  # For better GPU utilization
                    compute_metrics=compute_metrics,
                    loss_func=loss_fn
                )

                # Train and save model
                trainer.train()
                save_model_with_fallback(trainer, f"{model_name}_{loss_fn_name}_full_ft")

                # Clean up to free memory
                del model, trainer, tokenizer
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error training {model_name} with {loss_fn_name}: {e}")
                continue

def peft_fine_tune_all_models(model_checkpoints, dataset, loss_functions=None):
    # Initialize device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Get loss functions if not provided
    if loss_functions is None:
        loss_functions = get_loss_functions(device)

    # Define LoRA configuration
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=16,  # attention heads
        lora_alpha=32,  # alpha scaling
        lora_dropout=0.1,  # dropout probability
        bias="none",  # bias parameters
        inference_mode=False,  # training mode
    )

    for model_name, model_info in model_checkpoints.items():
        # Tokenize dataset specific to model
        tokenized_data = tokenize_datasets(model_name, dataset)

        for loss_fn_name, loss_fn in loss_functions.items():
            try:
                print(f"\nFine-tuning {model_name} with PEFT (LoRA) using {loss_fn_name}")
                print(f"Using device: {device}")

                # Initialize tokenizer and base model
                tokenizer = AutoTokenizer.from_pretrained(model_info["path"])
                base_model = AutoModelForSequenceClassification.from_pretrained(
                    model_info["path"],
                    num_labels=2,
                    problem_type="single_label_classification"
                )

                # Get PEFT model
                model = get_peft_model(base_model, lora_config)
                model.print_trainable_parameters()

                # Move model to device
                model = model.to(device)

                # Ensure loss function is on correct device
                loss_fn = loss_fn.to(device)

                # Define training arguments
                training_args = TrainingArguments(
                    output_dir=f"{local_save_path}/{model_name}_{loss_fn_name}_peft_lora",
                    eval_strategy="epoch",
                    save_strategy="epoch",
                    learning_rate=2e-5,
                    per_device_train_batch_size=8,
                    per_device_eval_batch_size=8,
                    num_train_epochs=3,
                    weight_decay=0.01,
                    load_best_model_at_end=True,
                    metric_for_best_model="f1",
                    logging_dir="./logs",
                    report_to="wandb",
                    logging_steps=100,
                    gradient_checkpointing=True,
                    gradient_accumulation_steps=4,
                    fp16=True,
                    half_precision_backend="cuda_amp",  # Specify the backend
                    optim="adamw_torch"
                )

                # Initialize trainer
                trainer = CustomTrainer(
                    model=model,
                    args=training_args,
                    train_dataset=tokenized_data["train"],
                    eval_dataset=tokenized_data["validation"],
                    processing_class=tokenizer,
                    data_collator=DataCollatorWithPadding(tokenizer),
                    compute_metrics=compute_metrics,
                    loss_func=loss_fn
                )

                # Train the model
                trainer.train()

                # Save the model and adapter
                output_dir = f"{local_save_path}/{model_name}_{loss_fn_name}_peft_lora"
                save_model_with_fallback(trainer, output_dir)
                model.save_pretrained(f"{output_dir}/adapter")

                # Clear memory
                del model, base_model
                torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error training {model_name} with {loss_fn_name}: {e}")
                continue

# Step 14: Inference Function

In [15]:
def run_inference_and_save_results(model_checkpoints, test_df, results_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    predictions_df_list = []

    # Get list of all trained models
    loss_functions = get_loss_functions(device)

    for model_name, model_info in model_checkpoints.items():
        # For each training method (full fine-tuning with different loss functions)
        for loss_fn_name in loss_functions.keys():
            try:
                # Full fine-tuning model
                full_ft_model_name = f"{model_name}_{loss_fn_name}_full_ft"
                model = load_best_model(full_ft_model_name).to(device)
                tokenizer = AutoTokenizer.from_pretrained(model_info["path"])

                inputs = tokenizer(
                    list(test_df["text"]),
                    truncation=True,
                    padding=True,
                    max_length=model_info["max_length"],
                    return_tensors="pt"
                ).to(device)

                with torch.no_grad():
                    outputs = model(**inputs)
                    preds = outputs.logits.argmax(dim=-1).cpu().numpy()

                result_df = test_df.copy()
                result_df["prediction"] = preds
                predictions_df_list.append((full_ft_model_name, result_df))

                # PEFT model
                peft_model_name = f"{model_name}_{loss_fn_name}_peft_lora"
                peft_model_path = os.path.join(local_save_path, peft_model_name, "adapter")

                if os.path.exists(peft_model_path):
                    model = load_best_model(peft_model_name).to(device)

                    with torch.no_grad():
                        outputs = model(**inputs)
                        preds = outputs.logits.argmax(dim=-1).cpu().numpy()

                    result_df = test_df.copy()
                    result_df["prediction"] = preds
                    predictions_df_list.append((peft_model_name, result_df))

                # Clear memory
                del model
                torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error during inference for {model_name}: {e}")
                continue

    # Save all predictions
    for model_name, result_df in predictions_df_list:
        result_file_path = os.path.join(results_dir, f"{model_name}_predictions.csv")
        result_df.to_csv(result_file_path, index=False)
        print(f"Saved predictions for {model_name} to {result_file_path}")


# Step 15: Main Execution

In [None]:
if __name__ == "__main__":
    # Initialize device and wandb
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize loss functions
    loss_functions = get_loss_functions(device)

    # Analyze dataset once
    print("Analyzing dataset text lengths...")
    dataset_stats = analyze_text_lengths(dataset)

    # Run full fine-tuning
    print("\nStarting Full Fine-Tuning with all models and loss functions...")
    full_fine_tune_all_models(model_checkpoints, dataset, loss_functions)

    # # Run PEFT fine-tuning
    # print("\nStarting PEFT Fine-Tuning with LoRA on all models...")
    # peft_fine_tune_all_models(model_checkpoints, dataset, loss_functions)

    # # Run inference and save results
    # print("\nRunning inference and saving predictions...")
    # run_inference_and_save_results(model_checkpoints, test_df, results_dir)

    print("\nExperiments completed!")
    wandb.finish()

Using device: cpu
Analyzing dataset text lengths...

Text Length Analysis (word-based):
Average length: 254.6 words
Median length: 138 words
Max length: 17608 words
95th percentile: 863 words

Length Distribution:
< 128 words: 222456 texts (48.0%)
128-256 words: 86918 texts (18.8%)
256-512 words: 87173 texts (18.8%)
> 512 words: 66793 texts (14.4%)

Character-based Analysis:
Average length: 1627.8 characters
Median length: 929 characters
Max length: 137662 characters
95th percentile: 5380 characters

Starting Full Fine-Tuning with all models and loss functions...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/872k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.72M [00:00<?, ?B/s]

Using max_length=512 for model MBERT_uncased


Tokenizing train set (num_proc=4):   0%|          | 0/463340 [00:00<?, ? examples/s]