<a href="https://colab.research.google.com/github/junruren/6.7960-2024-Fall/blob/main/6_7960_final_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The Llama model is access-controlled. Speak with Junru for access.

In [None]:
!pip install -q -U bitsandbytes accelerate transformers
!pip install -q boto3
!pip install -q datasets
!pip install -q rouge-score # For ROUGE
!pip install -q nltk # For BLEU
!pip install -q bert-score # For BERTScore

In [None]:
from google.colab import userdata
import boto3
from botocore.exceptions import NoCredentialsError, ClientError
import pandas as pd

from tqdm.auto import tqdm
import dataclasses
from typing import List
import matplotlib.pyplot as plt

from datasets import load_dataset
from rouge_score import rouge_scorer
import nltk
nltk.download('punkt')
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import torch
import torch.nn as nn
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Optionally for BERTScore:
from bert_score import score as bert_score

import json
from datetime import datetime
from zoneinfo import ZoneInfo

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
# Check if CUDA is available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")

# 8-bit quantization config
quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0
)

# Load LLaMA-2-7B
model_name = "meta-llama/Llama-2-7b-chat-hf"

# Load the LLM with 8-bit quantization
model_llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
).eval()  # Keep LLM in eval mode

# Infer model hidden_size and vocab_size from Llama model config
hidden_size = model_llm.config.hidden_size
vocab_size = model_llm.config.vocab_size

tokenizer = AutoTokenizer.from_pretrained(model_name)

Using device: cuda


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
#-----------------------------------------
# Get Training Data
# https://huggingface.co/datasets/nvidia/OpenMathInstruct-2
# Download only the 1M training split
# Columns: [problem, generated_solution, expected_answer, problem_source]
#-----------------------------------------

dataset = load_dataset('nvidia/OpenMathInstruct-2', split='train_1M', streaming=True) # IterableDataset obj

TRAIN_DATA_SIZE = 70
VALIDATION_DATA_SIZE = 30
TEST_DATA_SIZE = 30

train_data = []
validation_data = []
test_data = []

for example in dataset:
    data = {
        "question": example["problem"],
        "answer": example["expected_answer"]
    }

    if len(train_data) < TRAIN_DATA_SIZE:
        train_data.append(data)
    elif len(validation_data) < VALIDATION_DATA_SIZE:
        validation_data.append(data)
    else:
        test_data.append(data)

    if len(test_data) >= TEST_DATA_SIZE:
        break

print(len(train_data))
print(len(validation_data))
print(len(test_data))

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

70000
30000
30000


In [None]:
#-----------------------------------------
# Construct Initial Prompt
# Define the "prompt for problem-solving prompt" P
#-----------------------------------------

def get_instruction_prompt():
    return "Given the following question, produce a concise and refined prompt to elicit the correct answer to this question from a large language model. The refined prompt should not reveal the solution but should encourage thorough verification and clarity."

def get_context_enhancement_prompt():
    return "You are an intelligent and technologically sophisticated math professor trying to find the right instructions to tell a large language model to explain the correct solution to a math problem."

def get_instruction_enhancement_prompt():
    return "In your prompt, incorporate the exact question and add additional instructions that guide the model to reason, carefully check its reasoning, and provide a correct and well-explained answer."

def get_zero_shot_enhancement_prompt():
    return "Add \"Let's think step by step.\" to the end of your prompt."

def get_question_suffix():
    return " Question: "

def get_prompt_1():
    return get_instruction_prompt() + get_question_suffix()

def get_prompt_2():
    return get_instruction_prompt() + "\n" + get_instruction_enhancement_prompt() + get_question_suffix()

def get_prompt_3():
    return get_context_enhancement_prompt() + "\n" + get_instruction_prompt() + "\n" + get_instruction_enhancement_prompt() + get_question_suffix()

def get_prompt_4():
    return get_context_enhancement_prompt() + "\n" + get_instruction_prompt() + "\n" + get_instruction_enhancement_prompt() + "\n" + get_zero_shot_enhancement_prompt() + get_question_suffix()


