In [83]:
import torch

from add_thin.metrics import MMD, lengths_distribution_wasserstein_distance
from add_thin.evaluate_utils import get_task, get_run_data

In [88]:
# Set run id and paths
RUN_ID = "xg0yfv2c"

WANDB_DIR = "/raid/ai23mtech11004/TTPflow/TPP_flow_matching-finaldone/outputs/wandb/wandb"
PROJECT_ROOT = "/raid/ai23mtech11004/TTPflow/TPP_flow_matching-finaldone/"  # should include data folder

# outputs/wandb/wandb/run-20240911_113136-7nn6ltny


In [89]:
def sample_model(task, tmax, n=4000):
    """
    Unconditionally draw n event sequences from Add Thin.
    """
    with torch.no_grad():
        samples = task.model.sample(n, tmax=tmax.to(task.device)).to_time_list()

    assert len(samples) == n, "not enough samples"
    return samples

In [90]:
# Get run data
data_name, seed, run_path = get_run_data(RUN_ID, WANDB_DIR)

# Get task and datamodule
task, datamodule = get_task(run_path, density=True, data_root=PROJECT_ROOT)

# Get test sequences
test_sequences = []
for (
    batch
) in (
    datamodule.test_dataloader()
):  # batch is set to full test set, but better be safe
    test_sequences = test_sequences + batch.to_time_list()

# Sample event sequences from trained model
samples = sample_model(task, datamodule.tmax, n = 4000)

# Evaluate metrics against test dataset
mmd = MMD(
    samples,
    test_sequences,
    datamodule.tmax.detach().cpu().item(),
)[0]
wasserstein = lengths_distribution_wasserstein_distance(
    samples,
    test_sequences,
    datamodule.tmax.detach().cpu().item(),
    datamodule.n_max,
)

# Print rounded results for data and seed
print("ADD and Thin density evaluation:")
print("================================")
print(
    f"{data_name} (Seed: {seed}): MMD: {mmd:.3f}, Wasserstein: {wasserstein:.3f}"
)

AttributeError: 'AddThin' object has no attribute 'history'

In [77]:
import os
from omegaconf import OmegaConf

def get_task(path, density=True, data_root="/path/to/data"):
    """
    Load task and datamodule for a given run path.

    Parameters:
    -----------
        path (str): Path to the model directory.
        data_root (str): Path to the data directory.

    Returns:
    -----------
        task (Task): The task object.
        datamodule (DataModule): The datamodule object.
    """
    print(f"Attempting to load task from path: {path}")
    
    model_path = os.path.join(path, "checkpoints", "best.ckpt")
    config_path = os.path.join(path, "config_hydra.yaml")
    
    print(f"Checking for model checkpoint at: {model_path}")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Checkpoint file not found: {model_path}")
    
    print(f"Checking for config file at: {config_path}")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found: {config_path}")
    
    with open(config_path, "r") as stream:
        config = OmegaConf.load(stream)

    # get config and set seed
    OmegaConf.resolve(config)
    _ = set_seed(config)

    # load data
    config.data.root = os.path.join(data_root, config.data.root)
    print(f"Data root set to: {config.data.root}")
    datamodule = instantiate_datamodule(config.data, config.task.name)
    datamodule.prepare_data()

    # load model
    model = instantiate_model(config.model, datamodule)
    try:
        if density:
            task = DensityEstimation.load_from_checkpoint(model_path, model=model)
        else:
            task = Forecasting.load_from_checkpoint(model_path, model=model)
    except Exception as e:
        print(f"Error loading checkpoint: {str(e)}")
        raise

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    task.to(device)
    task.eval()

    return task, datamodule

# Example usage
try:
    run_path = "/raid/ai23mtech11004/TTPflow/TPP_flow_matching-finaldone/outputs/wandb/wandb/run-20240911_113136-7nn6ltny/files/"
    PROJECT_ROOT = "/raid/ai23mtech11004/TTPflow/TPP_flow_matching-finaldone/"
    
    print(f"Run path: {run_path}")
    print(f"Project root: {PROJECT_ROOT}")
    
    task, datamodule = get_task(run_path, density=True, data_root=PROJECT_ROOT)
    print("Task and datamodule loaded successfully")
except FileNotFoundError as e:
    print(f"File not found error: {str(e)}")
except Exception as e:
    print(f"An unexpected error occurred: {str(e)}")

Run path: /raid/ai23mtech11004/TTPflow/TPP_flow_matching-finaldone/outputs/wandb/wandb/run-20240911_113136-7nn6ltny/files/
Project root: /raid/ai23mtech11004/TTPflow/TPP_flow_matching-finaldone/
Attempting to load task from path: /raid/ai23mtech11004/TTPflow/TPP_flow_matching-finaldone/outputs/wandb/wandb/run-20240911_113136-7nn6ltny/files/
Checking for model checkpoint at: /raid/ai23mtech11004/TTPflow/TPP_flow_matching-finaldone/outputs/wandb/wandb/run-20240911_113136-7nn6ltny/files/checkpoints/best.ckpt
File not found error: Checkpoint file not found: /raid/ai23mtech11004/TTPflow/TPP_flow_matching-finaldone/outputs/wandb/wandb/run-20240911_113136-7nn6ltny/files/checkpoints/best.ckpt
