<a href="https://colab.research.google.com/github/nathan-barry/ai2-cartography-reimplementation/blob/main/removed_half_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets

In [None]:
!nvidia-smi

In [None]:
import numpy as np
import torch
import time
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from transformers import ElectraForSequenceClassification, ElectraTokenizerFast, AdamW
from datasets import load_dataset


In [None]:
# Load the dataset
snli_dataset = load_dataset("snli")
print(snli_dataset)

In [None]:
# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
# Load the CSV file into a pandas DataFrame
hardest_df = pd.read_csv('/content/drive/MyDrive/data_arrays/hardest_examples.csv')
easiest_df = pd.read_csv('/content/drive/MyDrive/data_arrays/easiest_examples.csv')

In [None]:
hardest_df

In [None]:
easiest_df

In [None]:
# Tokenizer
tokenizer = ElectraTokenizerFast.from_pretrained("google/electra-small-discriminator")
print(tokenizer)

In [None]:
# Preprocess the dataset
def preprocess_data(batch, indices):
    encodings = tokenizer(batch['premise'], batch['hypothesis'], truncation=True, padding='max_length', max_length=128)
    labels = batch['label']
    encodings['labels'] = torch.tensor(labels, dtype=torch.long)
    encodings['index'] = indices 
    return encodings

In [None]:
def remove_unlabeled(example):
  return example['label'] != -1

In [None]:
# Assuming you have already loaded the hardest_df DataFrame
hardest_indices = set(hardest_df["index"].tolist())
easiest_indices = set(easiest_df["index"].tolist())

# Define a filter function to check if an example is not in the hardest_indices set
def not_in_indices(example, idx):
    return idx not in hardest_indices and idx not in easiest_indices

In [None]:
# Filter out instances with -1 labels
filtered_train_dataset = snli_dataset["train"].filter(remove_unlabeled)
filtered_val_dataset = snli_dataset["validation"].filter(remove_unlabeled)
print(filtered_train_dataset)

In [None]:
# Filter out the hardest examples from the train dataset
filtered_train_dataset = filtered_train_dataset.filter(not_in_indices, with_indices=True)
print(filtered_train_dataset)

In [None]:
# Apply preprocessing
train_dataset = filtered_train_dataset.map(preprocess_data, with_indices=True, batched=True)
val_dataset = filtered_val_dataset.map(preprocess_data, with_indices=True, batched=True)
print(train_dataset)

In [None]:
# Set the format as PyTorch tensors
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels", "index"])
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels", "index"])
print(train_dataset)

In [None]:
# Convert to PyTorch DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
for i in train_dataloader:
  print(i.keys())
  break

In [None]:
# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ElectraForSequenceClassification.from_pretrained("google/electra-small-discriminator", num_labels=3).to(device)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
# Training loop
num_epochs = 6
validate_after_batches = 5000

for epoch in range(num_epochs):
  start_time = time.time()
  print(f"Epoch {epoch+1}/{num_epochs}")
  model.train()

  for idx, batch in enumerate(train_dataloader):
    if (idx + 1) % 1000 == 0:
      print(idx+1) 

    # Training code
    optimizer.zero_grad()
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)
    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
    loss = outputs.loss
    loss.backward()
    optimizer.step()

    # Validate after every 500 batches
    if (idx + 1) % validate_after_batches == 0:
      model.eval()
      total_loss, total_correct, total_examples = 0, 0, 0
      for val_batch in val_dataloader:
        input_ids = val_batch["input_ids"].to(device)
        attention_mask = val_batch["attention_mask"].to(device)
        labels = val_batch["labels"].to(device)
        with torch.no_grad():
          outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
          loss = outputs.loss
          logits = outputs.logits
          total_loss += loss.item()
          total_correct += (logits.argmax(dim=-1) == labels).sum()
          total_examples += labels.size(0)
      val_loss = total_loss / len(val_dataloader)
      val_acc = total_correct / total_examples
      print(f"After {idx + 1} batches: Val Loss = {val_loss}, Val Accurracy = {val_acc:.4f}, Time = {time.time() - start_time}")
      model.train() # Switch back to training model

  # Save to Google Drive
  model.save_pretrained(f"/content/drive/MyDrive/model_checkpoints/half_checkpoint_epoch_{epoch+1}")

  end_time = time.time()
  print(f"Time taken for epoch {epoch + 1}: {end_time - start_time:.2f} seconds")



In [None]:
test_dataset = snli_dataset['test'].map(preprocess_data, with_indices=True, batched=True)
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels", "index"])
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
model.eval()  # Set the model to evaluation mode

total_correct = 0
total_examples = 0

for batch in test_dataloader:
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        predictions = logits.argmax(dim=-1)

        total_correct += (predictions == labels).sum().item()
        total_examples += labels.size(0)

test_accuracy = total_correct / total_examples
print(f"Test accuracy: {test_accuracy:.4f}")


Test accuracy of 3 epochs: 0.8680