In [None]:
import os
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7")

In [None]:
%%capture
! pip install inflect
! pip install datasets==2.16.0
! pip install hf_xet

In [None]:
import re
from datetime import datetime, date, timedelta
from collections import defaultdict
import inflect
from datasets import load_dataset

In [None]:
from datasets import load_dataset

# point directly to the L1 validation file on the HF repo
VAL_L1 = "https://huggingface.co/datasets/tonytan48/TempReason/resolve/main/test_l1.json"

# load only that file as the "validation" split
ds = load_dataset("json", data_files={"validation": VAL_L1}, )

# grab the validation split and print the first 5 examples
val = ds["validation"]

print(val)

inflect_obj = inflect.engine()

In [None]:
pattern = re.compile(r'''
    (?P<offset_str>
        (?P<num1>\d+)\s*(?P<unit1>years?|months?)     # first offset
        (?:\s*and\s*(?P<num2>\d+)\s*(?P<unit2>years?|months?))?  # optional second offset
    )
    \s*
    (?P<direction>after|before)\s*
    (?P<date_str>
        (?P<month>[A-Za-z]+),\s*(?P<year>\d{4}) # base date
    )
''', re.IGNORECASE | re.VERBOSE)

In [None]:
def parse_question(q:str):
    m = pattern.search(q)
    if not m:
      raise ValueError(f"Could not parse question: {q}")
    parsed_components = m.groupdict()

    # initialize
    offset = {"year": 0, "month": 0}

    # populate year/month
    for i in (1, 2):
      num_key = f"num{i}"
      unit_key = f"unit{i}"
      num_val = parsed_components.get(num_key)
      unit_val = parsed_components.get(unit_key)

      if num_val and unit_val:
        num = int(num_val)
        unit = unit_val.lower().rstrip('s')
        if unit in offset:
          offset[unit] += num

    # build base_date (use day=1 by convention)
    mon_num = datetime.strptime(parsed_components["month"][:3], "%b").month
    yr = int(parsed_components["year"])
    base_date = date(yr, mon_num, 1)

    # extract direction and the original string parts
    direction = parsed_components["direction"].lower()
    original_offset_str = parsed_components["offset_str"].strip()
    original_date_str = parsed_components["date_str"].strip()


    return {
        "offset": offset,
        "base_date": base_date,
        "direction": direction,
        "original_offset_str": original_offset_str, # keeping original for one format
        "original_date_str": original_date_str
    }

In [None]:
def format_offset_original(offset_info):
    return offset_info["original_offset_str"]

def format_offset_total_months_numeric(offset_info):
    total_months = offset_info["offset"]["year"] * 12 + offset_info["offset"]["month"]
    return f"{total_months} {inflect_obj.plural('month', total_months)}"

def format_offset_total_months_word(offset_info):
    total_months = offset_info["offset"]["year"] * 12 + offset_info["offset"]["month"]
    month_word = inflect_obj.number_to_words(total_months)
    return f"{month_word} {inflect_obj.plural('month', total_months)}"

In [None]:
def format_base_date_original(offset_info):
    return offset_info["original_date_str"]

def format_base_date_iso(offset_info):   # YYYY-MM
    base_date = offset_info["base_date"]
    return f"{base_date.year}-{base_date.month:02d}"

def format_base_date_ordinal(offset_info):   # 'Nth month of YYYY'
    month_num = offset_info["base_date"].month
    ordinal_month = inflect_obj.ordinal(month_num)
    year = offset_info["base_date"].year
    return f"{ordinal_month} month of {year}"

def format_base_date_words(offset_info):  # 'Nth month of YYYY (words)'
    month_num = offset_info["base_date"].month
    ordinal_month = inflect_obj.ordinal(month_num) # like 'second'
    year = offset_info["base_date"].year
    year_words = inflect_obj.number_to_words(year).replace(",", "") # remove commas like in "one thousand, nine hundred"
    return f"{ordinal_month} month of {year_words}"

In [None]:
ds = load_dataset("json", data_files={"validation": VAL_L1}, split="validation")
print(f"Loaded {len(ds)} examples.")

In [None]:
ds[0]

In [None]:
# define formatters
offset_formatters = {
    "original": format_offset_original,
    "total_months_numeric": format_offset_total_months_numeric,
    "total_months_word": format_offset_total_months_word,
}

