In [1]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import evaluate
import pandas as pd
import sys

# Optional: Increase recursion limit (use with caution)
sys.setrecursionlimit(5000)

# Load 1% of the dataset
dataset = load_dataset("ag_news", split="train[:1%]")

# Load model and tokenizer
model_name = "distilbert-base-uncased"  # Use a model suitable for text classification
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4)

# Tokenization function
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

# Tokenize dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# Split dataset into train and test (80% train, 20% test)
train_size = int(0.8 * len(tokenized_dataset))
train_dataset = tokenized_dataset.select(range(train_size))
test_dataset = tokenized_dataset.select(range(train_size, len(tokenized_dataset)))

# Define evaluation metrics
metric = evaluate.load("accuracy")

# Function to fine-tune and evaluate the model
def fine_tune_and_evaluate(model, train_dataset, eval_dataset):
    training_args = TrainingArguments(
        output_dir="./results",
        evaluation_strategy="epoch",
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        weight_decay=0.01,
    )
    
    # Print predictions and labels for debugging
    def compute_metrics(p):
        preds = p.predictions.argmax(-1)
        labels = p.label_ids
        print(f"Predictions: {preds}, Labels: {labels}")
        return {"accuracy": metric.compute(predictions=preds, references=labels)}
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )
    
    trainer.train()
    results = trainer.evaluate()
    return results

# Dictionary to store results
results_summary = {}

# Zero-Shot Learning Evaluation
zero_shot_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4)
zero_shot_results = fine_tune_and_evaluate(zero_shot_model, train_dataset, test_dataset)
results_summary["Zero-Shot (1% Data)"] = zero_shot_results["eval_accuracy"]

# Few-Shot Learning Evaluations with different shot sizes
for few_shot_size in [1, 5, 10]:  # Adjust few-shot sizes as needed
    few_shot_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4)
    few_shot_train_dataset = train_dataset.shuffle(seed=42).select(range(few_shot_size))  # Select few-shot examples
    few_shot_results = fine_tune_and_evaluate(few_shot_model, few_shot_train_dataset, test_dataset)
    results_summary[f"Few-Shot ({few_shot_size})"] = few_shot_results["eval_accuracy"]

# Create a DataFrame to summarize results
results_df = pd.DataFrame(list(results_summary.items()), columns=["Method", "Accuracy"])

# Print the final results table
print(results_df)

# Save results to a CSV file
results_df.to_csv("results_summary.csv", index=False)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                 
 33%|███▎      | 120/360 [00:36<01:05,  3.64it/s]

