In [None]:
from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification, AutoTokenizer, AdamW
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression
from datasets import load_dataset
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"

snli_dataset = load_dataset("snli")
full_train_dataset = snli_dataset["train"].filter(lambda x: x["label"] != -1) 

# extract bow features for the adversary 
vectorizer = CountVectorizer(max_features=5000)
premises = full_train_dataset["premise"]
hypotheses = full_train_dataset["hypothesis"]
bow_features = vectorizer.fit_transform([f"{p} {h}" for p, h in zip(premises, hypotheses)])

# train logistic regression as bias model on bow features
filtered_labels = full_train_dataset["label"]
adversary = LogisticRegression(max_iter=1000)
adversary.fit(bow_features, filtered_labels)

# precompute bias predictions on the bow features
bias_predicted_probs = adversary.predict_proba(bow_features) 
# shape: (num_samples, num_classes)

# load main model and tokenizer
checkpoint_folder = "./New Folder With Items"
main_model = AutoModelForSequenceClassification.from_pretrained(checkpoint_folder)
main_model.to(device)
main_tokenizer = AutoTokenizer.from_pretrained(checkpoint_folder)

# define optimizer and loss function for the main model
optimizer = AdamW(main_model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss()

# tokenize the full dataset with the main model's tokenizer
def main_tokenize_function(example):
    return main_tokenizer(
        example["premise"],
        example["hypothesis"],
        truncation=True,
        padding="max_length",
        max_length=57
    )

# apply tokenization to the training dataset
tokenized_full_train = full_train_dataset.map(main_tokenize_function, batched=True)
tokenized_full_train.set_format("torch", columns=["input_ids", "attention_mask", "label"])

# attach precomputed bias predictions to tokenized dataset
tokenized_full_train = tokenized_full_train.add_column("bias_probs", bias_predicted_probs.tolist())
def adversarial_loss(predictions, labels, bias_probs):
    adversary_target = torch.tensor(bias_probs).float().to(predictions.device)

    if adversary_target.shape[1] != predictions.shape[1]:
        adversary_target = adversary_target[:, :predictions.shape[1]] # slice to match classes

    # combined loss: classification loss + adversarial loss
    classification_loss = criterion(predictions, labels)
    adversarial_penalty = nn.MSELoss()(predictions, adversary_target)
    total_loss = classification_loss + 0.5 * adversarial_penalty 

    return total_loss

train_dataloader = DataLoader(tokenized_full_train, batch_size=16, shuffle=True)

# training loop with adversarial loss
num_epochs = 3
for epoch in range(num_epochs):
    main_model.train()
    total_loss = 0

    for batch in train_dataloader:
        optimizer.zero_grad()
        
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
        bias_probs = batch["bias_probs"] 

        # forward pass
        outputs = main_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        logits = outputs.logits

        loss = adversarial_loss(logits, labels, bias_probs)
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")

test_dataset = snli_dataset["test"]

def test_tokenize_function(example):
    return main_tokenizer(
        example["premise"],
        example["hypothesis"],
        truncation=True,
        padding="max_length",
        max_length=57
    )

# apply tokenization to the test set
tokenized_test_dataset = test_dataset.map(test_tokenize_function, batched=True)
tokenized_test_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])


 # define a function to compute accuracy
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    return {"accuracy": accuracy}

# define training arguments for evaluation on the test set
main_training_args = TrainingArguments(
    output_dir="./main_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=16,
)

main_trainer = Trainer(
    model=main_model,
    args=main_training_args,
    train_dataset=tokenized_full_train,
    eval_dataset=tokenized_test_dataset,
    compute_metrics=compute_metrics
)

# evaluate on the tokenized test dataset
test_results = main_trainer.evaluate(eval_dataset=tokenized_test_dataset)

print("Evaluation results:", test_results)

if "eval_accuracy" in test_results:
    print(f"Final accuracy on the test set: {test_results['eval_accuracy']:.4f}")
elif "accuracy" in test_results:
    print(f"Final accuracy on test set: {test_results['accuracy']:.4f}")
else:
    print("Error.")

Filter:   0%|          | 0/550152 [00:00<?, ? examples/s]



Map:   0%|          | 0/549367 [00:00<?, ? examples/s]

  adversary_target = torch.tensor(bias_probs).float().to(predictions.device)


Epoch 1/3, Loss: 0.7190
Epoch 2/3, Loss: 0.7069
Epoch 3/3, Loss: 0.6979




  0%|          | 0/1250 [00:00<?, ?it/s]

Evaluation Results: {'eval_loss': 0.5151589512825012, 'eval_model_preparation_time': 0.0028, 'eval_accuracy': 0.8662, 'eval_runtime': 65.804, 'eval_samples_per_second': 151.967, 'eval_steps_per_second': 18.996}
Final Accuracy on the Test Set: 0.8662
