# Access to Data and Model

Before you continue please download the data and the model from: https://figshare.com/s/3cb512c318285fba5051. Unpack it and move the _model_ to __models__ and the _data_ to __data__.

In [4]:
import torch
from pathlib import Path
from fim.models.hawkes import FIMHawkes
from utils import load_data_from_dir, prepare_batch_for_model, plot_intensity_comparison, _move_to_device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_checkpoint = Path("../models/")
dataset_dir = Path("../data/")

print(f"Loading model from: {model_checkpoint}")
print(f"Loading dataset from: {dataset_dir}")

# Load model with proper weight loading (bypasses transformers issues)
model = FIMHawkes.load_model(model_checkpoint)
model.eval()
model.to(device)

data = load_data_from_dir(dataset_dir)
if not data:
    raise ValueError("No data loaded. Exiting.")


print(f"Loaded data with keys: {list(data.keys())}")

# Use the specified sample index
sample_idx = 10
path_idx = 0

# Validate sample index
sample_count = None
for key, value in data.items():
    if torch.is_tensor(value):
        sample_count = value.shape[0]
        break

if sample_count is not None and sample_idx >= sample_count:
    print(f"Warning: sample_idx {sample_idx} >= number of samples {sample_count}. Using sample 0.")
    sample_idx = 0

print(f"Using sample index: {sample_idx}")

single_sample_data = {}
for key, value in data.items():
    if torch.is_tensor(value):
        single_sample_data[key] = value[sample_idx]
    else:
        # Handle non-tensor data if necessary, e.g., lists of tensors
        single_sample_data[key] = value[sample_idx]

try:
    model_data = prepare_batch_for_model(single_sample_data, path_idx, num_points_between_events=10)
except ValueError as e:
    raise ValueError(f"Error preparing batch: {e}")

print("Model input shapes:")
for key, value in model_data.items():
    if torch.is_tensor(value):
        print(f"  {key}: {value.shape}")

print(f"Using path index: {path_idx}")

# Ensure model inputs are on the same device as the model
model_data = _move_to_device(model_data, device)

with torch.no_grad():
    model_output = model(model_data)

print(f"Model output keys: {list(model_output.keys())}")

# save_path = f"intensity_comparison_sample_{sample_idx}_path_{path_idx}.png"
plot_intensity_comparison(model_output, model_data, path_idx=path_idx)
