In [None]:
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    get_scheduler,
)
import torch
from torch.optim import AdamW
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import os

# Environment settings for debugging CUDA
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Set checkpoint
checkpoint = "albert-large-v2"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# Define task-specific column mappings
sentence1_key = "sentence"
sentence2_key = None

# Load dataset for CoLA
dataset = load_dataset("glue", "cola")

# Define tokenization function
def tokenize_function(examples):
    return tokenizer(examples[sentence1_key], truncation=True, padding="max_length")

# Tokenize the dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True)

# Define data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Get number of labels
num_labels = dataset['train'].features['label'].num_classes

# Load model
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)

# Define metric computation function
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average='binary', zero_division=0)
    recall = recall_score(labels, predictions, average='binary', zero_division=0)
    f1 = f1_score(labels, predictions, average='binary', zero_division=0)
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}

# Define training arguments
training_args = TrainingArguments(
    output_dir="Albert-Large-cola",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=5,
    save_total_limit=5,
    logging_dir="logs-cola",
    logging_strategy="steps",
    logging_steps=5,  # Log loss every 10 steps
    learning_rate=2e-5,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

# Define custom optimizer and scheduler
def create_optimizer_and_scheduler(model):
    # Use AdamW as the optimizer
    optimizer = AdamW(model.parameters(), lr=1e-5)

    # Scheduler: Linear decay with warmup
    num_training_steps = (
        len(tokenized_dataset["train"]) // training_args.per_device_train_batch_size
    ) * training_args.num_train_epochs
    num_warmup_steps = int(0.1 * num_training_steps)  # 10% warmup
    scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )
    return optimizer, scheduler

# Create the optimizer and scheduler
optimizer, scheduler = create_optimizer_and_scheduler(model)

# Select eval_dataset
eval_dataset = tokenized_dataset["validation"]

# Define Trainer with new optimizer and scheduler
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, scheduler),  # Use AdamW and linear scheduler
)

# Train the model
trainer.train()

# Evaluate the model
predictions = trainer.predict(eval_dataset)
logits = predictions.predictions
y_true = predictions.label_ids
y_pred = np.argmax(logits, axis=-1)

# Classification Report
print("\nClassification Report for CoLA:")
print(classification_report(y_true, y_pred))

# Confusion Matrix
conf_matrix = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix for CoLA')
plt.show()
