In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import os
import json

In [2]:
def load_tomi_dataset(
    data_dir="data/tom/ToMi",
    train_file="train.txt",
    trace_file="train.trace"
):
    # Read trace file to get belief labels
    trace_path = os.path.join(data_dir, trace_file)
    with open(trace_path, "r") as f:
        trace_lines = [line.strip() for line in f if line.strip()]
    # Read train file to get stories, questions, answers
    train_path = os.path.join(data_dir, train_file)
    with open(train_path, "r") as f:
        train_lines = [line.strip() for line in f if line.strip()]
    # Each story is separated by a blank line or by the repeated pattern
    # We'll parse by looking for lines starting with '1 ' (start of story)
    dataset = []
    i = 0
    trace_idx = 0
    while i < len(train_lines):
        # Collect story lines
        story_lines = []
        while i < len(train_lines) and not train_lines[i].startswith("9 ") and not train_lines[i].startswith("10 ") and not train_lines[i].startswith("11 "):
            story_lines.append(train_lines[i])
            i += 1
        # The question line
        if i < len(train_lines):
            question_line = train_lines[i]
            i += 1
        else:
            break
        # Parse question, answer
        if "\t" in question_line:
            q, a, _ = question_line.split("\t")
        else:
            q, a = question_line, ""
        # Get the corresponding belief label from trace file
        if trace_idx < len(trace_lines):
            trace = trace_lines[trace_idx]
            trace_idx += 1
            # The last field is the belief label
            belief = trace.split(",")[-1]
        else:
            belief = None
        # Save the entry
        dataset.append({
            "story": "\n".join(story_lines),
            "question": q,
            "answer": a,
            "belief": belief
        })
    return dataset

In [3]:
import csv
def load_bigtom_dataset(csv_path="data/tom/BigToM/bigtom.csv"):
    entries = []
    with open(csv_path, newline='', encoding='utf-8') as f:
        reader = csv.reader(f, delimiter=';')
        for row in reader:
            if not row or len(row) < 19:
                continue  # skip empty or malformed rows

            story = row[0]
            # Belief question and answers
            question = row[5]
            answer_saw = row[8]   # character saw the key event
            answer_not_saw = row[11]  # character did not see the key event

            # Entry where character saw the event (belief matches reality)
            entries.append({
                "story": story,
                "question": question,
                "answer": answer_saw,
                "belief": True
            })
            # Entry where character did not see the event (belief does not match reality)
            entries.append({
                "story": story,
                "question": question,
                "answer": answer_not_saw,
                "belief": False
            })
    return entries

In [6]:
model = AutoModelForCausalLM.from_pretrained("./models/mistral-7b", output_attentions=False)
tokenizer = AutoTokenizer.from_pretrained("./models/mistral-7b")
model.eval()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

: 

: 

: 