In [None]:
from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score

from transformers import ElectraConfig, ElectraForSequenceClassification, ElectraTokenizer

 # custom config of electra with fewer layers, heads, and a smaller hidden size
custom_config = ElectraConfig(
    hidden_size=64, 
    num_attention_heads=2, 
    intermediate_size=256, 
    num_hidden_layers=6,
    max_position_embeddings=64, 
    vocab_size=30522,
    num_labels=3 
)

small_electra_model = ElectraForSequenceClassification(custom_config)

# make a tokenizer from a pre-trained electra model (same vocabulary)
tokenizer = ElectraTokenizer.from_pretrained("google/electra-small-discriminator")

print(small_electra_model)

checkpoint_folder = "./New Folder With Items"

bias_model = small_electra_model
bias_tokenizer = AutoTokenizer.from_pretrained(checkpoint_folder)

# load snli
snli_dataset = load_dataset("snli")
full_train_dataset = snli_dataset["train"]

# tokenize training set
def bias_tokenize_function(example):
    return bias_tokenizer(
        example["premise"], 
        example["hypothesis"], 
        truncation=True, 
        padding="max_length", 
        max_length=57
    )

tokenized_full_train = full_train_dataset.map(bias_tokenize_function, batched=True)

# define training args
bias_training_args = TrainingArguments(
    output_dir="./bias_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=16, 
)

# handle non-contiguous tensors
class ContiguousTrainer(Trainer):
    def save_model(self, output_dir=None, **kwargs): 
        for name, param in self.model.named_parameters():
            if not param.is_contiguous():
                param.data = param.data.contiguous()
        super().save_model(output_dir, **kwargs)

# trainer for bias model with contiguous saving
bias_trainer = ContiguousTrainer(
    model=bias_model,
    args=bias_training_args,
    train_dataset=tokenized_full_train,
    eval_dataset=tokenized_full_train
)

bias_trainer.train()

# get predictions on train
bias_predictions = bias_trainer.predict(tokenized_full_train).predictions
predicted_labels = np.argmax(bias_predictions, axis=1)
residuals = full_train_dataset["label"] - predicted_labels

checkpoint_folder = "./New Folder With Items"
main_model = AutoModelForSequenceClassification.from_pretrained(checkpoint_folder)
main_tokenizer = AutoTokenizer.from_pretrained(checkpoint_folder)

def main_tokenize_function(example, index):
    tokens = main_tokenizer(
        example["premise"],
        example["hypothesis"],
        truncation=True,
        padding="max_length",
        max_length=57
    )
    tokens["labels"] = residuals[index]
    return tokens

# apply tokenization and residual labels for residuals dataset
tokenized_residuals = full_train_dataset.map(main_tokenize_function, with_indices=True)

tokenized_residuals.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

main_training_args = TrainingArguments(
    output_dir="./main_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=16, 
)

test_dataset = snli_dataset["test"]

def test_tokenize_function(example):
    return main_tokenizer(
        example["premise"],
        example["hypothesis"],
        truncation=True,
        padding="max_length",
        max_length=57
    )

# apply tokenization to test set
tokenized_test_dataset = test_dataset.map(test_tokenize_function, batched=True)

tokenized_test_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])

# define a function to compute accuracy
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    return {"accuracy": accuracy}

# update the main_trainer with compute_metrics
main_trainer = Trainer(
    model=main_model,
    args=main_training_args,
    train_dataset=tokenized_residuals,
    eval_dataset=tokenized_residuals,
    compute_metrics=compute_metrics 
)

# now evaluate the model on the tokenized test dataset
test_results = main_trainer.evaluate(eval_dataset=tokenized_test_dataset)

# print the entire test_results dictionary to check available keys
print("Evaluation results:", test_results)

# access accuracy if available
if "eval_accuracy" in test_results:
    print(f"Final accuracy on the test set: {test_results['eval_accuracy']:.4f}")
elif "accuracy" in test_results:
    print(f"Final accuracy on test set: {test_results['accuracy']:.4f}")
else:
    print("Error.")