In [None]:
#-----------------------------------------
# The Refiner model:
# We'll define a simple model that takes the LLM hidden states
# and outputs a refined prompt as logits over tokens.
#-----------------------------------------
class SimpleRefiner(nn.Module):
    def __init__(self, hidden_size=4096, vocab_size=32000, seq_len=10):
        super().__init__()
        self.seq_len = seq_len
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size

        # We'll pool over sequence dimension
        self.pool = nn.AdaptiveAvgPool1d(1)

        # MLP from hidden_size to seq_len*vocab_size
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, seq_len * vocab_size)
        )

    def forward(self, llm_hidden_states):
        # llm_hidden_states: [batch, seq_len, hidden_size]
        # Transpose for pooling: [batch, hidden_size, seq_len]
        x = llm_hidden_states.transpose(1, 2)
        # pool -> [batch, hidden_size, 1]
        x = self.pool(x).squeeze(-1) # [batch, hidden_size]

        out = self.mlp(x) # [batch, seq_len*vocab_size]
        out = out.view(-1, self.seq_len, self.vocab_size)
        return out

def text_to_tokens(text):
    return tokenizer.encode(text, return_tensors='pt').to(device)

def tokens_to_text(tokens):
    return tokenizer.decode(tokens, skip_special_tokens=True)

def generate_hidden_states(model, input_ids):
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
    hidden_states = outputs.hidden_states[-1] # last layer hidden states
    return hidden_states

In [None]:
def make_final_input_text(refined_prompt_text, question):
    """
    Utility to make the final_input_text that combines the refined prompt and question so that the LLM can generate the answer. Optionally, we may suffix the text with a clear instruction to LLM on how to format the final answer so that our evaluation code can programmatically extract the answer portion.

    Return: (the final_input_text, answer parser function corresponding to the added instruction)
    """
    def answer_parser(output_text: str):
        """
        Extract text immedieately trailing last instance of "Answer: " from output_text.

        Return: the extracted answer, or None if not found.
        """
        answer_lines = [line for line in output_text.split('\n') if line.startswith('Answer: ')]
        if answer_lines:
            answer = answer_lines[-1].replace('Answer: ', '')
            return answer if len(answer) else None
        else:
            return None
    final_input_text = refined_prompt_text.strip() + ": " + question + " Please present your final answer at the end in a new line prefixed with \"Answer: \""
    return final_input_text, answer_parser

class OutputQualityChecker:
    def __init__(self):
        self.total_count = 0
        self.answer_found_count = 0

    def check(self, output_text, answer_parser):
        self.total_count += 1
        if output_text and answer_parser:
            answer = answer_parser(output_text)
            if answer:
                self.answer_found_count += 1

    def report_quality(self):
        print(f"Total count: {self.total_count}")
        print(f"Answer found count: {self.answer_found_count}")
        print(f"Quality: {self.answer_found_count * 100 / self.total_count}% of the responses can parse out the answer as we expected")

In [None]:
@dataclasses.dataclass
class TrainResult:
    r"""
    A collection containing everything we need to know about the training results
    """

    num_epochs: int
    lr: float

    # Training loss (saved at each iteration in `train_epoch`)
    train_losses: List[float]

    # Training accuracies, before training and after each epoch
    train_accs: List[float]

    # Validation accuracies, before training and after each epoch
    val_accs: List[float]

def generate_refined_prompt(refiner: SimpleRefiner, question):
    # Given Q, we do same steps as training to get refined prompt
    input_text = P + question
    input_ids = text_to_tokens(input_text)
    with torch.no_grad():
        # Get hidden states
        outputs = model_llm(input_ids, output_hidden_states=True)
        llm_hidden_states = outputs.hidden_states[-1]

    # Pass through refiner
    refiner.eval()
    with torch.no_grad():
        refiner_logits = refiner(llm_hidden_states)  # [1, seq_len, vocab_size]
        refined_prompt_ids = torch.argmax(refiner_logits, dim=-1)  # [1, seq_len]
        refined_prompt_text = tokens_to_text(refined_prompt_ids[0])
    return refined_prompt_text.strip()

