In [1]:
import sys, os
cmd = '''run.ipynb
--subset biology
--tree_version bottom-up
--llm_max_concurrent_calls 20
--num_eval_samples 10
'''
sys.argv = cmd.split()

## Imports

In [2]:
%cd /home/nilesh/work/lattice/release/src

/home/nilesh/work/lattice/release/src


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
#region Imports
import numpy as np
import pandas as pd
import pickle as pkl
import seaborn as sns
from datasets import load_dataset
from tqdm.auto import tqdm
import os
import logging
from hyperparams import HyperParams
from tree_objects import SemanticNode, InferSample
from llm_apis import GenAIAPI
from prompts import get_traversal_prompt_response_constraint, get_reranking_prompt
from utils import (
    setup_logger, 
    compute_node_registry,
    get_all_leaf_nodes_with_path, 
    get_node_id, 
    post_process, 
    save_exp, 
    load_exp,
    init_wandb_logging,
    finish_wandb_logging,
    wandb_log_iteration_metrics,
    wandb_log_reranking_metrics,
    wandb_log_final_summary,
    visualize_sample,
)
#endregion

#region Setup
hp = HyperParams.from_args()
BASE_DIR = '/home/nilesh/work/lattice/release'
RESULTS_DIR = f'{BASE_DIR}/results/BRIGHT/{hp.SUBSET}/'
os.makedirs(RESULTS_DIR, exist_ok=True)
logger = setup_logger('lattice_notebook', f"{RESULTS_DIR}/{hp}.log", logging.INFO)

# Initialize wandb logging
run_name = init_wandb_logging(hp, RESULTS_DIR, mode_override="disabled")
logger.info(f"Initialized wandb run: {run_name}")
#endregion

2025-10-16 00:14:12,654 - lattice_notebook - INFO - Initialized wandb run: dummy-4fionv5r


Log file already exists: /home/nilesh/work/lattice/release/results/BRIGHT/biology/S=biology-TV=bottom-up-TPV=5-RInTP=-1-NumLC=10-PlTau=5.0-RCF=0.5-LlmApiB=genai-Llm=gemini-2.5-flash-NumI=20-NumES=10-MaxBS=2.log, appending to it.


## Data loading

In [4]:
#region Data loading
docs_df = pd.DataFrame(load_dataset('xlangai/BRIGHT', 'documents', split=hp.SUBSET))
examples_df = pd.DataFrame(load_dataset('xlangai/BRIGHT', 'examples', split=hp.SUBSET))
doc_id_to_content = {docs_df.iloc[i].id: docs_df.iloc[i].content for i in range(len(docs_df))}

tree_dict = pkl.load(open(f'{BASE_DIR}/trees/BRIGHT/{hp.SUBSET}/tree-{hp.TREE_VERSION}.pkl', 'rb'))
semantic_root_node = SemanticNode().load_dict(tree_dict) if isinstance(tree_dict, dict) else tree_dict
node_registry = compute_node_registry(semantic_root_node)
all_leaf_nodes = get_all_leaf_nodes_with_path(semantic_root_node)
doc_id_to_path = {get_node_id(leaf.id, docs_df): path for leaf, path in all_leaf_nodes}
#endregion

In [5]:
semantic_root_node.child[0]

ID: [0], Num children: 9, Description: This topic cluster provides a multi-disciplinary exploration of human and animal biology, behavior, and evolution. It covers fundamental life processes such as reproduction, development, and aging (senescence), as well as specific anatomical systems like the eye and joints. The collection delves into key evolutionary concepts like Cope's rule and speciation, and examines traits with both biological and cultural dimensions, including handedness, color perception, and kissing. It also provides a detailed overview of the genus *Homo*, focusing on humans and Neanderthals.

First 4 Children:

[0, [1], 9 children] This collection provides a comprehensive exploration of topics related to reproduction, development, and family dynamics in both humans and animals. It covers the biological and medical aspects of conception and birth, including twinning (monozygotic, dizygotic), superfecundation, parthenogenesis, and parturition. The node also delves into the

## Setup

In [20]:
assert os.environ.get('GOOGLE_API_KEY') is not None, "Please set the GOOGLE_API_KEY environment variable."