base_date_formatters = {
    "original": format_base_date_original,
    "iso": format_base_date_iso,
    "ordinal_month": format_base_date_ordinal,
    "full_words": format_base_date_words,
}

In [None]:
# keys will be tuples like ('original', 'iso')
modified_datasets = defaultdict(list)

# process and generate new questions
print("Generating modified datasets")
processed_count = 0
error_count = 0
for example in ds:
    original_question = example['question']
    try:
        parsed_info = parse_question(original_question)

        # iterate through all combinations of formatters
        for offset_name, offset_func in offset_formatters.items():
            for date_name, date_func in base_date_formatters.items():
                # generate the formatted parts
                formatted_offset = offset_func(parsed_info)
                formatted_base_date = date_func(parsed_info)
                direction = parsed_info["direction"]

                # construct the new question
                new_question = f"What is the time {formatted_offset} {direction} {formatted_base_date} in ISO format(YYYY-MM)?"

                # create the new example, preserving other fields
                new_example = example.copy()
                new_example['question'] = new_question
                new_example['original_question'] = original_question
                dataset_key = (offset_name, date_name)
                modified_datasets[dataset_key].append(new_example)

        processed_count += 1

    except ValueError as e:
        print(f"Skipping due to parsing error: {e}")
        error_count += 1
    except Exception as e:
        print(f"Skipping due to unexpected error: {e} for question: {original_question}")
        error_count += 1


print(f"\nProcessing complete.")
print(f"Successfully processed: {processed_count}")
print(f"Errors/Skipped: {error_count}")
print(f"Generated {len(modified_datasets)} dataset variations.") # should be 12

In [None]:
print("Example Generated Questions")

if ('original', 'iso') in modified_datasets and modified_datasets[('original', 'iso')]:
    print("\nFormat: Offset=Original, Base Date=ISO")
    print(f" Original: {modified_datasets[('original', 'iso')][0]['original_question']}")
    print(f" Modified: {modified_datasets[('original', 'iso')][0]['question']}")
    print(f" Answer: {modified_datasets[('original', 'iso')][0]['text_answers']}")

if ('total_months_numeric', 'ordinal_month') in modified_datasets and modified_datasets[('total_months_numeric', 'ordinal_month')]:
    print("\nFormat: Offset=Total Months Numeric, Base Date=Ordinal Month")
    print(f" Original: {modified_datasets[('total_months_numeric', 'ordinal_month')][0]['original_question']}")
    print(f" Modified: {modified_datasets[('total_months_numeric', 'ordinal_month')][0]['question']}")
    print(f" Answer: {modified_datasets[('total_months_numeric', 'ordinal_month')][0]['text_answers']}")

if ('total_months_word', 'full_words') in modified_datasets and modified_datasets[('total_months_word', 'full_words')]:
    print("\nFormat: Offset=Total Months Word, Base Date=Full Words")
    print(f" Original: {modified_datasets[('total_months_word', 'full_words')][0]['original_question']}")
    print(f" Modified: {modified_datasets[('total_months_word', 'full_words')][0]['question']}")
    print(f" Answer: {modified_datasets[('total_months_word', 'full_words')][0]['text_answers']}")

In [None]:
import json
import os

In [None]:
output_dir = "temp_reason_modified_datasets"
os.makedirs(output_dir, exist_ok=True)

for (offset_fmt, date_fmt), data_list in modified_datasets.items():
     filename = f"{offset_fmt}_offset_{date_fmt}_date.jsonl"
     filepath = os.path.join(output_dir, filename)
     print(f"Saving {filepath}")
     with open(filepath, 'w', encoding='utf-8') as f:
         for item in data_list:
             json.dump(item, f)
             f.write('\n')
print("Finished saving.")

In [None]:
import os
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, T5ForConditionalGeneration, T5Tokenizer
from huggingface_hub import login
import re
from collections import defaultdict

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
print(f"Number of available GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"Memory allocated: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB")
        print(f"Memory cached: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB")


In [None]:
import torch
from transformers import pipeline
from huggingface_hub import login
login(token="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")
model_id = "meta-llama/Llama-3.2-3B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    tokenizer=tokenizer  # This will default to eos_token_id for open-ended generation
)