def generate_answer(refined_prompt, question):
    # Use the new utility
    final_input_text, answer_parser = make_final_input_text(refined_prompt, question)
    final_input_ids = text_to_tokens(final_input_text)
    with torch.no_grad():
        gen_ids = model_llm.generate(final_input_ids, max_new_tokens=50)
    predicted_output = tokenizer.decode(gen_ids[0], skip_special_tokens=True)

    # Attempt to parse the answer
    parsed_answer = answer_parser(predicted_output)
    if parsed_answer is not None:
        return parsed_answer
    else:
        # Fallback: If no "Answer: " line found, try a naive extraction
        # (this fallback may or may not be needed depending on your instructions)
        # For now, just return the entire predicted_output or empty string
        return predicted_output.strip()

def evaluate(refiner: SimpleRefiner, data):
    # Define accuracy: proportion of exact matches
    # You can improve this metric (e.g., using F1) as needed.
    correct = 0
    total = len(data)
    for ex in data:
        q = ex["question"]
        a = ex["answer"]
        refined = generate_refined_prompt(refiner, q)
        pred = generate_answer(refined, q)
        if pred.strip().lower() == a.strip().lower():
            correct += 1
    return correct / total if total > 0 else 0.0


def learning_curve(result: TrainResult, *, title: str = 'Learning Curve'):
    r"""
    Plot the training loss, training accuracy, and validation accuracy versus
    epochs taken.
    """
    fig, ax_loss = plt.subplots(figsize=(8, 5))
    ax_loss.set_title(title, fontsize=16)
    ax_loss.set_xlabel('Epoch', fontsize=12)

    l_trloss = ax_loss.plot(
        torch.arange(len(result.train_losses)) / len(result.train_losses) * result.num_epochs,
        result.train_losses,
        label='Train loss',
        color='C0',
    )
    ax_loss.set_ylim(0, 3)
    ax_loss.set_ylabel('Train loss', color='C0', fontsize=12)
    ax_loss.tick_params(axis='y', labelcolor='C0')

    ax_acc = ax_loss.twinx()
    l_tracc = ax_acc.plot(result.train_accs, label='Train acc', color='C1', linestyle='--')
    if len(result.val_accs):
        l_valacc = ax_acc.plot(result.val_accs, label='Val acc', color='C1')
    else:
        l_valacc = []
    ax_acc.set_ylim(0, 1)
    ax_acc.set_ylabel('Accuracies', color='C1', fontsize=12)
    ax_acc.tick_params(axis='y', labelcolor='C1')

    lines = l_trloss + l_tracc + l_valacc
    ax_loss.legend(lines, [l.get_label() for l in lines], loc='upper left', fontsize=13)