ElectraForSequenceClassification(
  (electra): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(64, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (embeddings_project): Linear(in_features=128, out_features=64, bias=True)
    (encoder): ElectraEncoder(
      (layer): ModuleList(
        (0-5): 6 x ElectraLayer(
          (attention): ElectraAttention(
            (self): ElectraSelfAttention(
              (query): Linear(in_features=64, out_features=64, bias=True)
              (key): Linear(in_features=64, out_features=64, bias=True)
              (value): Linear(in_features=64, out_features=64, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ElectraSelfOutput(
              (dense): Linear(in_fe

Map:   0%|          | 0/550152 [00:00<?, ? examples/s]



  0%|          | 0/103155 [00:00<?, ?it/s]

{'loss': 1.0897, 'grad_norm': 0.929936408996582, 'learning_rate': 4.97576462604818e-05, 'epoch': 0.01}
{'loss': 1.0378, 'grad_norm': 2.1871800422668457, 'learning_rate': 4.9515292520963604e-05, 'epoch': 0.03}
{'loss': 1.0109, 'grad_norm': 4.704246520996094, 'learning_rate': 4.92729387814454e-05, 'epoch': 0.04}
{'loss': 0.9949, 'grad_norm': 6.175472259521484, 'learning_rate': 4.90305850419272e-05, 'epoch': 0.06}
{'loss': 0.9745, 'grad_norm': 11.171287536621094, 'learning_rate': 4.8788231302409e-05, 'epoch': 0.07}
{'loss': 0.9751, 'grad_norm': 7.448033332824707, 'learning_rate': 4.8545877562890794e-05, 'epoch': 0.09}
{'loss': 0.9606, 'grad_norm': 3.906587600708008, 'learning_rate': 4.83035238233726e-05, 'epoch': 0.1}
{'loss': 0.9662, 'grad_norm': 9.827710151672363, 'learning_rate': 4.8061170083854396e-05, 'epoch': 0.12}
{'loss': 0.9336, 'grad_norm': 5.347815036773682, 'learning_rate': 4.78188163443362e-05, 'epoch': 0.13}
{'loss': 0.9208, 'grad_norm': 5.368433475494385, 'learning_rate': 4

  0%|          | 0/68769 [00:00<?, ?it/s]

{'eval_loss': 0.7322214841842651, 'eval_runtime': 6976.6707, 'eval_samples_per_second': 78.856, 'eval_steps_per_second': 9.857, 'epoch': 1.0}
{'loss': 0.7686, 'grad_norm': 7.144468307495117, 'learning_rate': 3.3277591973244146e-05, 'epoch': 1.0}
{'loss': 0.7512, 'grad_norm': 5.24981689453125, 'learning_rate': 3.3035238233725944e-05, 'epoch': 1.02}
{'loss': 0.7625, 'grad_norm': 3.674412250518799, 'learning_rate': 3.279288449420775e-05, 'epoch': 1.03}
{'loss': 0.7627, 'grad_norm': 10.069609642028809, 'learning_rate': 3.2550530754689546e-05, 'epoch': 1.05}
{'loss': 0.756, 'grad_norm': 10.032224655151367, 'learning_rate': 3.230817701517135e-05, 'epoch': 1.06}
{'loss': 0.7584, 'grad_norm': 7.1385817527771, 'learning_rate': 3.206582327565315e-05, 'epoch': 1.08}
{'loss': 0.746, 'grad_norm': 6.116687297821045, 'learning_rate': 3.1823469536134945e-05, 'epoch': 1.09}
{'loss': 0.7562, 'grad_norm': 6.517622947692871, 'learning_rate': 3.158111579661674e-05, 'epoch': 1.11}
{'loss': 0.7554, 'grad_nor

  0%|          | 0/68769 [00:00<?, ?it/s]

{'eval_loss': 0.6695483922958374, 'eval_runtime': 6326.6387, 'eval_samples_per_second': 86.958, 'eval_steps_per_second': 10.87, 'epoch': 2.0}
{'loss': 0.7016, 'grad_norm': 10.1483736038208, 'learning_rate': 1.6555183946488294e-05, 'epoch': 2.01}
{'loss': 0.7018, 'grad_norm': 8.871072769165039, 'learning_rate': 1.6312830206970095e-05, 'epoch': 2.02}
{'loss': 0.6949, 'grad_norm': 6.341709136962891, 'learning_rate': 1.6070476467451895e-05, 'epoch': 2.04}
{'loss': 0.6953, 'grad_norm': 7.744751453399658, 'learning_rate': 1.5828122727933693e-05, 'epoch': 2.05}
{'loss': 0.6901, 'grad_norm': 9.605010032653809, 'learning_rate': 1.558576898841549e-05, 'epoch': 2.06}
{'loss': 0.6817, 'grad_norm': 8.296448707580566, 'learning_rate': 1.534341524889729e-05, 'epoch': 2.08}
{'loss': 0.6981, 'grad_norm': 9.939855575561523, 'learning_rate': 1.510106150937909e-05, 'epoch': 2.09}
{'loss': 0.702, 'grad_norm': 8.375449180603027, 'learning_rate': 1.4858707769860888e-05, 'epoch': 2.11}
{'loss': 0.6896, 'grad_

  0%|          | 0/68769 [00:00<?, ?it/s]

{'eval_loss': 0.6456310153007507, 'eval_runtime': 4209.0155, 'eval_samples_per_second': 130.708, 'eval_steps_per_second': 16.339, 'epoch': 3.0}
{'train_runtime': 65574.5397, 'train_samples_per_second': 25.169, 'train_steps_per_second': 1.573, 'train_loss': 0.7528186898127268, 'epoch': 3.0}


  0%|          | 0/68769 [00:00<?, ?it/s]

Map:   0%|          | 0/550152 [00:00<?, ? examples/s]



  0%|          | 0/103155 [00:00<?, ?it/s]

{'loss': 0.4936, 'grad_norm': 2.8753273487091064, 'learning_rate': 4.97576462604818e-05, 'epoch': 0.01}
{'loss': 0.4351, 'grad_norm': 5.618480682373047, 'learning_rate': 4.9515292520963604e-05, 'epoch': 0.03}
{'loss': 0.4334, 'grad_norm': 3.557499885559082, 'learning_rate': 4.92729387814454e-05, 'epoch': 0.04}
{'loss': 0.4162, 'grad_norm': 5.431368350982666, 'learning_rate': 4.90305850419272e-05, 'epoch': 0.06}
{'loss': 0.4095, 'grad_norm': 4.089376449584961, 'learning_rate': 4.8788231302409e-05, 'epoch': 0.07}
{'loss': 0.3998, 'grad_norm': 6.50526762008667, 'learning_rate': 4.8545877562890794e-05, 'epoch': 0.09}
{'loss': 0.409, 'grad_norm': 4.539760589599609, 'learning_rate': 4.83035238233726e-05, 'epoch': 0.1}
{'loss': 0.4004, 'grad_norm': 16.888830184936523, 'learning_rate': 4.8061170083854396e-05, 'epoch': 0.12}
{'loss': 0.3869, 'grad_norm': 6.269852638244629, 'learning_rate': 4.78188163443362e-05, 'epoch': 0.13}
{'loss': 0.3842, 'grad_norm': 5.592806816101074, 'learning_rate': 4.7

  0%|          | 0/68769 [00:00<?, ?it/s]

{'eval_loss': 0.2930378317832947, 'eval_runtime': 1854.236, 'eval_samples_per_second': 296.7, 'eval_steps_per_second': 37.088, 'epoch': 1.0}
{'loss': 0.3286, 'grad_norm': 9.273524284362793, 'learning_rate': 3.3277591973244146e-05, 'epoch': 1.0}
{'loss': 0.3128, 'grad_norm': 10.432160377502441, 'learning_rate': 3.3035238233725944e-05, 'epoch': 1.02}
{'loss': 0.2992, 'grad_norm': 7.834573745727539, 'learning_rate': 3.279288449420775e-05, 'epoch': 1.03}
{'loss': 0.3198, 'grad_norm': 3.936295747756958, 'learning_rate': 3.2550530754689546e-05, 'epoch': 1.05}
{'loss': 0.3162, 'grad_norm': 4.256681442260742, 'learning_rate': 3.230817701517135e-05, 'epoch': 1.06}
{'loss': 0.3081, 'grad_norm': 11.89797306060791, 'learning_rate': 3.206582327565315e-05, 'epoch': 1.08}
{'loss': 0.313, 'grad_norm': 5.215085506439209, 'learning_rate': 3.1823469536134945e-05, 'epoch': 1.09}
{'loss': 0.3095, 'grad_norm': 9.46875286102295, 'learning_rate': 3.158111579661674e-05, 'epoch': 1.11}
{'loss': 0.3059, 'grad_no

  0%|          | 0/68769 [00:00<?, ?it/s]

{'eval_loss': 0.2353406399488449, 'eval_runtime': 14060.4365, 'eval_samples_per_second': 39.128, 'eval_steps_per_second': 4.891, 'epoch': 2.0}
{'loss': 0.2776, 'grad_norm': 13.653707504272461, 'learning_rate': 1.6555183946488294e-05, 'epoch': 2.01}
{'loss': 0.2614, 'grad_norm': 5.901821136474609, 'learning_rate': 1.6312830206970095e-05, 'epoch': 2.02}
{'loss': 0.2515, 'grad_norm': 7.043242931365967, 'learning_rate': 1.6070476467451895e-05, 'epoch': 2.04}
{'loss': 0.2627, 'grad_norm': 9.907173156738281, 'learning_rate': 1.5828122727933693e-05, 'epoch': 2.05}
{'loss': 0.2696, 'grad_norm': 9.449553489685059, 'learning_rate': 1.558576898841549e-05, 'epoch': 2.06}
{'loss': 0.2645, 'grad_norm': 10.251653671264648, 'learning_rate': 1.534341524889729e-05, 'epoch': 2.08}
{'loss': 0.2587, 'grad_norm': 6.704514980316162, 'learning_rate': 1.510106150937909e-05, 'epoch': 2.09}
{'loss': 0.2673, 'grad_norm': 6.611218452453613, 'learning_rate': 1.4858707769860888e-05, 'epoch': 2.11}
{'loss': 0.2657, '

  0%|          | 0/68769 [00:00<?, ?it/s]

{'eval_loss': 0.20425069332122803, 'eval_runtime': 1442.6806, 'eval_samples_per_second': 381.34, 'eval_steps_per_second': 47.668, 'epoch': 3.0}
{'train_runtime': 66631.681, 'train_samples_per_second': 24.77, 'train_steps_per_second': 1.548, 'train_loss': 0.30845974834398315, 'epoch': 3.0}


  0%|          | 0/1250 [00:00<?, ?it/s]

Evaluation Results: {'eval_loss': 3.4126429557800293, 'eval_model_preparation_time': 0.0023, 'eval_accuracy': 0.4032, 'eval_runtime': 31.2216, 'eval_samples_per_second': 320.291, 'eval_steps_per_second': 40.036}
Final Accuracy on the Test Set: 0.4032