In [None]:
@torch.no_grad()
def generate_batch_response_llama(prompts_batch, max_new_tokens=128):
    prompt_messages = [[
        {"role": "system", "content": "You are a Helpful assistant"},
        {"role": "user", "content": "Answer the following question: \n" + prompt + "\nExplain how you arrive at the result briefly. Then, on the next line, output **only** the final date in YYYY-MM format, with no extra words"}] for prompt in prompts_batch]

    outputs = pipe(
      prompt_messages,
      max_new_tokens=max_new_tokens,
      )
    generated_texts = []

    for conv in outputs:
        updated_chat = conv[0]["generated_text"][-1]['content'].split('\n')[-1]
        generated_texts.append(updated_chat)
    return generated_texts


In [None]:
selected_dataset_files = [
    "original_offset_iso_date.jsonl",
    "total_months_numeric_offset_ordinal_month_date.jsonl",
    "total_months_word_offset_full_words_date.jsonl",
    "original_offset_original_date.jsonl"
]
results_output_dir = "./model_predictions_on_4_datasets_BATCHED"
os.makedirs(results_output_dir, exist_ok=True)

In [None]:
def create_batches(data_list, batch_size_val):
    for i in range(0, len(data_list), batch_size_val):
        yield data_list[i:i + batch_size_val]

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
generated_datasets_dir = "temp_reason_modified_datasets"
batch_size = 8
max_gen_tokens = 128

for dataset_filename in selected_dataset_files:
    input_filepath = os.path.join(generated_datasets_dir, dataset_filename)
    print(f"\nProcessing dataset: {dataset_filename}")

    current_dataset_items = []
    with open(input_filepath, 'r', encoding='utf-8') as f:
        for line in f:
            current_dataset_items.append(json.loads(line))
    results_for_this_dataset = []

    for batch_idx, batch_items in enumerate(create_batches(current_dataset_items, batch_size)):
        if (batch_idx + 1) % 10 == 0:
            total_batches = (len(current_dataset_items) + batch_size - 1) // batch_size
            print(f"  Processing Batch {batch_idx + 1}/{total_batches}")

        prompts_batch = [item['original_question'] for item in batch_items]

        try:
            llama_preds_batch = generate_batch_response_llama(prompts_batch, max_new_tokens=max_gen_tokens)
        except Exception as e:
            print(f"    Error Llama on BATCH {batch_idx + 1}: {e}")
            llama_preds_batch = [f"Error Llama: {e}"] * len(prompts_batch)


        for i, original_item in enumerate(batch_items):
                prediction_item = original_item.copy()
                prediction_item['llama_prediction'] = llama_preds_batch[i]
                results_for_this_dataset.append(prediction_item)


    if results_for_this_dataset:
        output_filename = f"predictions_BATCHED_{dataset_filename}"
        output_filepath = os.path.join(results_output_dir, output_filename)

        with open(output_filepath, 'w', encoding='utf-8') as f:
            for res_item in results_for_this_dataset:
                json.dump(res_item, f)
                f.write('\n')

print("\nAll selected datasets processed and results saved.")

In [None]:
answer_column_name = 'llama_prediction'

In [None]:
import numpy as np
import calendar

# Build a lookup from month names / abbreviations → zero‑padded month number
_month_lookup = {}
for month_idx in range(1, 13):
    month_num_str = f"{month_idx:02d}"
    month_forms = [
        calendar.month_name[month_idx].lower(),
        calendar.month_abbr[month_idx].lower().rstrip('.')
    ]
    for form in month_forms:
      for prefix_len in range(3, len(form) + 1):
            _month_lookup[form[:prefix_len]] = month_num_str

_year_re = re.compile(r"(\d{4})")