In [None]:
def train_model_with_experiment(name, P, train_data,
                                num_epochs, lr=1e-4,
                                validation_data=None, # Running validation during each train epoch is very costly
    ):
    output_quality_checker = OutputQualityChecker()

    refiner = SimpleRefiner(hidden_size=hidden_size, vocab_size=vocab_size, seq_len=10).to(device).half()
    optimizer = AdamW(refiner.parameters(), lr)

    result: TrainResult = TrainResult(num_epochs, lr, train_losses=[], train_accs=[], val_accs=[])
    #result.train_accs.append(evaluate(refiner, train_data)) # <-- time consuming
    if validation_data:
        result.val_accs.append(evaluate(refiner, validation_data)) # <-- time consuming

    for epoch in tqdm(range(num_epochs)):
        loss_values: List[float] = []
        for example in tqdm(train_data, desc=f'Training @ epoch {epoch}'):
            question = example["question"]
            answer = example["answer"]

            # Step 1: Prepare input for LLM: "P + Q"
            input_text = P + question
            input_ids = text_to_tokens(input_text)

            # Sanity check: print LLM's raw completion (initial output)
            with torch.no_grad():
                gen_ids = model_llm.generate(input_ids, max_new_tokens=50)
                gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
                # print("LLM initial completion:\n", gen_text, "\n")

            # Step 2: Run LLM to get hidden states
            with torch.no_grad():
                outputs = model_llm(input_ids, output_hidden_states=True)
                llm_hidden_states = outputs.hidden_states[-1]  # [1, seq_len, hidden_size]

            # Step 3: Run refiner
            refiner.train()
            refiner_logits = refiner(llm_hidden_states)  # [1, seq_len=10, vocab_size]

            # Step 4: Convert refiner_logits to tokens (non-differentiable)
            refined_prompt_ids = torch.argmax(refiner_logits, dim=-1)  # [1, 10]
            refined_prompt_text = tokens_to_text(refined_prompt_ids[0])

            # Step 5: Concatenate refined prompt and question
            final_input_text, answer_parser = make_final_input_text(refined_prompt_text, answer)
            final_input_ids = text_to_tokens(final_input_text)

            # Another sanity check: print LLM's output with the refined prompt
            with torch.no_grad():
                gen_refined_ids = model_llm.generate(final_input_ids, max_new_tokens=50)
                gen_refined_text = tokenizer.decode(gen_refined_ids[0], skip_special_tokens=True)
                # print("LLM output after refined prompt:\n", gen_refined_text, "\n")
                output_quality_checker.check(gen_refined_text, answer_parser)

            # Step 6: Compute loss with the known answer
            combined_input = final_input_text + " " + answer
            combined_ids = text_to_tokens(combined_input)

            # Compute LM loss
            outputs = model_llm(combined_ids, labels=combined_ids)
            loss = outputs.loss
            loss_values.append(loss.item())

            # Step 7: Backprop into refiner
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = sum(loss_values) / len(loss_values)
        result.train_losses.append(epoch_loss)
        #result.train_accs.append(evaluate(refiner, train_data)) # <-- time consuming
        if validation_data:
            result.val_accs.append(evaluate(refiner, validation_data)) # <-- time consuming
        print(f"Epoch = {epoch:> 2d}    Train loss = {result.train_losses[-1]:.4f}")

    print(f"Training done for experiment {name}.")
    #output_quality_checker.report_quality()

    return (
        refiner, # trained model
        result
    )

# Evaluation of Prompt Refiner Model with Llama2

When evaluating a question-answering (QA) system, we want to measure how closely the model’s generated answers match the ground truth. Here are a few metrics and why they matter:

1. **Exact Match (EM):**  
   *What it measures:*  
   The percentage of answers that match the reference answer exactly, word-for-word.  
   *Significance:*  
   This is a strict metric that gives a clear sense of how often the model gets the answer perfectly right without any extra words or differences in phrasing. It’s easy to understand but can be overly harsh on answers that are correct in meaning but differ slightly in formatting.

2. **F1 Score (Token Overlap):**  
   *What it measures:*  
   The harmonic mean of precision and recall of word overlap between the predicted answer and the reference.  
   *Significance:*  
   F1 captures partial correctness. If the model includes most of the correct information but adds some extra words, it still gets partial credit. This metric balances being lenient and strict, encouraging more complete answers without penalizing minor wording differences as harshly as Exact Match.

3. **ROUGE (Recall-Oriented Understudy for Gisting Evaluation):**  
   *What it measures:*  
   Overlap of n-grams (short sequences of words) between the generated answer and the reference answer.  
   *Significance:*  
   ROUGE was popularized in summarization tasks, but for QA it can reward answers that capture the key content. It’s less strict than EM and can show how similar the answers are at a lexical level.

4. **BLEU (Bilingual Evaluation Understudy):**  
   *What it measures:*  
   Similar to ROUGE, but originally designed for machine translation. It calculates n-gram overlaps from the perspective of precision.  
   *Significance:*  
   BLEU provides another lens on content similarity. While often used for translation, it can provide a secondary check for QA. However, it’s less common for QA than EM/F1.

5. **BERTScore:**  
   *What it measures:*  
   Uses contextual embeddings (from a model like BERT) to measure semantic similarity between predicted and reference answers.  
   *Significance:*  
   BERTScore goes beyond exact word matches and tries to measure if the model’s answer “means” the same thing as the human answer. This helps when the correct answer can be phrased in many valid ways.

