In [1]:
# Import necessary libraries
from openprompt.plms import T5TokenizerWrapper
from datasets import load_from_disk
from openprompt.pipeline_base import PromptDataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer
from openprompt.prompts import ManualTemplate, MixedTemplate
from openprompt import PromptForClassification
from openprompt.data_utils import FewShotSampler
from random import shuffle
from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
import torch
from openprompt.prompts import ManualVerbalizer
from openprompt.data_utils import InputExample
from tqdm import tqdm
import json
dataset_path = "/lustre/work/client/users/minhos/cache/datasets/p3_cos_qa"
raw_dataset = load_from_disk(dataset_path)



t5_path = "/lustre/work/client/users/minhos/models_for_supercomputer/t5-base"
model = T5ForConditionalGeneration.from_pretrained(t5_path)
tokenizer = T5Tokenizer.from_pretrained(t5_path)


# Logging setup
log_file = "qa_multiple_choice_id_t5.json"
results = []

label_map = {"A":0, "B":1, "C":2, "D":3 , "E":4}

dataset = {}
for split in ['train', 'validation']:
    dataset[split] = []
    if split == 'train':
        raw_dataset[split] = raw_dataset[split].shuffle(seed=42).select(range(1000))
    else:
        raw_dataset[split] = raw_dataset[split].select(range(500))
    
    for idx, data in enumerate(raw_dataset[split]):
        label_text = data["targets_pretokenized"].strip()
        label_numeric = label_map.get(label_text, -1)
        input_example = InputExample(text_a=data['inputs_pretokenized'], guid=idx, label=label_numeric)
        dataset[split].append(input_example)
print(dataset['train'][0])
print(type(dataset['train'][0]))


sampler = FewShotSampler(num_examples_per_label=30)
fewshot_data = sampler(dataset['train'], seed=42)
def evaluate(prompt_model, dataloader):
    prompt_model.eval()  # Set the model to evaluation mode
    total, correct = 0, 0
    
    with torch.no_grad():
        for inputs in dataloader:
            logits = prompt_model(inputs)
            preds = torch.argmax(logits, dim=-1)
            labels = inputs['label']
            
            total += len(labels)
            correct += (preds == labels).sum().item()
        
    accuracy = correct / total
    return accuracy

# Hyperparameter search ranges
learning_rates = [0.005, 0.001, 0.0005] # 0.0005, 0.001, 0.005
warmup_steps = [10]

