In [None]:
%load_ext autoreload
%autoreload 2

## Table: Dataset stats

In [None]:
import sys
sys.path.append('..')

from src import data, models, functional
import torch
import os
import pandas as pd
import json

In [None]:
############################################################
model_name = "gptj"
device = "cuda" if torch.cuda.is_available() else "cpu"
results_dir = f"../results/known_samples/"
############################################################
os.makedirs(results_dir, exist_ok=True)

In [None]:

mt = models.load_model(model_name, fp16=True, device=device)

In [None]:
dataset = data.load_dataset()
filtered = functional.filter_dataset_samples(
    mt=mt,
    dataset=dataset,
    n_icl_lm=functional.DEFAULT_N_ICL_LM,
    n_trials=3,
    batch_size=1,
)

In [None]:
relations_by_name = {r.name: r for r in dataset.relations}
filtered_by_name = {r.name: r for r in filtered.relations}

samples_known = {}
for name in relations_by_name:
    relation_samples = set(relations_by_name[name].samples)
    filtered_samples = set(filtered_by_name[name].samples) if name in filtered_by_name else set()
    samples_known[name] = {
        "known": len(filtered_samples),
        "total": len(relation_samples),
        "known_samples": [
            {
                "subject": sample.subject,
                "object": sample.object,
            }
            for sample in filtered_samples
        ],
    }

with open(f"{results_dir}/{model_name}.json", "w") as f:
    json.dump(samples_known, f, indent=4)

In [None]:
df_json = [
    {
        "relation": key,
        "total": value["total"],
        model_name: value["known"],
    } for key, value in samples_known.items()]

df = pd.DataFrame(df_json)
df.to_excel(f"{results_dir}/{model_name}.xlsx", index=False)

In [None]:
dataset[0].properties.__dict__

In [None]:
from typing import Literal
property_key: Literal["relation_type", "fn_type", "disambiguating", "symmetric"] = "disambiguating"

category_wise = {}
for name in relations_by_name:
    property_value = relations_by_name[name].properties.__dict__[property_key]
    relation_samples = set(relations_by_name[name].samples)
    filtered_samples = set(filtered_by_name[name].samples) if name in filtered_by_name else set()
    if property_value not in category_wise:
        category_wise[property_value] = {
            "known": 0,
            "total": 0,
        }
    category_wise[property_value]["known"] += len(filtered_samples)
    category_wise[property_value]["total"] += len(relation_samples)

In [None]:
df_json = [
    {
        property_key: key,
        "total": value["total"],
        model_name: value["known"],
    } for key, value in category_wise.items()
]

pd.DataFrame(df_json)