For a QA system, the most common metrics in the research literature are typically Exact Match and F1 (token-level), as they are well-established, simple to interpret, and widely accepted. ROUGE and BLEU can provide additional insights, and BERTScore can be helpful if semantic equivalence matters more than exact wording.

In [None]:
def test(test_name, model_llm, tokenizer, refiner, P, test_data):
    """
    The full test pipeline.

    test_data: List of dicts with keys "question" and "answer" e.g. [{"question": "Q?", "answer": "A"}, ...]
    """

    def compute_f1(pred, ref):
        # Simple token-based F1
        pred_tokens = pred.split()
        ref_tokens = ref.split()
        common = set(pred_tokens) & set(ref_tokens)
        # Count how many tokens appear in both
        num_common = sum(min(pred_tokens.count(t), ref_tokens.count(t)) for t in common)
        if len(pred_tokens) == 0 or len(ref_tokens) == 0:
            return 0.0
        precision = num_common / len(pred_tokens)
        recall = num_common / len(ref_tokens)
        if precision + recall == 0:
            return 0.0
        return 2 * (precision * recall) / (precision + recall)

    # Initialize metrics
    exact_matches = 0
    f1_scores = []
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    rouge1_scores = []
    rougeL_scores = []
    bleu_scores = []
    # For BERTScore (optional):
    bert_references = []
    bert_candidates = []

    for example in test_data:
        question = example["question"]
        true_answer = example["answer"]

        # Generate refined prompt
        refined_prompt = generate_refined_prompt(refiner, question)

        # Construct final input
        final_input_text = refined_prompt + ": " + question
        final_input_ids = text_to_tokens(final_input_text)

        # Generate answer from LLM
        with torch.no_grad():
            gen_ids = model_llm.generate(final_input_ids, max_new_tokens=50)
        predicted_answer = tokenizer.decode(gen_ids[0], skip_special_tokens=True)

        # Sometimes the predicted answer might contain the prompt and question again.
        # You may need to strip that out. For simplicity, assume the model's answer
        # follows the prompt and question directly:
        # You might want to split by the original question if necessary.
        # For a naive approach:
        if final_input_text in predicted_answer:
            predicted_answer = predicted_answer.replace(final_input_text, '').strip()

        # Compute Exact Match
        if predicted_answer.strip().lower() == true_answer.strip().lower():
            exact_matches += 1

        # Compute F1
        f1_val = compute_f1(predicted_answer, true_answer)
        f1_scores.append(f1_val)

        # Compute ROUGE
        rouge_result = rouge_scorer_obj.score(true_answer, predicted_answer)
        rouge1_scores.append(rouge_result['rouge1'].fmeasure)
        rougeL_scores.append(rouge_result['rougeL'].fmeasure)

        # Compute BLEU
        # BLEU uses tokenized, reference as a list of lists
        ref_tokens = [true_answer.split()]
        pred_tokens = predicted_answer.split()
        bleu_val = sentence_bleu(ref_tokens, pred_tokens, smoothing_function=SmoothingFunction().method1)
        bleu_scores.append(bleu_val)

        # For BERTScore:
        bert_references.append(true_answer)
        bert_candidates.append(predicted_answer)

    # Aggregate results
    num_examples = len(test_data)
    exact_match_score = exact_matches / num_examples
    avg_f1 = sum(f1_scores) / num_examples
    avg_rouge1 = sum(rouge1_scores) / num_examples
    avg_rougeL = sum(rougeL_scores) / num_examples
    avg_bleu = sum(bleu_scores) / num_examples

    # For BERTScore:
    P, R, F1 = bert_score(bert_candidates, bert_references, lang='en')
    bert_f1 = torch.mean(F1).item()

    # Print results
    print(f"Evaluation Results: {test_name}")
    print(f"Exact Match: {exact_match_score:.4f}")
    print(f"F1 Score: {avg_f1:.4f}")
    print(f"ROUGE-1: {avg_rouge1:.4f}")
    print(f"ROUGE-L: {avg_rougeL:.4f}")
    print(f"BLEU: {avg_bleu:.4f}")
    print(f"BERTScore-F1: {bert_f1:.4f}")

    return {
        "exact_match": exact_match_score,
        "f1": avg_f1,
        "rouge1": avg_rouge1,
        "rougeL": avg_rougeL,
        "bleu": avg_bleu,
        "bert_f1": bert_f1
    }