def _normalize_text(txt: str) -> str:
    """
    Canonicalise various date strings to ISO 'YYYY-MM' where possible,
    otherwise fallback to lowercased / whitespace-collapsed text.

    Examples:
        "Mar, 1789"   -> "1789-03"
        "march 1789"  -> "1789-03"
        "1789-03-12"  -> "1789-03"
        "1789-03"     -> "1789-03"
    """
    if not txt or not isinstance(txt, str):
        return ""

    s = " ".join(txt.strip().lower().split()).replace("*", "")

    # 1) ISO patterns: YYYY-MM or YYYY-MM-DD
    m_iso = re.match(r"^(?P<year>\d{4})-(?P<month>\d{2})(?:-\d{2})?$", s)
    if m_iso:
        return f"{m_iso.group('year')}-{m_iso.group('month')}"

    # 2) Month name patterns
    month_pattern = "|".join(re.escape(month) for month in _month_lookup.keys())

    # Pattern: YYYY month_name [YYYY]?
    pattern = rf"^(?P<year1>\d{{4}})\s+(?P<month_name>{month_pattern})[\.,]?\s*(?P<year2>\d{{4}})?$"
    m_name = re.match(pattern, s)
    if m_name:
        month_str = m_name.group("month_name")
        year_str = m_name.group("year1")
        month_num = _month_lookup.get(month_str)
        if month_num:
            return f"{year_str}-{month_num}"

    # Pattern: month_name YYYY
    m_month_year = re.match(rf"^(?P<month_name>{month_pattern})[\.,]?\s+(?P<year>\d{{4}})$", s)
    if m_month_year:
        month_num = _month_lookup.get(m_month_year.group("month_name"))
        if month_num:
            return f"{m_month_year.group('year')}-{month_num}"

    # 3) If no conversion matched, return the cleaned text
    return s

# Extract the first year found in the text
def _extract_year(txt: str):
    m = _year_re.search(txt)
    return int(m.group(1)) if m else None


def _reference_year(question: str):
    """
    Extract the YYYY that appears *last* in the question –
    this is the base date in all L1 questions like '... after Jul, 1699'.
    """
    years = _year_re.findall(question)
    return int(years[-1]) if years else None


def evaluate_predictions(results_for_this_dataset, answer_column_name, dataset_filename, results_output_dir):

    g_em, g_abs_err, g_trend_ok, count_year = 0, 0, 0, 0
    total_examples = len(results_for_this_dataset)

    for item in results_for_this_dataset:
        # Extract gold and predicted answers
        gold = (item["text_answers"]["text"][0]
                if isinstance(item["text_answers"], dict)
                else item["text_answers"])
        pred = item.get(answer_column_name, "")
        question = item["original_question"]

        # Exact Match
        if _normalize_text(gold) in _normalize_text(pred):
            g_em += 1
        else:
            print(_normalize_text(pred), _normalize_text(gold))

        # Year-based metrics
        year_gold = _extract_year(gold)
        year_pred = _extract_year(pred)
        year_ref = _reference_year(question)

        if year_gold is not None and year_pred is not None:
            g_abs_err += abs(year_pred - year_gold)

            # Trend: sign wrt reference year
            if year_ref is not None:
                gold_sign = np.sign(year_gold - year_ref)
                pred_sign = np.sign(year_pred - year_ref)
                if gold_sign == pred_sign and gold_sign != 0:
                    g_trend_ok += 1
            count_year += 1

    exact_match = g_em / total_examples if total_examples else 0.0
    mae = g_abs_err / count_year if count_year else 0.0
    trend_accuracy = g_trend_ok / count_year if count_year else 0.0

    print(f"=== Evaluation for {dataset_filename} ===")
    print(f"  Exact Match        : {exact_match:.4f}")
    print(f"  Mean Absolute Error: {mae:.4f}")
    print(f"  Trend Accuracy     : {trend_accuracy:.4f}")

    # Save metrics to JSON file
    metrics_path = os.path.join(results_output_dir,
                                f"metrics_{dataset_filename.replace('.jsonl', '.json')}")
    metrics = {
        "dataset": dataset_filename,
        "num_examples": total_examples,
        "exact_match": exact_match,
        "mae_year": mae,
        "trend_accuracy": trend_accuracy,
    }

    with open(metrics_path, 'w', encoding='utf-8') as mf:
        json.dump(metrics, mf, indent=2)

    return metrics

In [None]:
metrics = evaluate_predictions(
    results_for_this_dataset=results_for_this_dataset,
    answer_column_name=answer_column_name,
    dataset_filename=dataset_filename,
    results_output_dir=results_output_dir
)
print(f"\nReturned metrics: {metrics}")