In [2]:
import json
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd
from tqdm import tqdm

from compare_activations_to_concepts import (
    load_concept_names,
)


def load_metadata(eval_set_dir: Path) -> Dict[str, Any]:
    """Load and return evaluation set metadata from JSON file."""
    with open(eval_set_dir / "metadata.json", "r") as f:
        return json.load(f)


def calculate_metrics(
    tp: np.ndarray,
    fp: np.ndarray,
    fn: np.ndarray,
    tp_per_domain: np.ndarray,
    positive_labels: np.ndarray,
    # positive_labels_per_domain: np.ndarray,
    concept_names: List[str],
    threshold_percents: List[float],
    # is_aa_concept_list: List[bool],
) -> pd.DataFrame:
    """
    Calculate precision, recall, and F1 scores for each concept-feature-threshold combination.

    Args:
        tp: True positives array (concepts x features x thresholds)
        fp: False positives array
        tp_per_domain: True positives per domain array
        positive_labels: Total positive labels per concept
        positive_labels_per_domain: Total positive labels per domain per concept
        concept_names: List of concept names
        threshold_percents: List of threshold percentages
        # is_aa_concept_list: Boolean list indicating if each concept is AA-level

    Returns:
        DataFrame containing calculated metrics for each combination
    """
    results = []

    # for concept_idx, concept in enumerate(concept_names):
    #     actual_positives = positive_labels[concept_idx]
    #     for feature in range(tp.shape[1]):
    #         for threshold_idx in range(len(threshold_percents)):
    #             this_tp = tp[concept_idx, feature, threshold_idx]
    #             if actual_positives > 0 and this_tp > actual_positives:
    #                 print(f"[ERROR] Concept: {concept}, TP: {this_tp}, Positives: {actual_positives}, Recall: {this_tp / actual_positives}")


    for concept_idx, concept in enumerate(concept_names):
        for feature in range(tp.shape[1]):
            for threshold_idx, threshold_pct in enumerate(threshold_percents):
                # Skip if no true positives
                if tp[concept_idx, feature, threshold_idx] == 0:
                    continue

                # Calculate tp, fp, precision and recall
                curr_tp = tp[concept_idx, feature, threshold_idx]
                curr_fp = fp[concept_idx, feature, threshold_idx]
                curr_fn = fn[concept_idx, feature, threshold_idx]
                precision = curr_tp / (curr_tp + curr_fp)
                recall = curr_tp / (curr_tp + curr_fn)
                # recall = (
                #     curr_tp / positive_labels[concept_idx]
                #     if positive_labels[concept_idx] > 0
                #     else 0
                # )  # we dont have FNs!!!

                # Calculate recall per domain for domain-level concepts or just
                # use recall if AA-level concept
                # if is_aa_concept_list[concept_idx]:
                #     recall_per_domain = recall
                # else:
                # recall_per_domain = (
                #     tp_per_domain[concept_idx, feature, threshold_idx]
                #     / positive_labels_per_domain[concept_idx]
                #     if positive_labels_per_domain[concept_idx] > 0
                #     else 0
                # )

                # Calculate F1 scores
                f1 = calculate_f1(precision, recall)
                # f1_per_domain = calculate_f1(precision, recall_per_domain)

                results.append(
                    {
                        "concept": concept,
                        "feature": feature,
                        "threshold_pct": threshold_pct,
                        "precision": precision,
                        "recall": recall,
                        "f1": f1,
                        "tp": curr_tp,
                        "fp": curr_fp,
                        "tp_per_domain": tp_per_domain[
                            concept_idx, feature, threshold_idx
                        ]
                    }
                )

    return pd.DataFrame(results)