In [21]:
#region Setup LLM API and Eval Samples
if hp.LLM_API_BACKEND == 'genai': llm_api = GenAIAPI(hp.LLM, logger=logger, timeout=120, max_retries=4)
else: raise ValueError(f'Unknown LM API backend: {hp.LLM_API_BACKEND}')

llm_api_kwargs = {
    'max_concurrent_calls': hp.LLM_MAX_CONCURRENT_CALLS,
    'response_mime_type': 'application/json',
    'response_schema': get_traversal_prompt_response_constraint(hp.TRAVERSAL_PROMPT_VERSION),
}

2025-10-16 00:03:54,945 - lattice_notebook - INFO - Initialized client for model: gemini-2.5-flash
2025-10-16 00:03:55,127 - lattice_notebook - INFO - Initialized Google GenAI client with model: gemini-2.5-flash


In [6]:
from utils import load_exp

if hp.LOAD_EXISTING:
  all_eval_samples, all_eval_metric_dfs = load_exp(RESULTS_DIR, hp, semantic_root_node, node_registry, logger)
  if len(all_eval_samples) > 0:
    logger.info(f'Loaded existing experiment with {len(all_eval_samples)} eval samples and {len(all_eval_metric_dfs)} eval metric dfs')
  if len(all_eval_samples) > 0:
    [sample.update_relevances(sample.prediction_tree) for sample in all_eval_samples]
    eval_metric_df = pd.DataFrame([sample.compute_eval_metrics(k=10) for sample in all_eval_samples])
    logger.info('; '.join([f'{k}: {eval_metric_df[k].mean():.2f}' for k in eval_metric_df.columns]))
else:
  all_eval_samples, all_eval_metric_dfs = [], []
  for i in range(min(examples_df.shape[0], hp.NUM_EVAL_SAMPLES)):
    sample = InferSample(
        semantic_root_node,
        node_registry,
        hp=hp,
        logger=logger,
        query=examples_df.iloc[i]['query'][:hp.MAX_QUERY_CHAR_LEN],
        gold_paths=[doc_id_to_path[docid] for docid in examples_df.iloc[i]['gold_ids']],
        excluded_ids_set=set(examples_df.iloc[i]['excluded_ids']),
        )
    all_eval_samples.append(sample)
  assert not any([sample.prediction_tree.excluded for sample in tqdm(all_eval_samples)])
logger.info('Hyperparams:\n'+'\n'.join([f'{k}:\t{v}' for k, v in vars(hp).items()]))
#endregion

  0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:14:34,711 - lattice_notebook - INFO - Hyperparams:
subset:	biology
tree_version:	bottom-up
traversal_prompt_version:	5
reasoning_in_traversal_prompt:	-1
max_query_char_len:	None
max_doc_desc_char_len:	None
max_prompt_proto_size:	None
search_with_path_relevance:	True
num_leaf_calib:	10
pl_tau:	5.0
relevance_chain_factor:	0.5
llm_api_backend:	genai
llm:	gemini-2.5-flash
llm_max_concurrent_calls:	20
llm_api_timeout:	120
llm_api_max_retries:	4
num_iters:	20
num_eval_samples:	10
max_beam_size:	2
rerank:	False
load_existing:	False
num_threads:	32
suffix:	


## Retrieval loop

In [23]:
for i in tqdm(range(len(all_eval_metric_dfs), hp.NUM_ITERS)):
    logger.info(f'-------------------- Iter {i} --------------------')
    
    inputs = [sample.get_step_prompts() for sample in all_eval_samples]
    indptr = np.cumsum([0, *[len(x) for x in inputs]])
    flat_inputs = [y for x in inputs for y in x]
    flat_prompts, flat_slates = list(zip(*flat_inputs))
    slates = [flat_slates[indptr[j]:indptr[j+1]] for j in range(len(inputs))]

    flat_responses = await llm_api.run_batch(flat_prompts, **llm_api_kwargs)
    flat_response_jsons = [post_process(output, return_json=True) for output in tqdm(flat_responses)]
    response_jsons = [flat_response_jsons[indptr[j]:indptr[j+1]] for j in range(len(inputs))]

    for sample, sample_slates, sample_response_jsons in tqdm(zip(all_eval_samples, slates, response_jsons), total=len(all_eval_samples), desc='Updating samples'):
      sample.update(sample_slates, sample_response_jsons)

    eval_metric_df = pd.DataFrame([sample.compute_eval_metrics(k=10) for sample in all_eval_samples])
    all_eval_metric_dfs.append(eval_metric_df)
    
    # Log metrics
    wandb_log_iteration_metrics(eval_metric_df, i)
    logger.info('; '.join([f'{k}: {eval_metric_df[k].mean():.2f}' for k in eval_metric_df.columns]))
    # save_exp(RESULTS_DIR, hp, llm_api, all_eval_samples, all_eval_metric_dfs, allow_overwrite=True)  
    logger.info('-'*50)

  0%|          | 0/20 [00:00<?, ?it/s]

