In [None]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader, WeightedRandomSampler
import numpy as np
from tqdm import tqdm
import evaluate

dataset = load_dataset("snli")
checkpoint_folder = "./New Folder With Items"

model = AutoModelForSequenceClassification.from_pretrained(checkpoint_folder)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_folder)

# tokenize
def tokenize(batch):
    return tokenizer(batch['premise'], batch['hypothesis'], truncation=True, padding="max_length", max_length=128)

encoded_dataset = dataset.map(tokenize, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")

# filter out examples with `-1` labels
encoded_dataset = encoded_dataset.filter(lambda example: example["labels"] != -1)

def collate_fn(batch):
    return {key: torch.stack([torch.tensor(item[key]) if not isinstance(item[key], torch.Tensor) else item[key]
                              for item in batch])
            for key in batch[0] if isinstance(batch[0][key], (int, float, list, torch.Tensor))}

def get_data_loader(split, batch_size=8):
    return DataLoader(encoded_dataset[split].with_format("torch"), batch_size=batch_size, collate_fn=collate_fn)

# track predictions and confidence scores across multiple epochs
def calculate_learning_dynamics(data_loader, model, num_epochs=3):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    confidences, correctness = [], []
    
    for epoch in range(num_epochs):
        epoch_confidences, epoch_correctness = [], []
        
        for batch in tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
                logits = outputs.logits
                probs = torch.softmax(logits, dim=-1)
                
                # store confidence and correctness
                max_probs, preds = torch.max(probs, dim=-1)
                epoch_confidences.extend(max_probs.cpu().numpy())
                epoch_correctness.extend((preds == batch["labels"]).cpu().numpy())
        
        confidences.append(epoch_confidences)
        correctness.append(epoch_correctness)
    
    return np.array(confidences), np.array(correctness)

# get learning dynamics
train_loader = get_data_loader("train", batch_size=8)
confidences, correctness = calculate_learning_dynamics(train_loader, model, num_epochs=1)

# classify examples into easy, ambiguous, and hard-to-learn
def classify_examples(confidences, correctness):
    avg_confidence = confidences.mean(axis=0)
    consistency = correctness.mean(axis=0)
    
    easy_indices = np.where((avg_confidence > 0.8) & (consistency == 1))[0]
    hard_indices = np.where((avg_confidence < 0.5) & (consistency < 0.5))[0]
    ambiguous_indices = np.where((avg_confidence >= 0.5) & (avg_confidence <= 0.8) & (consistency < 1))[0]
    
    return easy_indices, ambiguous_indices, hard_indices

easy_indices, ambiguous_indices, hard_indices = classify_examples(confidences, correctness)

Epoch 1/1: 100%|██████████| 68671/68671 [1:42:32<00:00, 11.16it/s]  


TypeError: Wrong key type: '112' of type '<class 'numpy.int64'>'. Expected one of int, slice, range, str or Iterable.

In [None]:
import random

def create_oversampled_dataset(dataset, hard_indices, duplication_factor=1, sample_fraction=0.5):
   # sample a subset of all examples 
    total_examples = len(dataset)
    sample_size = int(sample_fraction * total_examples)
    all_examples_sample = random.sample([dataset[i] for i in range(total_examples)], sample_size)
    
   # get hard-to-learn examples
    hard_examples = [dataset[int(i)] for i in hard_indices]
   # duplicate each hard example by dupe factor
    oversampled_hard_examples = hard_examples * duplication_factor
    
   # combine the sampled examples with oversampled hard examples
    combined_data = all_examples_sample + oversampled_hard_examples
    return Dataset.from_dict({k: [d[k] for d in combined_data] for k in combined_data[0]})

oversampled_train_dataset = create_oversampled_dataset(encoded_dataset["train"], hard_indices)

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=2,
    per_device_train_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    learning_rate=5e-5
)

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=oversampled_train_dataset,
    eval_dataset=encoded_dataset["validation"],
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

trainer.train()

results = trainer.evaluate()
print(f"Final Accuracy on the Validation Set: {results['eval_accuracy']:.4f}")

  0%|          | 0/69192 [00:00<?, ?it/s]