def calculate_f1(precision: float, recall: float) -> float:
    """Calculate F1 score from precision and recall."""
    return (
        2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    )

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def combine_metrics_across_shards(
    eval_res_dir: Path,
    eval_set_dir: Path,
    threshold_percents: List[float] = [0, 0.15, 0.5, 0.6, 0.8],
) -> None:
    """
    Combine metrics across multiple evaluation shards and save results.

    Args:
        eval_res_dir: Directory containing evaluation results
        eval_set_dir: Directory containing evaluation set data
        threshold_percents: List of threshold percentages to evaluate
        shards_to_eval: Optional list of specific shards to evaluate
    """
    # Load concept information
    concept_names = load_concept_names(eval_set_dir / "gene_concepts_columns.txt")
    # is_aa_concept_list = [is_aa_level_concept(name) for name in concept_names]

    # Load metadata and get positive label counts
    metadata = load_metadata(eval_set_dir)
    positive_labels = np.array(metadata["n_positive_gene_per_concept"])
    # positive_labels_per_domain = np.array(metadata["n_positive_domains_per_concept"])

    # Use all shards if none specified
    shards_to_eval = metadata["shard_source"]

    # Initialize total counts
    tp_total = None
    fp_total = None
    fn_total = None
    tp_per_domain_total = None

    # Combine counts from all shards
    for shard in tqdm(shards_to_eval, desc="Combining shard counts"):
        shard_data = np.load(eval_res_dir / f"shard_{shard}_counts.npz")

        # For the first shard, initialize total counts with correct shapes
        if tp_total is None:
            tp_total = np.zeros(shard_data["tp"].shape)
            fp_total = np.zeros(shard_data["fp"].shape)
            fn_total = np.zeros(shard_data["fn"].shape)
            tp_per_domain_total = np.zeros(shard_data["tp_per_domain"].shape)

        tp_total += shard_data["tp"]
        fp_total += shard_data["fp"]
        fn_total += shard_data["fn"]
        tp_per_domain_total += shard_data["tp_per_domain"]

    # Calculate and save metrics
    print("Calculating F1 scores...")
    metrics_df = calculate_metrics(
        tp_total,
        fp_total,
        fn_total,
        tp_per_domain_total,
        np.array(positive_labels).sum(axis=0),
        # np.array(positive_labels_per_domain).sum(axis=0),
        concept_names,
        threshold_percents
    )

    output_path = eval_set_dir / "concept_f1_scores.csv"
    metrics_df.to_csv(output_path, index=False)
    print(f"Metrics saved to {output_path}")

In [5]:
eval_res_dir = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output')
eval_set_dir = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output/valid')
combine_metrics_across_shards(eval_res_dir, eval_set_dir)

Combining shard counts: 100%|██████████| 3/3 [00:12<00:00,  4.29s/it]


Calculating F1 scores...
Metrics saved to /maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output/valid/concept_f1_scores.csv


In [6]:
eval_res_dir = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output')
eval_set_dir = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output/test')
combine_metrics_across_shards(eval_res_dir, eval_set_dir)

Combining shard counts: 100%|██████████| 2/2 [00:09<00:00,  4.94s/it]


Calculating F1 scores...
Metrics saved to /maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output/test/concept_f1_scores.csv


In [26]:
df = pd.read_csv('/maiziezhou_lab2/yunfei/Projects/FM_temp/interGFM/output/test/concept_f1_scores.csv')

# for acts

In [7]:
eval_res_dir = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output_acts')
eval_set_dir = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output_acts/valid')
combine_metrics_across_shards(eval_res_dir, eval_set_dir)

Combining shard counts: 100%|██████████| 3/3 [00:02<00:00,  1.43it/s]


Calculating F1 scores...
Metrics saved to /maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output_acts/valid/concept_f1_scores.csv


In [4]:
eval_res_dir = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output_acts')
eval_set_dir = Path('/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output_acts/test')
combine_metrics_across_shards(eval_res_dir, eval_set_dir)

Combining shard counts: 100%|██████████| 2/2 [00:01<00:00,  1.25it/s]


Calculating F1 scores...
Metrics saved to /maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/output_acts/test/concept_f1_scores.csv