2025-10-16 00:04:00,614 - lattice_notebook - INFO - -------------------- Iter 0 --------------------
2025-10-16 00:04:00,629 - lattice_notebook - INFO - Running a batch of 10 prompts...
2025-10-16 00:04:00,631 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.


Processing batch: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:21<00:00,  2.12s/it, errors=0, active=0, completed=10, 429s=0, 503s=0]
2025-10-16 00:04:21,805 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:04:21,807 - lattice_notebook - INFO - Total Duration: 21.17 seconds
2025-10-16 00:04:21,808 - lattice_notebook - INFO - Total Prompts: 10
2025-10-16 00:04:21,808 - lattice_notebook - INFO - Successful: 10 (100.0%)
2025-10-16 00:04:21,809 - lattice_notebook - INFO - Failed: 0
2025-10-16 00:04:21,810 - lattice_notebook - INFO - Total Error Occurrences: 0
2025-10-16 00:04:21,814 - lattice_notebook - INFO - 
RETRY STATISTICS:
2025-10-16 00:04:21,815 - lattice_notebook - INFO -   1 attempt(s): 10 requests (100.0%)
2025-10-16 00:04:21,816 - lattice_notebook - INFO -   Average attempts per successful request: 1.00
2025-10-16 00:04:21,817 - lattice_notebook - INFO - 
THROUGHPUT:
2025-10-16 00:04:21,818 - lattice_notebook - INFO - 

  0%|          | 0/10 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:04:26,319 - lattice_notebook - INFO - nDCG@10: 0.00; Recall@10: 0.00; Recall@100: 0.00; Recall@all: 0.00; Coverage: 0.00
2025-10-16 00:04:26,321 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:04:26,323 - lattice_notebook - INFO - -------------------- Iter 1 --------------------
2025-10-16 00:04:26,334 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:04:26,336 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch: 100%|████████████████████████████████████████████████████████████████| 20/20 [00:35<00:00,  1.76s/it, errors=0, active=0, completed=20, 429s=0, 503s=0]
2025-10-16 00:05:01,641 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:05:01,644 - lattice_notebook - INFO - Total Duration: 35.30 seconds
2025-10-16 00:05:01,644 - lattice_notebook - INFO - Total Prompts: 20
2025-10-16 00:05:01,645 - lattice_notebook - INFO - Successful: 20 (100.0%)


  0%|          | 0/20 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:05:04,650 - lattice_notebook - INFO - nDCG@10: 0.00; Recall@10: 0.00; Recall@100: 0.00; Recall@all: 0.00; Coverage: 0.00
2025-10-16 00:05:04,653 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:05:04,655 - lattice_notebook - INFO - -------------------- Iter 2 --------------------
2025-10-16 00:05:04,667 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:05:04,668 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch: 100%|████████████████████████████████████████████████████████████████| 20/20 [00:35<00:00,  1.76s/it, errors=0, active=0, completed=20, 429s=0, 503s=0]
2025-10-16 00:05:39,863 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:05:39,867 - lattice_notebook - INFO - Total Duration: 35.19 seconds
2025-10-16 00:05:39,868 - lattice_notebook - INFO - Total Prompts: 20
2025-10-16 00:05:39,868 - lattice_notebook - INFO - Successful: 20 (100.0%)


  0%|          | 0/20 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:05:42,935 - lattice_notebook - INFO - nDCG@10: 0.00; Recall@10: 0.00; Recall@100: 0.00; Recall@all: 0.00; Coverage: 0.00