for lr in learning_rates:
    for warmup in warmup_steps:

        model = T5ForConditionalGeneration.from_pretrained(t5_path)
        tokenizer = T5Tokenizer.from_pretrained(t5_path)

        template = ManualTemplate(
            tokenizer=tokenizer,
            text='{"placeholder":"text_a"} Which option is correct? {"mask"}',
        )
        verbalizer = ManualVerbalizer(
            tokenizer=tokenizer,
            num_classes=5,
            label_words=[
                ["A", "a", "Option A", "first choice"],
                ["B", "b", "Option B", "second choice"],
                ["C", "c", "Option C", "third choice"],
                ["D", "d", "Option D", "fourth choice"],
                ["E", "e", "Option E", "fifth choice"]
            ]
        )
        wrapped_example = template.wrap_one_example(dataset['train'][0])
        prompt_model = PromptForClassification(
            plm=model,
            template=template,
            verbalizer=verbalizer,
            freeze_plm=False,
        )
        train_dataloader = PromptDataLoader(
            dataset = fewshot_data,
            template=template,
            tokenizer=tokenizer,
            tokenizer_wrapper_class=T5TokenizerWrapper,
            decoder_max_length=3, max_seq_length=480,
            batch_size=5)


        validation_dataloader = PromptDataLoader(
            dataset = dataset["validation"],
            template=template,
            tokenizer=tokenizer,
            tokenizer_wrapper_class=T5TokenizerWrapper,
            decoder_max_length=3, max_seq_length=480,
            batch_size=20
        )

        loss_func = torch.nn.CrossEntropyLoss()
        no_decay = ['bias', 'LayerNorm.weight']
        # it's always good practice to set no decay to biase and LayerNorm parameters
        optimizer_grouped_parameters = [
            {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        from tqdm import tqdm
        optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
        prompt_model.train()
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup, num_training_steps=1000)
        for epoch in range(10):
            total_loss = 0
            pbar = tqdm(train_dataloader, desc="Training")
            for step, inputs in enumerate(train_dataloader):
                logits = prompt_model(inputs)
                labels = inputs['label']
                loss = loss_func(logits, labels)
                loss.backward()
                total_loss += loss.item()
                optimizer.step()
                optimizer.zero_grad()
                pbar.set_postfix({"loss": total_loss / (step + 1)})
                if step %100 ==1:
                    print("Epoch {}, average loss: {}".format(epoch+1, total_loss/(step+1)), flush=True)

       
    
            # Validation after each epoch
        val_accuracy = evaluate(prompt_model, validation_dataloader)
        print(f"Validation Accuracy after Epoch {epoch + 1}: {val_accuracy:.4f}")
        # Log results
        result = {
            "learning_rate": lr,
            "warmup_steps": warmup,
            "final_loss": total_loss / (10 * len(train_dataloader)),
            "accuracy": val_accuracy
        }
        results.append(result)
        
        # Save results to JSON
        with open(log_file, "w") as f:
            json.dump(results, f, indent=4)
            
print("Tuning complete. Results saved to", log_file)

  from .autonotebook import tqdm as notebook_tqdm
  return torch.load(checkpoint_file, map_location="cpu")


{
  "guid": 0,
  "label": 3,
  "meta": {},
  "text_a": "Pick the option in line with common sense to answer the question.\nQuestion: It was the only way out of town, the police parked their vehicles and drew their guns to create a what?\nOptions:\n\nA. war\n\nB. sporting goods store\n\nC. military base\n\nD. roadblock\n\nE. fun\n\n",
  "text_b": "",
  "tgt_text": null
}

<class 'openprompt.data_utils.utils.InputExample'>


tokenizing: 150it [00:00, 1147.18it/s]
tokenizing: 500it [00:00, 1422.81it/s]
Training:   0%|          | 0/30 [00:13<?, ?it/s, loss=1.59]

Epoch 1, average loss: 1.5899412631988525


Training:   0%|          | 0/30 [01:24<?, ?it/s, loss=1.67]
Training:   0%|          | 0/30 [01:24<?, ?it/s, loss=1.67]

Training:   0%|          | 0/30 [00:02<?, ?it/s, loss=1.71][A
Training:   0%|          | 0/30 [00:05<?, ?it/s, loss=1.7] [A

Epoch 2, average loss: 1.7012322545051575



Training:   0%|          | 0/30 [00:09<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:11<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:14<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:16<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:18<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:20<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:22<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:23<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:25<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:27<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:29<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:30<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:32<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:34<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:35<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:37<?, ?it/s, loss=

Epoch 3, average loss: 1.6779685020446777


Training:   0%|          | 0/30 [00:40<?, ?it/s, loss=1.68]
Training:   0%|          | 0/30 [00:40<?, ?it/s, loss=1.68]

Training:   0%|          | 0/30 [00:01<?, ?it/s, loss=1.72][A
Training:   0%|          | 0/30 [00:02<?, ?it/s, loss=1.57][A

Epoch 4, average loss: 1.5705686807632446



Training:   0%|          | 0/30 [00:04<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:05<?, ?it/s, loss=1.59][A
Training:   0%|          | 0/30 [00:06<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:08<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:09<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:11<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:12<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:13<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:15<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:17<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:18<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:20<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:21<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:22<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:24<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:25<?, ?it/s, loss=

Epoch 5, average loss: 1.6333264708518982


Training:   0%|          | 0/30 [01:08<?, ?it/s, loss=1.68]
Training:   0%|          | 0/30 [01:08<?, ?it/s, loss=1.68]

Training:   0%|          | 0/30 [00:02<?, ?it/s, loss=1.8][A
Training:   0%|          | 0/30 [00:04<?, ?it/s, loss=1.68][A

Epoch 6, average loss: 1.6783061027526855



Training:   0%|          | 0/30 [00:07<?, ?it/s, loss=1.75][A
Training:   0%|          | 0/30 [00:09<?, ?it/s, loss=1.71][A
Training:   0%|          | 0/30 [00:11<?, ?it/s, loss=1.72][A
Training:   0%|          | 0/30 [00:13<?, ?it/s, loss=1.73][A
Training:   0%|          | 0/30 [00:15<?, ?it/s, loss=1.74][A
Training:   0%|          | 0/30 [00:17<?, ?it/s, loss=1.73][A
Training:   0%|          | 0/30 [00:19<?, ?it/s, loss=1.7] [A
Training:   0%|          | 0/30 [00:20<?, ?it/s, loss=1.71][A
Training:   0%|          | 0/30 [00:22<?, ?it/s, loss=1.7] [A
Training:   0%|          | 0/30 [00:24<?, ?it/s, loss=1.69][A
Training:   0%|          | 0/30 [00:27<?, ?it/s, loss=1.69][A
Training:   0%|          | 0/30 [00:28<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:30<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:32<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:34<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:35<?, ?it/s, loss=

Epoch 7, average loss: 1.4973579049110413


Training:   0%|          | 0/30 [00:45<?, ?it/s, loss=1.67]
Training:   0%|          | 0/30 [00:45<?, ?it/s, loss=1.67]

Training:   0%|          | 0/30 [00:03<?, ?it/s, loss=1.78][A
Training:   0%|          | 0/30 [00:06<?, ?it/s, loss=1.68][A

Epoch 8, average loss: 1.6807405948638916



Training:   0%|          | 0/30 [00:08<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:09<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:10<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:12<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:13<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:14<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:16<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:17<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:19<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:20<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:21<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:23<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:24<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:25<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:27<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:28<?, ?it/s, loss=

Epoch 9, average loss: 1.6087524890899658


Training:   0%|          | 0/30 [00:42<?, ?it/s, loss=1.66]
Training:   0%|          | 0/30 [00:42<?, ?it/s, loss=1.66]

Training:   0%|          | 0/30 [00:01<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:02<?, ?it/s, loss=1.58][A

Epoch 10, average loss: 1.5773894786834717



Training:   0%|          | 0/30 [00:04<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:05<?, ?it/s, loss=1.6] [A
Training:   0%|          | 0/30 [00:06<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:08<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:09<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:10<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:12<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:13<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:15<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:16<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:17<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:19<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:20<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:21<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:23<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:24<?, ?it/s, loss=

Validation Accuracy after Epoch 10: 0.1960


  return torch.load(checkpoint_file, map_location="cpu")
tokenizing: 150it [00:00, 1262.92it/s]
tokenizing: 500it [00:00, 1383.29it/s]
Training:   0%|          | 0/30 [01:38<?, ?it/s, loss=1.7]
Training:   0%|          | 0/30 [00:05<?, ?it/s, loss=1.65]

Epoch 1, average loss: 1.6454914212226868


Training:   0%|          | 0/30 [01:04<?, ?it/s, loss=1.68]
Training:   0%|          | 0/30 [01:04<?, ?it/s, loss=1.68]

Training:   0%|          | 0/30 [00:01<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:03<?, ?it/s, loss=1.6] [A

Epoch 2, average loss: 1.5967612266540527



Training:   0%|          | 0/30 [00:05<?, ?it/s, loss=1.59][A
Training:   0%|          | 0/30 [00:07<?, ?it/s, loss=1.57][A
Training:   0%|          | 0/30 [00:09<?, ?it/s, loss=1.57][A
Training:   0%|          | 0/30 [00:11<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:12<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:14<?, ?it/s, loss=1.6] [A
Training:   0%|          | 0/30 [00:15<?, ?it/s, loss=1.59][A
Training:   0%|          | 0/30 [00:17<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:19<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:21<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:22<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:24<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:25<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:26<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:28<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:30<?, ?it/s, loss=

Epoch 3, average loss: 1.5686908960342407


Training:   0%|          | 0/30 [01:09<?, ?it/s, loss=1.66]
Training:   0%|          | 0/30 [01:09<?, ?it/s, loss=1.66]

Training:   0%|          | 0/30 [00:06<?, ?it/s, loss=1.89][A
Training:   0%|          | 0/30 [00:12<?, ?it/s, loss=1.67][A

Epoch 4, average loss: 1.6714516282081604



Training:   0%|          | 0/30 [00:19<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:25<?, ?it/s, loss=1.58][A
Training:   0%|          | 0/30 [00:31<?, ?it/s, loss=1.6] [A
Training:   0%|          | 0/30 [00:37<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:43<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:49<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:55<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:59<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [01:05<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [01:11<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [01:17<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [01:23<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [01:28<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [01:33<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [01:39<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [01:44<?, ?it/s, loss=

Epoch 5, average loss: 1.5941259264945984


Training:   0%|          | 0/30 [02:21<?, ?it/s, loss=1.66]
Training:   0%|          | 0/30 [02:21<?, ?it/s, loss=1.66]

Training:   0%|          | 0/30 [00:05<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:09<?, ?it/s, loss=1.55][A

Epoch 6, average loss: 1.5457215309143066



Training:   0%|          | 0/30 [00:11<?, ?it/s, loss=1.54][A
Training:   0%|          | 0/30 [00:13<?, ?it/s, loss=1.56][A
Training:   0%|          | 0/30 [00:14<?, ?it/s, loss=1.57][A
Training:   0%|          | 0/30 [00:16<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:17<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:19<?, ?it/s, loss=1.6] [A
Training:   0%|          | 0/30 [00:20<?, ?it/s, loss=1.6][A
Training:   0%|          | 0/30 [00:21<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:23<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:24<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:25<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:27<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:28<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:30<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:31<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:33<?, ?it/s, loss=1

Epoch 7, average loss: 1.6848066449165344


Training:   0%|          | 0/30 [04:19<?, ?it/s, loss=1.7] 
Training:   0%|          | 0/30 [04:19<?, ?it/s, loss=1.7]

Training:   0%|          | 0/30 [00:18<?, ?it/s, loss=1.73][A
Training:   0%|          | 0/30 [00:38<?, ?it/s, loss=1.64][A

Epoch 8, average loss: 1.6356967687606812



Training:   0%|          | 0/30 [00:54<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [01:11<?, ?it/s, loss=1.6] [A
Training:   0%|          | 0/30 [01:28<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [01:44<?, ?it/s, loss=1.69][A
Training:   0%|          | 0/30 [02:00<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [02:17<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [02:33<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [02:47<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [03:02<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [03:17<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [03:34<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [03:50<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [04:07<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [04:23<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [04:39<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [04:54<?, ?it/s, loss=

Epoch 9, average loss: 1.6163816452026367


Training:   0%|          | 0/30 [04:22<?, ?it/s, loss=1.65]
Training:   0%|          | 0/30 [04:22<?, ?it/s, loss=1.65]

Training:   0%|          | 0/30 [00:08<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:16<?, ?it/s, loss=1.59][A

Epoch 10, average loss: 1.5904827117919922



Training:   0%|          | 0/30 [00:24<?, ?it/s, loss=1.56][A
Training:   0%|          | 0/30 [00:31<?, ?it/s, loss=1.6] [A
Training:   0%|          | 0/30 [00:38<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:45<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:52<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:59<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [01:07<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [01:14<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [01:22<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [01:29<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [01:37<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [01:44<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [01:51<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [01:58<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [02:06<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [02:14<?, ?it/s, loss=

Validation Accuracy after Epoch 10: 0.1960


  return torch.load(checkpoint_file, map_location="cpu")
tokenizing: 150it [00:00, 1291.60it/s]
tokenizing: 500it [00:00, 1325.34it/s]
Training:   0%|          | 0/30 [04:21<?, ?it/s, loss=1.65]
Training:   0%|          | 0/30 [00:07<?, ?it/s, loss=1.63]

Epoch 1, average loss: 1.6316540837287903


Training:   0%|          | 0/30 [01:21<?, ?it/s, loss=1.7] 
Training:   0%|          | 0/30 [01:21<?, ?it/s, loss=1.7]

Training:   0%|          | 0/30 [00:02<?, ?it/s, loss=1.78][A
Training:   0%|          | 0/30 [00:05<?, ?it/s, loss=1.59][A

Epoch 2, average loss: 1.5870569944381714



Training:   0%|          | 0/30 [00:08<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:11<?, ?it/s, loss=1.6] [A
Training:   0%|          | 0/30 [00:13<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:16<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:18<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:20<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:23<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:25<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:28<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:30<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:33<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:35<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:37<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:39<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:41<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:42<?, ?it/s, loss=

Epoch 3, average loss: 1.652706503868103


Training:   0%|          | 0/30 [01:05<?, ?it/s, loss=1.68]
Training:   0%|          | 0/30 [01:05<?, ?it/s, loss=1.68]

Training:   0%|          | 0/30 [00:02<?, ?it/s, loss=1.73][A
Training:   0%|          | 0/30 [00:04<?, ?it/s, loss=1.7] [A

Epoch 4, average loss: 1.6951058506965637



Training:   0%|          | 0/30 [00:06<?, ?it/s, loss=1.7][A
Training:   0%|          | 0/30 [00:08<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:11<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:13<?, ?it/s, loss=1.71][A
Training:   0%|          | 0/30 [00:15<?, ?it/s, loss=1.71][A
Training:   0%|          | 0/30 [00:17<?, ?it/s, loss=1.69][A
Training:   0%|          | 0/30 [00:19<?, ?it/s, loss=1.69][A
Training:   0%|          | 0/30 [00:21<?, ?it/s, loss=1.7] [A
Training:   0%|          | 0/30 [00:22<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:25<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:27<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:29<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:31<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:33<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:36<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:38<?, ?it/s, loss=1

Epoch 5, average loss: 1.6203705072402954


Training:   0%|          | 0/30 [00:58<?, ?it/s, loss=1.67]
Training:   0%|          | 0/30 [00:58<?, ?it/s, loss=1.67]

Training:   0%|          | 0/30 [00:01<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:04<?, ?it/s, loss=1.66][A

Epoch 6, average loss: 1.6562522053718567



Training:   0%|          | 0/30 [00:06<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [00:07<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:09<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:12<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:15<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:17<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:18<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:20<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [00:22<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:24<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:25<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:27<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:29<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:33<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:40<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:47<?, ?it/s, loss=

Epoch 7, average loss: 1.695737361907959


Training:   0%|          | 0/30 [02:18<?, ?it/s, loss=1.69]
Training:   0%|          | 0/30 [02:18<?, ?it/s, loss=1.69]

Training:   0%|          | 0/30 [00:03<?, ?it/s, loss=1.8][A
Training:   0%|          | 0/30 [00:07<?, ?it/s, loss=1.69][A

Epoch 8, average loss: 1.6851468086242676



Training:   0%|          | 0/30 [00:10<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:14<?, ?it/s, loss=1.59][A
Training:   0%|          | 0/30 [00:19<?, ?it/s, loss=1.6] [A
Training:   0%|          | 0/30 [00:24<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [00:29<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [00:33<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [00:38<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:42<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [00:46<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:50<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:54<?, ?it/s, loss=1.62][A
Training:   0%|          | 0/30 [00:58<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [01:02<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [01:06<?, ?it/s, loss=1.61][A
Training:   0%|          | 0/30 [01:10<?, ?it/s, loss=1.63][A
Training:   0%|          | 0/30 [01:13<?, ?it/s, loss=

Epoch 9, average loss: 1.5478777885437012


Training:   0%|          | 0/30 [07:25<?, ?it/s, loss=1.65]
Training:   0%|          | 0/30 [07:25<?, ?it/s, loss=1.65]

Training:   0%|          | 0/30 [00:14<?, ?it/s, loss=1.68][A
Training:   0%|          | 0/30 [00:26<?, ?it/s, loss=1.52][A

Epoch 10, average loss: 1.5236400961875916



Training:   0%|          | 0/30 [00:39<?, ?it/s, loss=1.55][A
Training:   0%|          | 0/30 [00:52<?, ?it/s, loss=1.58][A
Training:   0%|          | 0/30 [01:05<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [01:19<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [01:29<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [01:40<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [01:51<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [02:01<?, ?it/s, loss=1.67][A
Training:   0%|          | 0/30 [02:11<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [02:21<?, ?it/s, loss=1.65][A
Training:   0%|          | 0/30 [02:31<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [02:41<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [02:49<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [02:56<?, ?it/s, loss=1.64][A
Training:   0%|          | 0/30 [03:04<?, ?it/s, loss=1.66][A
Training:   0%|          | 0/30 [03:12<?, ?it/s, loss=

Validation Accuracy after Epoch 10: 0.1960
Tuning complete. Results saved to qa_multiple_choice_id_t5.json