{'loss': 0.3767, 'grad_norm': 0.25721991062164307, 'learning_rate': 4.963868655335877e-05, 'epoch': 0.01}
{'loss': 0.3737, 'grad_norm': 32.28351974487305, 'learning_rate': 4.927737310671754e-05, 'epoch': 0.03}
{'loss': 0.3743, 'grad_norm': 9.496427536010742, 'learning_rate': 4.8916059660076315e-05, 'epoch': 0.04}
{'loss': 0.3721, 'grad_norm': 35.85519027709961, 'learning_rate': 4.855474621343508e-05, 'epoch': 0.06}
{'loss': 0.3882, 'grad_norm': 12.20327377319336, 'learning_rate': 4.819343276679385e-05, 'epoch': 0.07}
{'loss': 0.3948, 'grad_norm': 20.38226890563965, 'learning_rate': 4.783211932015262e-05, 'epoch': 0.09}
{'loss': 0.3815, 'grad_norm': 12.41571044921875, 'learning_rate': 4.7470805873511394e-05, 'epoch': 0.1}
{'loss': 0.3973, 'grad_norm': 20.384620666503906, 'learning_rate': 4.710949242687016e-05, 'epoch': 0.12}
{'loss': 0.3772, 'grad_norm': 22.44713020324707, 'learning_rate': 4.6748178980228926e-05, 'epoch': 0.13}
{'loss': 0.393, 'grad_norm': 14.110427856445312, 'learning_

  0%|          | 0/1231 [00:00<?, ?it/s]

{'eval_loss': 0.39658471941947937, 'eval_accuracy': 0.8834586466165414, 'eval_runtime': 51.7517, 'eval_samples_per_second': 190.177, 'eval_steps_per_second': 23.787, 'epoch': 1.0}
{'loss': 0.3326, 'grad_norm': 6.218863487243652, 'learning_rate': 2.4708058735113888e-05, 'epoch': 1.01}
{'loss': 0.354, 'grad_norm': 4.712692737579346, 'learning_rate': 2.4346745288472657e-05, 'epoch': 1.03}
{'loss': 0.3278, 'grad_norm': 37.56756591796875, 'learning_rate': 2.3985431841831427e-05, 'epoch': 1.04}
{'loss': 0.3546, 'grad_norm': 36.81509780883789, 'learning_rate': 2.3624118395190197e-05, 'epoch': 1.06}
{'loss': 0.3216, 'grad_norm': 20.366846084594727, 'learning_rate': 2.3262804948548966e-05, 'epoch': 1.07}
{'loss': 0.3307, 'grad_norm': 0.9261724352836609, 'learning_rate': 2.2901491501907736e-05, 'epoch': 1.08}
{'loss': 0.351, 'grad_norm': 15.74485969543457, 'learning_rate': 2.2540178055266506e-05, 'epoch': 1.1}
{'loss': 0.3668, 'grad_norm': 39.46297073364258, 'learning_rate': 2.2178864608625275e-

  0%|          | 0/1231 [00:00<?, ?it/s]

{'eval_loss': 0.44166556000709534, 'eval_accuracy': 0.8892501524080472, 'eval_runtime': 59.8171, 'eval_samples_per_second': 164.535, 'eval_steps_per_second': 20.579, 'epoch': 2.0}
{'train_runtime': 13238.9109, 'train_samples_per_second': 41.811, 'train_steps_per_second': 5.226, 'train_loss': 0.35926475556878934, 'epoch': 2.0}


  0%|          | 0/1231 [00:00<?, ?it/s]

Final Accuracy on the Validation Set: 0.8893


In [None]:
test_dataset = encoded_dataset["test"]

test_results = trainer.evaluate(eval_dataset=test_dataset)
print(f"Final accuracy on test: {test_results['eval_accuracy']:.4f}")

  0%|          | 0/1228 [00:00<?, ?it/s]

Final Accuracy on the Test Set: 0.8881


In [None]:
# create subsets of the training dataset
from torch.utils.data import Subset
from transformers import TrainingArguments

easy_indices = list(easy_indices)
ambiguous_indices = list(ambiguous_indices)
hard_indices = list(hard_indices)

# take 50% of each subset
easy_dataset = Subset(encoded_dataset["train"], random.sample(easy_indices, len(easy_indices) // 2))
ambiguous_dataset = Subset(encoded_dataset["train"], random.sample(ambiguous_indices, len(ambiguous_indices) // 2))
hard_dataset = Subset(encoded_dataset["train"], random.sample(hard_indices, len(hard_indices) // 2))

training_args_2 = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    learning_rate=5e-5
)

def data_collator_2(batch):
    collated_batch = {}
    for key in batch[0]:
        values = [item[key] for item in batch]
        
        if not isinstance(values[0], torch.Tensor):
            values = [torch.tensor(v) for v in values]
        
        collated_batch[key] = torch.stack(values)
    return collated_batch

# train in stages
def curriculum_learning_train(trainer, phases):
    for i, phase in enumerate(phases, 1):
        print(f"Starting Phase {i} with {len(phase['dataset'])} examples for {phase['num_epochs']} epochs.")
        trainer.train_dataset = phase['dataset']
        trainer.args.num_train_epochs = phase['num_epochs']
        trainer.train()


def compute_metrics_2(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer_2 = Trainer(
    model=model,
    args=training_args_2,
    train_dataset=easy_dataset, 
    eval_dataset=encoded_dataset["validation"],
    data_collator=data_collator_2,
    compute_metrics=compute_metrics_2
)

phases = [
    {"dataset": easy_dataset, "num_epochs": 2}, # easy examples for 2 epochs
    {"dataset": ambiguous_dataset, "num_epochs": 2}, # ambiguous examples for 2 epochs
    {"dataset": hard_dataset, "num_epochs": 3} # hard examples for 3 epochs
]

curriculum_learning_train(trainer_2, phases)

results_2 = trainer_2.evaluate()

print(f"Final Accuracy on the Validation Set: {results_2['eval_accuracy']:.4f}")


Starting Phase 1 with 245774 examples for 2 epochs.


  0%|          | 0/61444 [00:00<?, ?it/s]

{'loss': 0.1901, 'grad_norm': 0.022779373452067375, 'learning_rate': 4.959312544756201e-05, 'epoch': 0.02}
{'loss': 0.1664, 'grad_norm': 0.22065651416778564, 'learning_rate': 4.918625089512402e-05, 'epoch': 0.03}
{'loss': 0.1653, 'grad_norm': 150.581787109375, 'learning_rate': 4.877937634268602e-05, 'epoch': 0.05}
{'loss': 0.1884, 'grad_norm': 0.08925431221723557, 'learning_rate': 4.837250179024803e-05, 'epoch': 0.07}
{'loss': 0.1884, 'grad_norm': 0.10277227312326431, 'learning_rate': 4.796562723781004e-05, 'epoch': 0.08}
{'loss': 0.1977, 'grad_norm': 60.49163055419922, 'learning_rate': 4.755875268537205e-05, 'epoch': 0.1}
{'loss': 0.1624, 'grad_norm': 65.51551055908203, 'learning_rate': 4.715187813293406e-05, 'epoch': 0.11}
{'loss': 0.1833, 'grad_norm': 0.14181557297706604, 'learning_rate': 4.674500358049607e-05, 'epoch': 0.13}
{'loss': 0.1914, 'grad_norm': 64.29419708251953, 'learning_rate': 4.633812902805807e-05, 'epoch': 0.15}
{'loss': 0.1763, 'grad_norm': 0.4182526767253876, 'lear

  0%|          | 0/1231 [00:00<?, ?it/s]

{'eval_loss': 0.6968719959259033, 'eval_accuracy': 0.8908758382442593, 'eval_runtime': 52.117, 'eval_samples_per_second': 188.844, 'eval_steps_per_second': 23.62, 'epoch': 1.0}
{'loss': 0.1521, 'grad_norm': 0.041422486305236816, 'learning_rate': 2.4773777748844475e-05, 'epoch': 1.01}
{'loss': 0.1122, 'grad_norm': 0.03492776304483414, 'learning_rate': 2.4366903196406483e-05, 'epoch': 1.03}
{'loss': 0.1026, 'grad_norm': 16.27977180480957, 'learning_rate': 2.3960028643968495e-05, 'epoch': 1.04}
{'loss': 0.109, 'grad_norm': 0.11146517843008041, 'learning_rate': 2.35531540915305e-05, 'epoch': 1.06}
{'loss': 0.1032, 'grad_norm': 103.19365692138672, 'learning_rate': 2.3146279539092508e-05, 'epoch': 1.07}
{'loss': 0.0938, 'grad_norm': 55.30762481689453, 'learning_rate': 2.2739404986654516e-05, 'epoch': 1.09}
{'loss': 0.1208, 'grad_norm': 0.020430075004696846, 'learning_rate': 2.233253043421652e-05, 'epoch': 1.11}
{'loss': 0.1103, 'grad_norm': 0.014562027528882027, 'learning_rate': 2.1925655881

  0%|          | 0/1231 [00:00<?, ?it/s]

{'eval_loss': 0.8658453226089478, 'eval_accuracy': 0.8937207884576306, 'eval_runtime': 59.8275, 'eval_samples_per_second': 164.506, 'eval_steps_per_second': 20.576, 'epoch': 2.0}
{'train_runtime': 12790.1592, 'train_samples_per_second': 38.432, 'train_steps_per_second': 4.804, 'train_loss': 0.13611223402607528, 'epoch': 2.0}
Starting Phase 2 with 4705 examples for 2 epochs.


  0%|          | 0/1178 [00:00<?, ?it/s]

{'loss': 1.1396, 'grad_norm': 8.577925682067871, 'learning_rate': 2.8777589134125638e-05, 'epoch': 0.85}


  0%|          | 0/1231 [00:00<?, ?it/s]

{'eval_loss': 0.7110785841941833, 'eval_accuracy': 0.5989636252794147, 'eval_runtime': 55.0249, 'eval_samples_per_second': 178.864, 'eval_steps_per_second': 22.372, 'epoch': 1.0}
{'loss': 0.8931, 'grad_norm': 18.122976303100586, 'learning_rate': 7.5551782682512745e-06, 'epoch': 1.7}


  0%|          | 0/1231 [00:00<?, ?it/s]

{'eval_loss': 0.5401347875595093, 'eval_accuracy': 0.7590936801463117, 'eval_runtime': 51.562, 'eval_samples_per_second': 190.877, 'eval_steps_per_second': 23.874, 'epoch': 2.0}
{'train_runtime': 335.4837, 'train_samples_per_second': 28.049, 'train_steps_per_second': 3.511, 'train_loss': 0.9913968278920509, 'epoch': 2.0}
Starting Phase 3 with 1040 examples for 3 epochs.


  0%|          | 0/390 [00:00<?, ?it/s]

  0%|          | 0/1231 [00:00<?, ?it/s]

{'eval_loss': 0.6359837055206299, 'eval_accuracy': 0.7309489941068888, 'eval_runtime': 51.6632, 'eval_samples_per_second': 190.503, 'eval_steps_per_second': 23.827, 'epoch': 1.0}


  0%|          | 0/1231 [00:00<?, ?it/s]

{'eval_loss': 0.470461368560791, 'eval_accuracy': 0.8090835196098354, 'eval_runtime': 52.0351, 'eval_samples_per_second': 189.142, 'eval_steps_per_second': 23.657, 'epoch': 2.0}


  0%|          | 0/1231 [00:00<?, ?it/s]

{'eval_loss': 0.47851526737213135, 'eval_accuracy': 0.8182280024385288, 'eval_runtime': 49.4559, 'eval_samples_per_second': 199.006, 'eval_steps_per_second': 24.891, 'epoch': 3.0}
{'train_runtime': 227.5855, 'train_samples_per_second': 13.709, 'train_steps_per_second': 1.714, 'train_loss': 0.9826552953475561, 'epoch': 3.0}


  0%|          | 0/1231 [00:00<?, ?it/s]

Final Accuracy on the Validation Set: 0.8182


In [None]:
test_dataset = encoded_dataset["test"]

test_results_3 = trainer_2.evaluate(eval_dataset=test_dataset)
print(f"Final Accuracy on the Test Set: {test_results_3['eval_accuracy']:.4f}")


  0%|          | 0/1228 [00:00<?, ?it/s]

Final Accuracy on the Test Set: 0.8196


In [None]:
easy_indices = [int(i) for i in easy_indices]
ambiguous_indices = [int(i) for i in ambiguous_indices]
hard_indices = [int(i) for i in hard_indices]

easy_dataset = Subset(encoded_dataset["train"], easy_indices)
ambiguous_dataset = Subset(encoded_dataset["train"], ambiguous_indices)
hard_dataset = Subset(encoded_dataset["train"], hard_indices)

def view_examples(dataset, num_examples=3):
    sampled_indices = random.sample(range(len(dataset)), num_examples)
    for i, idx in enumerate(sampled_indices, 1):
        example = dataset[idx]
        print(f"Example {i}: {example}")
    print("\n")

print("Examples from the easy subset:")
view_examples(easy_dataset)

print("Examples from the ambiguous subset:")
view_examples(ambiguous_dataset)

print("Examples from the hard subset:")
view_examples(hard_dataset)

Examples from the easy subset:
Example 1: {'premise': 'A person walks on an empty sidewalk wearing a heavy coat.', 'hypothesis': 'A crowd of people walk on the sidewalk.', 'labels': 2, 'input_ids': [101, 1037, 2711, 7365, 2006, 2019, 4064, 11996, 4147, 1037, 3082, 5435, 1012, 102, 1037, 4306, 1997, 2111, 3328, 2006, 1996, 11996, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,