2025-10-16 00:05:42,937 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:05:42,939 - lattice_notebook - INFO - -------------------- Iter 3 --------------------
2025-10-16 00:05:42,947 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:05:42,948 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch: 100%|████████████████████████████████████████████████████████████████| 20/20 [00:27<00:00,  1.37s/it, errors=0, active=0, completed=20, 429s=0, 503s=0]
2025-10-16 00:06:10,435 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:06:10,437 - lattice_notebook - INFO - Total Duration: 27.49 seconds
2025-10-16 00:06:10,439 - lattice_notebook - INFO - Total Prompts: 20
2025-10-16 00:06:10,440 - lattice_notebook - INFO - Successful: 20 (100.0%)


  0%|          | 0/20 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:06:13,840 - lattice_notebook - INFO - nDCG@10: 37.55; Recall@10: 29.43; Recall@100: 31.54; Recall@all: 31.54; Coverage: 8.20
2025-10-16 00:06:13,843 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:06:13,845 - lattice_notebook - INFO - -------------------- Iter 4 --------------------
2025-10-16 00:06:13,963 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:06:13,964 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch: 100%|████████████████████████████████████████████████████████████████| 20/20 [01:04<00:00,  3.23s/it, errors=0, active=0, completed=20, 429s=0, 503s=0]
2025-10-16 00:07:18,506 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:07:18,509 - lattice_notebook - INFO - Total Duration: 64.54 seconds
2025-10-16 00:07:18,510 - lattice_notebook - INFO - Total Prompts: 20
2025-10-16 00:07:18,512 - lattice_notebook - INFO - Successful: 20 (100.

  0%|          | 0/20 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:07:21,732 - lattice_notebook - INFO - nDCG@10: 45.40; Recall@10: 41.93; Recall@100: 50.00; Recall@all: 50.00; Coverage: 26.90
2025-10-16 00:07:21,733 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:07:21,735 - lattice_notebook - INFO - -------------------- Iter 5 --------------------
2025-10-16 00:07:21,810 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:07:21,811 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch: 100%|████████████████████████████████████████████████████████████████| 20/20 [01:02<00:00,  3.15s/it, errors=0, active=0, completed=20, 429s=0, 503s=0]
2025-10-16 00:08:24,725 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:08:24,728 - lattice_notebook - INFO - Total Duration: 62.91 seconds
2025-10-16 00:08:24,729 - lattice_notebook - INFO - Total Prompts: 20
2025-10-16 00:08:24,730 - lattice_notebook - INFO - Successful: 20 (100

  0%|          | 0/20 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:08:27,947 - lattice_notebook - INFO - nDCG@10: 45.75; Recall@10: 44.82; Recall@100: 55.00; Recall@all: 55.00; Coverage: 46.00
2025-10-16 00:08:27,949 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:08:27,951 - lattice_notebook - INFO - -------------------- Iter 6 --------------------
2025-10-16 00:08:28,037 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:08:28,038 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch: 100%|████████████████████████████████████████████████████████████████| 20/20 [00:41<00:00,  2.08s/it, errors=0, active=0, completed=20, 429s=0, 503s=0]
2025-10-16 00:09:09,573 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:09:09,575 - lattice_notebook - INFO - Total Duration: 41.53 seconds
2025-10-16 00:09:09,577 - lattice_notebook - INFO - Total Prompts: 20
2025-10-16 00:09:09,578 - lattice_notebook - INFO - Successful: 20 (100

  0%|          | 0/20 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:09:12,981 - lattice_notebook - INFO - nDCG@10: 56.35; Recall@10: 53.82; Recall@100: 64.00; Recall@all: 64.00; Coverage: 59.00
2025-10-16 00:09:12,983 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:09:12,985 - lattice_notebook - INFO - -------------------- Iter 7 --------------------
2025-10-16 00:09:13,053 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:09:13,054 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch: 100%|████████████████████████████████████████████████████████████████| 20/20 [01:11<00:00,  3.60s/it, errors=0, active=0, completed=20, 429s=0, 503s=0]
2025-10-16 00:10:24,978 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:10:24,980 - lattice_notebook - INFO - Total Duration: 71.92 seconds
2025-10-16 00:10:24,981 - lattice_notebook - INFO - Total Prompts: 20
2025-10-16 00:10:24,982 - lattice_notebook - INFO - Successful: 20 (100

  0%|          | 0/20 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:10:29,047 - lattice_notebook - INFO - nDCG@10: 69.41; Recall@10: 67.16; Recall@100: 79.00; Recall@all: 79.00; Coverage: 81.30
2025-10-16 00:10:29,049 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:10:29,051 - lattice_notebook - INFO - -------------------- Iter 8 --------------------
2025-10-16 00:10:29,156 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:10:29,157 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch: 100%|████████████████████████████████████████████████████████████████| 20/20 [00:48<00:00,  2.42s/it, errors=0, active=0, completed=20, 429s=0, 503s=0]
2025-10-16 00:11:17,503 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:11:17,505 - lattice_notebook - INFO - Total Duration: 48.34 seconds
2025-10-16 00:11:17,506 - lattice_notebook - INFO - Total Prompts: 20
2025-10-16 00:11:17,507 - lattice_notebook - INFO - Successful: 20 (100

  0%|          | 0/20 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:11:20,976 - lattice_notebook - INFO - nDCG@10: 71.50; Recall@10: 67.68; Recall@100: 79.00; Recall@all: 79.00; Coverage: 102.00
2025-10-16 00:11:20,977 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:11:20,980 - lattice_notebook - INFO - -------------------- Iter 9 --------------------
2025-10-16 00:11:21,061 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:11:21,062 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch: 100%|████████████████████████████████████████████████████████████████| 20/20 [00:37<00:00,  1.87s/it, errors=0, active=0, completed=20, 429s=0, 503s=0]
2025-10-16 00:11:58,386 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:11:58,390 - lattice_notebook - INFO - Total Duration: 37.32 seconds
2025-10-16 00:11:58,391 - lattice_notebook - INFO - Total Prompts: 20
2025-10-16 00:11:58,392 - lattice_notebook - INFO - Successful: 20 (10

  0%|          | 0/20 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:12:01,734 - lattice_notebook - INFO - nDCG@10: 71.35; Recall@10: 67.68; Recall@100: 79.00; Recall@all: 79.00; Coverage: 121.80
2025-10-16 00:12:01,735 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:12:01,737 - lattice_notebook - INFO - -------------------- Iter 10 --------------------
2025-10-16 00:12:01,822 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:12:01,823 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch: 100%|████████████████████████████████████████████████████████████████| 20/20 [00:38<00:00,  1.92s/it, errors=0, active=0, completed=20, 429s=0, 503s=0]
2025-10-16 00:12:40,216 - lattice_notebook - INFO - BATCH PROCESSING SUMMARY REPORT
2025-10-16 00:12:40,219 - lattice_notebook - INFO - Total Duration: 38.39 seconds
2025-10-16 00:12:40,220 - lattice_notebook - INFO - Total Prompts: 20
2025-10-16 00:12:40,222 - lattice_notebook - INFO - Successful: 20 (1

  0%|          | 0/20 [00:00<?, ?it/s]

Updating samples:   0%|          | 0/10 [00:00<?, ?it/s]

2025-10-16 00:12:44,166 - lattice_notebook - INFO - nDCG@10: 73.07; Recall@10: 69.68; Recall@100: 83.00; Recall@all: 83.00; Coverage: 131.50
2025-10-16 00:12:44,169 - lattice_notebook - INFO - --------------------------------------------------
2025-10-16 00:12:44,171 - lattice_notebook - INFO - -------------------- Iter 11 --------------------
2025-10-16 00:12:44,268 - lattice_notebook - INFO - Running a batch of 20 prompts...
2025-10-16 00:12:44,269 - lattice_notebook - INFO - Concurrency limited to 20 parallel calls.
Processing batch:  30%|███████████████████▌                                             | 6/20 [00:16<00:39,  2.82s/it, errors=0, active=14, completed=6, 429s=0, 503s=0]


CancelledError: 

In [24]:
save_exp(RESULTS_DIR, hp, llm_api, all_eval_samples, all_eval_metric_dfs, allow_overwrite=True)

## Load results

In [7]:
all_eval_samples, all_eval_metric_dfs = load_exp(RESULTS_DIR, hp, semantic_root_node, node_registry, logger)

In [8]:
all_eval_metric_dfs[-1]

Unnamed: 0,nDCG@10,Recall@10,Recall@100,Recall@all,Coverage
0,69.921482,60.0,80.0,80.0,79
1,99.307832,100.0,100.0,100.0,202
2,71.493945,66.666667,100.0,100.0,142
3,100.0,100.0,100.0,100.0,73
4,0.0,0.0,0.0,0.0,41
5,95.502366,100.0,100.0,100.0,167
6,77.283783,36.842105,100.0,100.0,92
7,76.135694,100.0,100.0,100.0,187
8,100.0,100.0,100.0,100.0,129
9,41.039157,33.333333,50.0,50.0,203


## Debug / visualize

In [9]:
from utils import visualize_sample

In [10]:
# i = np.random.randint(len(all_eval_samples))
i = 0 # change this to visualize a different sample
step = hp.NUM_ITERS # or any step <= NUM_ITERS, tree will be shown up to this step (iteration)
visualize_sample(all_eval_samples[i], save_path='../visualize_sample.html', max_step=step)

Saved plot HTML to ../visualize_sample.html


## Additional Reranking (optional)

In [263]:
from prompts import get_reranking_prompt

def get_sample_rerank_prompt(sample):
    return get_reranking_prompt(sample.query, [x.desc for x, _ in sample.get_top_predictions(100, rel_fn=lambda x: x.combined_relevance)], hp=hp, logger=logger, topk=10)

def process_sample_rerank_response(sample, response):
    ranking = post_process(response, return_json=True)['ranking']
    top_preds = [x for x, _ in sample.get_top_predictions(100, rel_fn=lambda x: x.combined_relevance)]
    for rank, idx in enumerate(ranking):
        if hasattr(top_preds[idx], 'inverse_rank'):
            if isinstance(top_preds[idx].inverse_rank, float):
                top_preds[idx].inverse_rank = [top_preds[idx].inverse_rank]
            top_preds[idx].inverse_rank.append(1/(rank+1))
        else:
            top_preds[idx].inverse_rank = [1/(rank+1)]

In [270]:
eval_samples = all_eval_samples
all_rerank_prompts, all_rerank_constraints = list(zip(*[get_sample_rerank_prompt(sample) for sample in eval_samples]))
all_rerank_responses = await llm_api.run_batch(all_rerank_prompts, max_concurrent_calls=10, response_mime_type='application/json', response_schema=all_rerank_constraints[0])

2025-09-24 22:38:52,445 - lattice_runner - INFO - Running a batch of 108 prompts...
2025-09-24 22:38:52,447 - lattice_runner - INFO - Concurrency limited to 10 parallel calls.
Processing batch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 108/108 [07:53<00:00,  4.38s/it, errors=0, active=0, completed=108, 429s=0, 503s=0]
2025-09-24 22:46:45,822 - lattice_runner - INFO - BATCH PROCESSING SUMMARY REPORT
2025-09-24 22:46:45,824 - lattice_runner - INFO - Total Duration: 473.37 seconds
2025-09-24 22:46:45,827 - lattice_runner - INFO - Total Prompts: 108
2025-09-24 22:46:45,828 - lattice_runner - INFO - Successful: 108 (100.0%)
2025-09-24 22:46:45,828 - lattice_runner - INFO - Failed: 0
2025-09-24 22:46:45,829 - lattice_runner - INFO - Total Error Occurrences: 0
2025-09-24 22:46:45,830 - lattice_runner - INFO - 
RETRY STATISTICS:
2025-09-24 22:46:45,831 - lattice_runner - INFO -   1 attempt(s): 108 requests (100.0%)
2025-09-24 22:46:45,

In [271]:
for sample, response in zip(eval_samples, all_rerank_responses):
    try:
        process_sample_rerank_response(sample, response)
    except Exception as e:
        logger.error(f'Error processing rerank response for sample with query: {sample.query[:100]}.. Error: {e}')

In [272]:
rerank_rel_fn = lambda x: (np.mean(x.inverse_rank) if hasattr(x, 'inverse_rank') else 0, x.combined_relevance)
rerank_eval_metric_df = pd.DataFrame([sample.compute_eval_metrics(k=10, rel_fn=rerank_rel_fn) for sample in eval_samples])
logger.info('After reranking: '+'; '.join([f'{k}: {rerank_eval_metric_df[k].mean():.2f}' for k in rerank_eval_metric_df.columns]))

2025-09-24 22:46:47,717 - lattice_runner - INFO - After reranking: nDCG@10: 42.29; Recall@10: 48.40; Recall@100: 66.76; Recall@all: 70.16; Coverage: 261.72