In [None]:
def write_dict_to_json_with_timestamp(data_dict: dict, result_name) -> str:
    # Define the Eastern Time Zone (America/New_York)
    est_zone = ZoneInfo("America/New_York")
    # Get current date and time in EST
    now_est = datetime.now(est_zone)
    # Format the timestamp as YYYYMMDD_HHMMSS
    timestamp = now_est.strftime("%Y%m%d_%H%M%S")
    # Construct the filename
    filename = f"/content/drive/MyDrive/MS/MIT LGO/2024 Fall/6.7960 Deep Learning/6.7960 Final Project/Results/{result_name}_{timestamp}.json"

    # Write the dictionary to the JSON file
    with open(filename, "w") as f:
        json.dump(data_dict, f, indent=2)

    # Return the filename for reference
    return filename

write_dict_to_json_with_timestamp({"test":0}, "test")

'/content/drive/MyDrive/MS/MIT LGO/2024 Fall/6.7960 Deep Learning/6.7960 Final Project/Results/test_20241209_220100.json'

In [None]:
# FOUNDATIONAL EXPERIMENT

INITIAL_PROMPT = "Given the following question, produce a concise and refined prompt by incorporating the exact question and adding additional instructions that guide the model to reason step-by-step, carefully check its reasoning, and provide a correct and well-explained answer. The refined prompt should not reveal the solution but should encourage thorough verification and clarity."
epochs = [5, 10, 15, 20]

for epoch in epochs:
    print(f'Running epoch {epoch}...')

    P = INITIAL_PROMPT

    experiment_name = f"Foundational-num_epochs_=_{epoch}"
    refiner, train_result = train_model_with_experiment(experiment_name, P, train_data, epoch)
    write_dict_to_json_with_timestamp(train_result.__dict__, "train_" + experiment_name)

    test_result = test(experiment_name, model_llm, tokenizer, refiner, P, validation_data)
    write_dict_to_json_with_timestamp(test_result, "test_" + experiment_name)

Running epoch 1...


  0%|          | 0/1 [00:00<?, ?it/s]

Training @ epoch 0:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch =  0    Train loss = 5.4159
Training done for experiment Foundational - num_epochs = 1.


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Evaluation Results: Foundational - num_epochs = 1
Exact Match: 0.0000
F1 Score: 0.0000
ROUGE-1: 0.0196
ROUGE-L: 0.0196
BLEU: 0.0000
BERTScore-F1: 0.7882
Running epoch 5...


  0%|          | 0/5 [00:00<?, ?it/s]

Training @ epoch 0:   0%|          | 0/3 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Initialize experiments list
experiments = []

In [None]:
# ZERO-SHOT PROMPT EXPERIMENTS
experiments.append({
    "name": "Zero-Shot-Prompt1",
    "prompt": get_prompt_1()
})

experiments.append({
    "name": "Zero-Shot-Prompt2",
    "prompt": get_prompt_2()
})

experiments.append({
    "name": "Zero-Shot-Prompt3",
    "prompt": get_prompt_3()
})

experiments.append({
    "name": "Zero-Shot-Prompt4",
    "prompt": get_prompt_4()
})

In [None]:
# RUN EXPERIMENTS
"""
Experiment Obj
    name
    prompt
"""

EXPERIMENT_EPOCHS = 1

for experiment in experiments:
    experiment_name = experiment["name"]
    P = experiment["prompt"]

    print(f'Running experiment {experiment_name}...')

    refiner, train_result = train_model_with_experiment(experiment_name, P, train_data, EXPERIMENT_EPOCHS)
    test(experiment_name, model_llm, tokenizer, refiner, P, test_data)
    write_dict_to_json_with_timestamp(train_result.__dict__, "train_" + experiment_name)

    test_result = test(experiment_name, model_llm, tokenizer, refiner, P, validation_data)
    write_dict_to_json_with_timestamp(test_result, "test_" + experiment_name)