Predictions: [2 2 3 3 3 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 2 0 0 2 2 2 2 3 0 1 0 0 2 0 3 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 2 0 3 0 3 3 1 1 3 3 3 3 0 3 1 1 2 3 3 3 3 3 3 1 1 0 0
 0 2 3 1 3 1 2 2 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 1 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 3 2 0 2 2 3 1 1 1 0 1 0 0 1 3 3 3 3 0 3 3 3 1 2 3 2 3 3 1 0 0 1 0
 0 2 0 1 1 3 2 3 3 1 3 3 3 3 1 3 0 3 1 1 0 3 3 3 3 1 0 0 3 3 3 3 1 2 2 3 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 2 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

                                                 
 67%|██████▋   | 240/360 [01:12<00:33,  3.55it/s]

Predictions: [2 2 3 3 3 2 3 3 2 2 3 3 3 2 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 2 3 2 2
 0 0 3 3 1 1 0 0 2 0 0 2 2 2 2 3 0 1 0 0 2 0 0 2 3 2 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 0 1 2 3 0 1 0 2 0 0 0 3 3 1 1 0 3 0 3 0 3 1 1 2 3 3 3 3 3 3 1 1 0 0
 0 2 3 1 2 0 2 2 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 1 2 3 3 3 3 2 3 3 3 3 3
 3 1 1 1 2 2 0 2 2 3 1 1 1 0 1 0 0 1 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 2 0 1 1 3 2 3 3 1 3 3 3 3 1 3 0 3 1 1 0 3 2 3 3 1 0 0 3 3 3 3 1 2 2 3 0
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 2 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

                                                 
100%|██████████| 360/360 [01:49<00:00,  3.30it/s]


Predictions: [2 2 3 3 3 2 3 3 2 2 3 3 3 2 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 2 3 2 2
 3 3 3 3 1 1 0 0 2 0 0 2 2 2 2 3 0 1 0 0 2 0 3 2 3 2 1 0 0 2 2 3 3 0 0 0 0
 2 2 3 1 1 2 3 0 1 1 2 0 3 0 3 3 1 1 0 3 3 3 0 3 1 1 2 3 3 3 3 3 3 1 1 0 0
 0 2 3 1 3 0 2 2 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 1 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 2 0 2 2 3 1 1 1 0 1 0 3 1 3 3 3 3 0 3 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 2 0 1 1 3 2 3 3 1 3 3 3 3 1 3 0 3 1 1 0 3 2 3 3 1 0 0 3 3 3 3 1 2 2 3 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 2 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

100%|██████████| 30/30 [00:02<00:00, 12.61it/s]


Predictions: [2 2 3 3 3 2 3 3 2 2 3 3 3 2 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 2 3 2 2
 3 3 3 3 1 1 0 0 2 0 0 2 2 2 2 3 0 1 0 0 2 0 3 2 3 2 1 0 0 2 2 3 3 0 0 0 0
 2 2 3 1 1 2 3 0 1 1 2 0 3 0 3 3 1 1 0 3 3 3 0 3 1 1 2 3 3 3 3 3 3 1 1 0 0
 0 2 3 1 3 0 2 2 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 1 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 2 0 2 2 3 1 1 1 0 1 0 3 1 3 3 3 3 0 3 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 2 0 1 1 3 2 3 3 1 3 3 3 3 1 3 0 3 1 1 0 3 2 3 3 1 0 0 3 3 3 3 1 2 2 3 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 2 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                             
 67%|██████▋   | 2/3 [00:04<00:02,  2.29s/it]  

Predictions: [0 0 0 0 1 2 0 0 0 0 0 0 1 1 1 0 0 0 0 0 1 0 2 2 0 2 0 0 0 0 0 0 0 1 0 0 0
 0 2 2 1 0 1 0 0 1 1 0 0 0 0 0 0 2 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 2 0 0
 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0
 0 0 0 0 0 1 0 0 0 2 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 1 2 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 1 0 1 0 1 0 0 1 0 2 0 0 1 0 1 2 0 0 0 1 1
 1 1 0 0 0 0 0 1 0 0 1 0 2 0 1 0 0 0 1 1 0 0 1 0 0 0 0 0 2 1 0 0 0 0 0 2 0
 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

                                             
100%|██████████| 3/3 [00:07<00:00,  2.40s/it]  

Predictions: [2 0 0 2 2 2 2 2 0 0 0 2 1 2 2 2 0 2 0 2 2 2 2 2 2 2 0 0 2 0 2 2 2 1 2 0 0
 2 2 2 2 2 2 2 2 1 1 2 0 0 0 2 2 2 2 2 2 2 2 1 2 2 2 0 2 0 2 2 0 2 2 2 2 0
 2 2 2 1 2 0 2 2 0 0 0 2 2 2 2 2 0 2 2 2 2 2 2 2 0 2 2 2 2 2 2 2 2 2 1 2 0
 2 0 2 2 0 2 2 2 2 2 2 0 0 0 2 2 2 0 0 2 0 2 2 2 2 0 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 0 0 2 2 2 0 2 0 2 2 2 2 2 2 2 2 2 2 0 2 2 2 2 0 2 2 2 2 2 2 0 2 2 2
 2 2 2 2 2 2 0 1 2 2 2 0 2 2 2 0 0 0 2 1 2 0 2 0 2 0 2 2 2 2 0 2 2 2 2 2 2
 2 2 0 2 2 2 2 2 2 2 2 1 2 2 1 2 2 2], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

                                             
100%|██████████| 3/3 [00:10<00:00,  3.48s/it]  


Predictions: [2 2 2 2 2 2 2 2 0 2 2 2 2 2 2 2 0 2 0 2 2 2 2 2 2 2 0 2 2 2 2 2 2 2 2 2 0
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 0 2 2 2 2 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 0 2 2 2 2 0 2 2 2 2 2 2 0 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 0 2 2 2 2 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 2 2 2 2 2 2 0 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

100%|██████████| 30/30 [00:02<00:00, 12.78it/s]


Predictions: [2 2 2 2 2 2 2 2 0 2 2 2 2 2 2 2 0 2 0 2 2 2 2 2 2 2 0 2 2 2 2 2 2 2 2 2 0
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 0 2 2 2 2 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 0 2 2 2 2 0 2 2 2 2 2 2 0 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 0 2 2 2 2 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 2 2 2 2 2 2 0 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                             
 33%|███▎      | 1/3 [00:03<00:01,  1.19it/s]  

Predictions: [3 0 3 3 3 3 0 3 3 3 0 3 3 1 3 3 3 3 3 3 3 3 3 3 3 3 0 0 3 3 3 3 3 1 3 3 3
 0 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 0 3 3 3 3 3 0 3 0 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 0 0 0 3 3 1 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 0
 3 3 3 3 3 3 3 3 3 3 0 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 0 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 0 3 3 3 3 3 3 3 3 3 3 3 0 3 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

                                             
 67%|██████▋   | 2/3 [00:06<00:01,  1.93s/it]  

Predictions: [3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

                                             
100%|██████████| 3/3 [00:09<00:00,  3.15s/it]  


Predictions: [3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

100%|██████████| 30/30 [00:02<00:00, 12.52it/s]


Predictions: [3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                             
 33%|███▎      | 2/6 [00:03<00:02,  1.56it/s]  

Predictions: [3 0 3 3 3 3 0 3 0 0 0 3 3 0 3 3 3 0 3 3 3 3 3 3 3 3 0 0 0 0 3 3 3 0 3 3 3
 0 0 3 3 3 3 3 0 3 3 3 3 3 3 3 3 3 0 0 3 3 0 3 0 0 3 3 3 0 0 0 0 3 3 0 0 0
 3 3 3 3 0 0 3 3 0 3 3 0 0 3 3 3 3 0 3 3 0 3 3 0 3 3 3 0 0 0 0 0 3 3 3 3 0
 3 3 3 3 0 3 0 3 3 3 3 3 0 0 3 3 3 3 0 3 3 0 3 0 3 0 3 3 3 3 3 3 0 3 0 3 0
 0 3 0 0 3 3 0 3 0 3 0 3 0 3 3 3 3 3 3 3 3 3 3 3 3 3 0 0 3 0 3 3 0 0 3 3 3
 3 3 3 3 0 3 3 3 3 0 3 0 3 0 3 3 3 0 3 3 0 3 3 3 3 3 3 3 3 3 3 3 0 3 3 3 0
 3 3 0 3 0 3 3 3 0 0 3 3 0 3 0 3 0 0], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

                                             
 67%|██████▋   | 4/6 [00:06<00:02,  1.02s/it]  

Predictions: [3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 0 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

                                             
100%|██████████| 6/6 [00:10<00:00,  1.74s/it]  


Predictions: [3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
{'

100%|██████████| 30/30 [00:02<00:00, 12.33it/s]

Predictions: [3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3], Labels: [2 2 3 3 2 2 3 3 2 2 3 3 3 3 3 3 3 1 1 1 0 0 0 0 0 3 3 3 3 0 3 3 3 3 3 2 2
 3 3 3 3 1 1 0 0 0 0 2 2 2 2 3 3 0 0 0 2 2 3 0 2 3 3 1 0 0 2 2 3 3 1 0 0 0
 2 2 3 1 1 2 3 1 1 1 0 0 0 0 3 3 1 1 0 3 3 3 0 3 1 1 2 2 3 3 3 3 3 1 1 0 0
 0 2 3 1 0 0 0 0 0 0 3 1 0 0 3 3 3 1 1 2 2 3 1 0 0 0 3 3 3 3 3 3 3 3 3 3 3
 3 1 1 1 2 0 0 0 2 3 1 1 1 0 0 0 0 0 3 3 3 3 0 2 3 3 1 2 2 2 3 3 1 0 0 1 0
 0 3 0 0 1 3 2 3 3 1 3 3 3 3 1 3 0 0 1 1 0 3 3 3 3 1 0 0 3 3 3 3 0 0 2 2 3
 3 1 1 1 0 0 0 3 0 0 1 1 0 2 2 3 3 3]
  




NameError: name 'trainer' is not defined

In [8]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import evaluate
import pandas as pd
import psutil
import time
import sys

# Optional: Increase recursion limit (use with caution)
sys.setrecursionlimit(5000)

# Load 1% of the dataset
dataset = load_dataset("ag_news", split="train[:10%]")

# Load model and tokenizer
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4)

# Tokenization function
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

# Tokenize dataset
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# Split dataset into train and test (80% train, 20% test)
train_size = int(0.8 * len(tokenized_dataset))
train_dataset = tokenized_dataset.select(range(train_size))
test_dataset = tokenized_dataset.select(range(train_size, len(tokenized_dataset)))

# Load metrics
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def get_memory_usage():
    process = psutil.Process()
    return process.memory_info().rss / (1024 * 1024)  # Convert to MB

# Function to fine-tune and evaluate the model
def fine_tune_and_evaluate(model, train_dataset, eval_dataset):
    training_args = TrainingArguments(
        output_dir="./results",
        evaluation_strategy="epoch",
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        weight_decay=0.01,
    )
    
    def compute_metrics(p):
        preds = p.predictions.argmax(-1)
        labels = p.label_ids
        accuracy = accuracy_metric.compute(predictions=preds, references=labels)['accuracy']
        f1 = f1_metric.compute(predictions=preds, references=labels, average='weighted')['f1']
        return {"accuracy": accuracy, "f1": f1}
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )
    
    start_time = time.time()
    memory_before = get_memory_usage()
    
    trainer.train()
    
    training_time = time.time() - start_time
    results = trainer.evaluate()
    
    memory_after = get_memory_usage()
    memory_usage = memory_after - memory_before

    return {
        "accuracy": results["eval_accuracy"],
        "f1": results["eval_f1"],
        "training_time": training_time,
        "memory_usage": memory_usage,
    }

# Dictionary to store results
results_summary = {}

# Zero-Shot Learning Evaluation
zero_shot_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4)
zero_shot_results = fine_tune_and_evaluate(zero_shot_model, train_dataset, test_dataset)
results_summary["Zero-Shot (1% Data)"] = zero_shot_results

# Few-Shot Learning Evaluations with different shot sizes
for few_shot_size in [1, 5, 10, 50]:  # Adjust few-shot sizes as needed
    few_shot_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4)
    few_shot_train_dataset = train_dataset.shuffle(seed=42).select(range(few_shot_size))  # Select few-shot examples
    few_shot_results = fine_tune_and_evaluate(few_shot_model, few_shot_train_dataset, test_dataset)
    results_summary[f"Few-Shot ({few_shot_size})"] = few_shot_results

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
 33%|███▎      | 120/360 [02:05<04:10,  1.04s/it]
                                                  

 14%|█▍        | 500/3600 [02:20<14:23,  3.59it/s]
[A

{'loss': 0.4822, 'grad_norm': 0.4884715676307678, 'learning_rate': 4.305555555555556e-05, 'epoch': 0.42}


                                                   

 28%|██▊       | 1000/3600 [04:42<12:17,  3.53it/s]
[A

{'loss': 0.3821, 'grad_norm': 8.839018821716309, 'learning_rate': 3.611111111111111e-05, 'epoch': 0.83}


                                                   

[A[A                                          
 33%|███▎      | 1200/3600 [06:03<10:57,  3.65it/s]
[A

{'eval_loss': 0.3473186790943146, 'eval_accuracy': 0.90875, 'eval_f1': 0.9086682625236364, 'eval_runtime': 23.639, 'eval_samples_per_second': 101.527, 'eval_steps_per_second': 12.691, 'epoch': 1.0}


                                                     

 42%|████▏     | 1500/3600 [07:26<09:39,  3.62it/s]
[A

{'loss': 0.292, 'grad_norm': 0.1605522632598877, 'learning_rate': 2.916666666666667e-05, 'epoch': 1.25}


                                                   

 56%|█████▌    | 2000/3600 [09:47<07:22,  3.62it/s]
[A

{'loss': 0.1971, 'grad_norm': 0.45863035321235657, 'learning_rate': 2.2222222222222223e-05, 'epoch': 1.67}


                                                   

[A[A                                          
 67%|██████▋   | 2400/3600 [12:04<05:36,  3.57it/s]
[A

{'eval_loss': 0.3589439392089844, 'eval_accuracy': 0.90625, 'eval_f1': 0.9057354982210138, 'eval_runtime': 24.32, 'eval_samples_per_second': 98.684, 'eval_steps_per_second': 12.336, 'epoch': 2.0}


                                                     

 69%|██████▉   | 2500/3600 [12:32<05:05,  3.60it/s]
[A

{'loss': 0.198, 'grad_norm': 23.372535705566406, 'learning_rate': 1.527777777777778e-05, 'epoch': 2.08}


                                                   

 83%|████████▎ | 3000/3600 [14:52<02:47,  3.59it/s]
[A

{'loss': 0.1236, 'grad_norm': 11.23639965057373, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.5}


                                                   

 97%|█████████▋| 3500/3600 [17:14<00:28,  3.56it/s]
[A

{'loss': 0.1191, 'grad_norm': 0.39663928747177124, 'learning_rate': 1.388888888888889e-06, 'epoch': 2.92}


                                                   

[A[A                                          
100%|██████████| 3600/3600 [18:09<00:00,  3.58it/s]
                                                   

100%|██████████| 3600/3600 [18:09<00:00,  3.58it/s]
100%|██████████| 3600/3600 [18:09<00:00,  3.30it/s]


{'eval_loss': 0.35999736189842224, 'eval_accuracy': 0.9229166666666667, 'eval_f1': 0.9226888188503383, 'eval_runtime': 24.4736, 'eval_samples_per_second': 98.065, 'eval_steps_per_second': 12.258, 'epoch': 3.0}
{'train_runtime': 1089.3205, 'train_samples_per_second': 26.439, 'train_steps_per_second': 3.305, 'train_loss': 0.2515468692779541, 'epoch': 3.0}


100%|██████████| 300/300 [00:24<00:00, 12.32it/s]
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                     
[A                                              

 33%|███▎      | 1/3 [00:24<00:48, 24.33s/it]   
 67%|██████▋   | 2/3 [00:24<00:12, 12.22s/it]

{'eval_loss': 1.398409128189087, 'eval_accuracy': 0.19583333333333333, 'eval_f1': 0.10809213584667975, 'eval_runtime': 24.2306, 'eval_samples_per_second': 99.048, 'eval_steps_per_second': 12.381, 'epoch': 1.0}


                                             
[A                                              

 67%|██████▋   | 2/3 [00:48<00:12, 12.22s/it]   
100%|██████████| 3/3 [00:48<00:00, 17.27s/it]

{'eval_loss': 1.3940852880477905, 'eval_accuracy': 0.2520833333333333, 'eval_f1': 0.17297814931533642, 'eval_runtime': 24.2329, 'eval_samples_per_second': 99.039, 'eval_steps_per_second': 12.38, 'epoch': 2.0}


                                             
[A                                              

100%|██████████| 3/3 [01:13<00:00, 17.27s/it]   
                                             

100%|██████████| 3/3 [01:13<00:00, 17.27s/it]   
100%|██████████| 3/3 [01:13<00:00, 24.64s/it]


{'eval_loss': 1.392171859741211, 'eval_accuracy': 0.27291666666666664, 'eval_f1': 0.13745657182161553, 'eval_runtime': 24.3969, 'eval_samples_per_second': 98.373, 'eval_steps_per_second': 12.297, 'epoch': 3.0}
{'train_runtime': 73.9183, 'train_samples_per_second': 0.041, 'train_steps_per_second': 0.041, 'train_loss': 1.3554191589355469, 'epoch': 3.0}


100%|██████████| 300/300 [00:24<00:00, 12.46it/s]
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                             
[A                                              

 33%|███▎      | 1/3 [00:25<00:00,  4.73it/s]   
[A

{'eval_loss': 1.391747236251831, 'eval_accuracy': 0.22791666666666666, 'eval_f1': 0.1149667602922741, 'eval_runtime': 24.7858, 'eval_samples_per_second': 96.83, 'eval_steps_per_second': 12.104, 'epoch': 1.0}


                                             
[A                                              

 67%|██████▋   | 2/3 [00:50<00:14, 14.79s/it]   
[A

{'eval_loss': 1.381982445716858, 'eval_accuracy': 0.26208333333333333, 'eval_f1': 0.18470287271073765, 'eval_runtime': 24.882, 'eval_samples_per_second': 96.455, 'eval_steps_per_second': 12.057, 'epoch': 2.0}


                                             

[A[A                                          
100%|██████████| 3/3 [01:15<00:00, 19.49s/it]    
                                             

100%|██████████| 3/3 [01:15<00:00, 19.49s/it]   
100%|██████████| 3/3 [01:15<00:00, 25.28s/it]


{'eval_loss': 1.3770618438720703, 'eval_accuracy': 0.32875, 'eval_f1': 0.25914819892847807, 'eval_runtime': 24.8431, 'eval_samples_per_second': 96.606, 'eval_steps_per_second': 12.076, 'epoch': 3.0}
{'train_runtime': 75.8454, 'train_samples_per_second': 0.198, 'train_steps_per_second': 0.04, 'train_loss': 1.3426005045572917, 'epoch': 3.0}


100%|██████████| 300/300 [00:24<00:00, 12.13it/s]
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
 33%|███▎      | 2/6 [00:00<00:00,  5.29it/s]
[A                                              

                                                
 33%|███▎      | 2/6 [00:24<00:00,  5.29it/s]

{'eval_loss': 1.3786132335662842, 'eval_accuracy': 0.29041666666666666, 'eval_f1': 0.2278929465077241, 'eval_runtime': 24.4987, 'eval_samples_per_second': 97.964, 'eval_steps_per_second': 12.246, 'epoch': 1.0}


 67%|██████▋   | 4/6 [00:25<00:13,  6.97s/it]
[A                                              

                                                
 67%|██████▋   | 4/6 [00:50<00:13,  6.97s/it]

{'eval_loss': 1.3601192235946655, 'eval_accuracy': 0.3408333333333333, 'eval_f1': 0.2434442013138334, 'eval_runtime': 24.6825, 'eval_samples_per_second': 97.235, 'eval_steps_per_second': 12.154, 'epoch': 2.0}


100%|██████████| 6/6 [00:50<00:00,  8.93s/it]

                                                
[A                                              
100%|██████████| 6/6 [01:15<00:00,  8.93s/it]

                                                
100%|██████████| 6/6 [01:15<00:00, 12.65s/it]


{'eval_loss': 1.3519377708435059, 'eval_accuracy': 0.32375, 'eval_f1': 0.2164901052475208, 'eval_runtime': 24.7046, 'eval_samples_per_second': 97.148, 'eval_steps_per_second': 12.144, 'epoch': 3.0}
{'train_runtime': 75.9021, 'train_samples_per_second': 0.395, 'train_steps_per_second': 0.079, 'train_loss': 1.3189056714375813, 'epoch': 3.0}


100%|██████████| 300/300 [00:24<00:00, 12.23it/s]
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                              
[A                                              

 33%|███▎      | 7/21 [00:26<00:03,  4.26it/s]  
[A

{'eval_loss': 1.3262349367141724, 'eval_accuracy': 0.3283333333333333, 'eval_f1': 0.2316293969960446, 'eval_runtime': 24.8891, 'eval_samples_per_second': 96.428, 'eval_steps_per_second': 12.053, 'epoch': 1.0}


                                               
[A                                              

 67%|██████▋   | 14/21 [00:53<00:07,  1.11s/it] 
[A

{'eval_loss': 1.2044686079025269, 'eval_accuracy': 0.6029166666666667, 'eval_f1': 0.5847078027954136, 'eval_runtime': 24.8127, 'eval_samples_per_second': 96.725, 'eval_steps_per_second': 12.091, 'epoch': 2.0}


                                               
[A                                              

100%|██████████| 21/21 [01:21<00:00,  1.18s/it] 
                                               

100%|██████████| 21/21 [01:21<00:00,  1.18s/it] 
100%|██████████| 21/21 [01:21<00:00,  3.86s/it]


{'eval_loss': 1.1472645998001099, 'eval_accuracy': 0.6654166666666667, 'eval_f1': 0.661081531987111, 'eval_runtime': 24.8236, 'eval_samples_per_second': 96.682, 'eval_steps_per_second': 12.085, 'epoch': 3.0}
{'train_runtime': 81.1087, 'train_samples_per_second': 1.849, 'train_steps_per_second': 0.259, 'train_loss': 1.1700435820079984, 'epoch': 3.0}


100%|██████████| 300/300 [00:24<00:00, 12.16it/s]


In [15]:
# Function to evaluate the model without fine-tuning
def evaluate_model(model, eval_dataset):
    training_args = TrainingArguments(
        output_dir="./results",
        per_device_eval_batch_size=8,
        evaluation_strategy="epoch",
        logging_dir="./logs",  # Optional: specify logging directory
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        eval_dataset=eval_dataset,
        compute_metrics=lambda p: {
            "accuracy": accuracy_metric.compute(predictions=p.predictions.argmax(-1), references=p.label_ids)['accuracy'],
            "f1": f1_metric.compute(predictions=p.predictions.argmax(-1), references=p.label_ids, average='weighted')['f1'],
        },
    )
    
    results = trainer.evaluate()
    return {
        "accuracy": results["eval_accuracy"],
        "f1": results["eval_f1"],
        "training_time": None,  # No training time for zero-shot
        "memory_usage": None,    # No memory usage for zero-shot
    }

# Load the model again for zero-shot evaluation
zero_shot_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4)

# Evaluate the model on the test dataset without fine-tuning
zero_shot_results = evaluate_model(zero_shot_model, test_dataset)
results_summary["Zero-Shot (After Initial Training)"] = zero_shot_results

# Print the results
print(results_summary)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A

[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A




{'Zero-Shot (1% Data)': {'accuracy': 0.9229166666666667, 'f1': 0.9226888188503383, 'training_time': 1089.4382178783417, 'memory_usage': 161.046875}, 'Few-Shot (1)': {'accuracy': 0.27291666666666664, 'f1': 0.13745657182161553, 'training_time': 74.0447449684143, 'memory_usage': 7.234375}, 'Few-Shot (5)': {'accuracy': 0.32875, 'f1': 0.25914819892847807, 'training_time': 75.96170377731323, 'memory_usage': 10.109375}, 'Few-Shot (10)': {'accuracy': 0.32375, 'f1': 0.2164901052475208, 'training_time': 76.01416873931885, 'memory_usage': 18.375}, 'Few-Shot (50)': {'accuracy': 0.6654166666666667, 'f1': 0.661081531987111, 'training_time': 81.24914598464966, 'memory_usage': 7.640625}, 'Zero-Shot (After Initial Training)': {'accuracy': 0.17833333333333334, 'f1': 0.17282327816844673, 'training_time': None, 'memory_usage': None}}





In [16]:
results_summary

{'Zero-Shot (1% Data)': {'accuracy': 0.9229166666666667,
  'f1': 0.9226888188503383,
  'training_time': 1089.4382178783417,
  'memory_usage': 161.046875},
 'Few-Shot (1)': {'accuracy': 0.27291666666666664,
  'f1': 0.13745657182161553,
  'training_time': 74.0447449684143,
  'memory_usage': 7.234375},
 'Few-Shot (5)': {'accuracy': 0.32875,
  'f1': 0.25914819892847807,
  'training_time': 75.96170377731323,
  'memory_usage': 10.109375},
 'Few-Shot (10)': {'accuracy': 0.32375,
  'f1': 0.2164901052475208,
  'training_time': 76.01416873931885,
  'memory_usage': 18.375},
 'Few-Shot (50)': {'accuracy': 0.6654166666666667,
  'f1': 0.661081531987111,
  'training_time': 81.24914598464966,
  'memory_usage': 7.640625},
 'Zero-Shot (After Initial Training)': {'accuracy': 0.17833333333333334,
  'f1': 0.17282327816844673,
  'training_time': None,
  'memory_usage': None}}

In [17]:
# Create a DataFrame from the results_summary
results_df = pd.DataFrame.from_dict(results_summary, orient='index')

# Reset index to have Method as a column
results_df.reset_index(inplace=True)
results_df.rename(columns={'index': 'Method'}, inplace=True)

# Export the DataFrame to a CSV file
results_df.to_csv("results_summary.csv", index=False)

# Print the DataFrame to verify
print(results_df)

                               Method  accuracy        f1  training_time  \
0                 Zero-Shot (1% Data)  0.922917  0.922689    1089.438218   
1                        Few-Shot (1)  0.272917  0.137457      74.044745   
2                        Few-Shot (5)  0.328750  0.259148      75.961704   
3                       Few-Shot (10)  0.323750  0.216490      76.014169   
4                       Few-Shot (50)  0.665417  0.661082      81.249146   
5  Zero-Shot (After Initial Training)  0.178333  0.172823            NaN   

   memory_usage  
0    161.046875  
1      7.234375  
2     10.109375  
3     18.375000  
4      7.640625  
5           NaN  
