# Measurement Encoding Analysis

In [None]:
import sys
import os

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(os.path.join(parent_dir, "src"))

In [None]:
from models.clmbr_t_base import get_tokenizer
from femr.models.tokenizer import FEMRTokenizer
from datasets import Dataset
from femr.ontology import Ontology
import polars as pl
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import pandas as pd
from tabulate import tabulate
import json
from file_paths import MAPPING_DIR, ATHENA_PATH, MIMIC_MEDS_DIR

In [None]:
ontology = Ontology(ATHENA_PATH)

## Analyze Tokenizer
We create some stats and plots to better understand what information the tokenizer can transfer into tokens.

In [None]:
tokenizer = get_tokenizer(None)

In [None]:
simplified_text_keys = set()
text_counts = defaultdict(int)

for code, text in tokenizer.string_lookup.keys():
    simplified_text_keys.add(code)
    text_counts[text] += 1

print(len(simplified_text_keys))
print(len(tokenizer.string_lookup))

Calculate number of tokens

In [None]:
num_text_tokens = len(tokenizer.string_lookup)
num_code_tokens = len(tokenizer.code_lookup)
num_numeric_tokens = sum(len(value) for value in tokenizer.numeric_lookup.values())

print(f"Text Tokens: {num_text_tokens}")
print(f"Code Tokens: {num_code_tokens}")
print(f"Numeric Tokens: {num_numeric_tokens}")
print(f"Total tokens: {num_text_tokens + num_code_tokens + num_numeric_tokens}")

To find out why tokens are missing we dig deeper into the vocab dictionary and find out that there is a huge amount of missing tokens.

In [None]:
types = set()

for item in tokenizer.dictionary["vocab"]:
    types.add(item["type"])
types

In [None]:
token_stats = {
    "code": 0,
    "text": 0,
    "numeric": 0,
    "unused": 0
}

for item in tokenizer.dictionary["vocab"]:
    token_stats[item["type"]] += 1

token_stats

In [None]:
total_tokens = sum(token_stats.values())

# Data with first letters capitalized
token_stats_capitalized = {key.capitalize(): value for key, value in token_stats.items()}
percentages_capitalized = {key.capitalize(): (value / total_tokens) * 100 for key, value in token_stats.items()}

# Plotting
fig, ax = plt.subplots(figsize=(10, 6))

bars = ax.bar(token_stats_capitalized.keys(), token_stats_capitalized.values(), color=['skyblue', 'lightgreen', 'lightcoral', 'lightsalmon'])

ax.set_xlabel('Measurement Encoding Type', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.grid(axis='y', linestyle='--', alpha=0.7)

# Adding percentages on top of the bars
for bar, percent in zip(bars, percentages_capitalized.values()):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width() / 2, height, f'{percent:.1f}%', ha='center', va='bottom', fontsize=10)

# Displaying the plot
plt.tight_layout()
plt.savefig(os.path.join(MAPPING_DIR, "tokenizer_type_dist.png"))
plt.show()


In [None]:
# Convert defaultdict to DataFrame
text_count_df = pd.DataFrame(list(text_counts.items()), columns=['Text', 'Count'])

# Sort the DataFrame by 'Count' in descending order
df_sorted = text_count_df.sort_values(by='Count', ascending=False)

# Display the DataFrame
print(df_sorted)

# Save DataFrame as LaTeX table
latex_table = tabulate(df_sorted[0:10], headers='keys', tablefmt='latex', showindex=False)

# Save the LaTeX table to a .tex file
with open(os.path.join(MAPPING_DIR, 'top_10_text_counts_table.tex'), 'w') as file:
    file.write(latex_table)

print("LaTeX table has been saved to 'text_counts_table.tex'")

In [None]:
def get_tokenizer_match(measurement):
    if measurement.get("numeric_value") is not None:
        for start, end, i in tokenizer.numeric_lookup.get(measurement["code"], []):
            if start <= measurement["numeric_value"] < end:
                return "numeric", i
        else:
            return None, None
    elif measurement.get("text_value") is not None:
        value = tokenizer.string_lookup.get((measurement["code"], measurement["text_value"]))
        if value is not None:
            return "text", value
        else:
            return None, None
    else:
        value = tokenizer.code_lookup.get(measurement["code"])
        if value is not None:
            return "code", value
        else:
            return None, None

In [None]:
with open("/home/niclas/Dokumente/thesis-daten/mapping-metadata/metadata.json", "r") as file:
    data_metadata = json.load(file)
code_metadata = data_metadata["code_metadata"]


In [None]:
def combine_stats(stats1: dict, stats2: dict) -> dict:
    result = defaultdict(int)
    

    for key, value in stats1.items():
        result[key] += value

    for key, value in stats2.items():
        result[key] += value

    return dict(result)

def add_sets(set1, set2):
    result = set1.copy()
    result.update(set2)
    return result

def combine_extended_stats(stats1: dict, stats2: dict) -> dict:
    return {
        "parent_tokens": add_sets(stats1["parent_tokens"], stats2["parent_tokens"]),
        "direct_tokens": add_sets(stats1["direct_tokens"], stats2["direct_tokens"]),
        "mapping_tokens": add_sets(stats1["mapping_tokens"], stats2["mapping_tokens"]),
        "mapping_parent_tokens": add_sets(stats1["mapping_parent_tokens"], stats2["mapping_parent_tokens"]),
        "total_counts": combine_stats(stats1["total_counts"], stats2["total_counts"]),
    }

In [None]:
def extract_code_stats(batch):
    direct_ids = set()
    parent_ids = set()
    extended_ids = set()
    extended_parent_ids = set()
    stats = defaultdict(int)
    for events in batch["events"]:
        for event in events:
            for measurement in event["measurements"]:
                code = measurement["code"]
                stats["total"] += 1
                direct_match, id = get_tokenizer_match(measurement)
                if direct_match is not None:
                    direct_ids.add(id)
                    stats["direct_" + direct_match] += 1
                    continue
                parents = ontology.get_all_parents(code)
                found_parent_match = False
                for parent in parents:
                    if parent == code:
                        continue
                    measurement["code"] = parent
                    parent_match, id = get_tokenizer_match(measurement)
                    if parent_match is not None:
                        parent_ids.add(id)
                        found_parent_match = True
                        stats["parent_" + parent_match] += 1
                        break
                if not found_parent_match:
                    mapping = code_metadata.get(code, None)
                    if mapping is not None:
                        found_mapping_match = False
                        for parent in mapping.get("parent_codes", []):
                            measurement["code"] = parent
                            parent_match, id = get_tokenizer_match(measurement)
                            if parent_match is not None:
                                found_mapping_match = True
                                extended_ids.add(id)
                                stats["mapping_" + parent_match] += 1
                                break
                        if not found_mapping_match:
                            found_mapping_parent_match = False
                            for mapping_parent in mapping.get("parent_codes", []):
                                parents = ontology.get_all_parents(mapping_parent)
                                for parent in parents:
                                    if parent == code:
                                        continue
                                    measurement["code"] = parent
                                    parent_match, id = get_tokenizer_match(measurement)
                                    if parent_match is not None:
                                        extended_parent_ids.add(id)
                                        found_mapping_parent_match = True
                                        stats["mapping_parent_" + parent_match] += 1
                                        break
                            if not found_mapping_parent_match:
                                stats["no_match"] += 1
    return {
        "parent_tokens": parent_ids,
        "direct_tokens": direct_ids,
        "mapping_tokens": extended_ids,
        "mapping_parent_tokens": extended_parent_ids,
        "total_counts": dict(stats)
        }

In [None]:
from datasets import Dataset
from femr.hf_utils import aggregate_over_dataset

dataset = Dataset.from_parquet("/home/niclas/Dokumente/thesis-daten/mimic_meds_2.2/data/*")

In [None]:
code_stats = aggregate_over_dataset(dataset, extract_code_stats, combine_extended_stats, 25, 6)

In [None]:
code_tokens = set(tokenizer.code_lookup.values())
numeric_tokens = set()
for value in tokenizer.numeric_lookup.values():
    for _,_, id in value:
        numeric_tokens.add(id)
text_tokens = set(tokenizer.string_lookup.values())

In [None]:
def get_encoding_type_distributions(tokens):
    code_count = 0
    numeric_count = 0
    text_count = 0
    for token in tokens:
        if token in code_tokens:
            code_count += 1
        elif token in numeric_tokens:
            numeric_count += 1
        elif token in text_tokens:
            text_count += 1
    return text_count, numeric_count, code_count

In [None]:
# Categories and counts
categories = ['Text', 'Numeric', 'Code']
possible_matches = [2961, 11183, 25667]
direct_matches = list(get_encoding_type_distributions(code_stats["direct_tokens"]))
parent_matches = list(get_encoding_type_distributions(code_stats["parent_tokens"]))
mapping_matches = list(get_encoding_type_distributions(code_stats["mapping_tokens"]))
mapping_parent_matches = list(get_encoding_type_distributions(code_stats["mapping_parent_tokens"]))

# Calculate the sum of all matches except possible matches
sum_matches = [d + p + m + mp for d, p, m, mp in zip(direct_matches, parent_matches, mapping_matches, mapping_parent_matches)]

x = np.arange(len(categories))  # Label locations
width = 0.15  # Width of the bars

fig, ax = plt.subplots(figsize=(10, 6))

# Modern color palette
colors = {
    'sum': '#FFAA66',  # Updated modern color for sum (light orange)
    'possible': '#A0A0A0',  # Updated modern color for possible matches (gray)
    'direct': '#1f77b4',  # Blue tone for direct matches
    'parent': '#4A90E2',  # Lighter blue tone for parent matches
    'mapping': '#2ca02c',  # Green tone for mapping matches
    'mapping_parent': '#8DCC85'  # Lighter green tone for mapping parent matches
}

# Bars for sums of all matches (except possible matches), slightly overlapping with possible matches
bars_sums = ax.bar(x, sum_matches, width, label='Total Matches', color=colors['sum'])

# Bars for possible matches
bars_possible = ax.bar(x, possible_matches, width, label='Possible Matches', color=colors['possible'], alpha=0.3)

# Display sum values next to the sum bar
for i, bar in enumerate(bars_sums):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width() / 2.0, height, f'{sum_matches[i]}', ha='center', va='bottom', color='black', fontsize=10, fontweight='bold')

# Bars for direct matches
ax.bar(x + width, direct_matches, width, label='Direct Matches', color=colors['direct'])

# Bars for parent matches
ax.bar(x + 2 * width, parent_matches, width, label='Parent Matches', color=colors['parent'])

# Bars for mapping matches
ax.bar(x + 3 * width, mapping_matches, width, label='Mapping Matches', color=colors['mapping'])

# Bars for mapping parent matches
ax.bar(x + 4 * width, mapping_parent_matches, width, label='Mapping Parent Matches', color=colors['mapping_parent'])

# Add labels, title, and legend
ax.set_xlabel('Encoding Type')
ax.set_ylabel('Encoding Type Count')
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.legend()

# Display the plot
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.savefig(os.path.join(MAPPING_DIR, "matches.pdf"))
plt.show()


In [None]:
from utility.data import load_dataset

dataset = load_dataset("/home/niclas/Dokumente/cluster_data/correct_reduced_cohort/cohort")



In [None]:
from models.univeral_tokenizer import UniversalTokenizer

clmbr_t_base_tokenizer = get_tokenizer(None)
clmbr_t_mimic_tokenizer = FEMRTokenizer.from_pretrained("/home/niclas/Dokumente/cluster_data/pretraining_mimic/fm")
clmbr_t_ime_tokenizer = UniversalTokenizer.from_pretrained("/home/niclas/Dokumente/cluster_data/pretraining_lab/fm")

In [None]:
def get_tokenizer_match(measurement, tokenizer):
    ids, _ = tokenizer.get_feature_codes(None, measurement)
    return len(ids) > 0

In [None]:
import pickle
with open("/home/niclas/Dokumente/cluster_data/adjusted_mapping_reduced_cohort/ontology.pkl", "rb") as f:
    ontology = pickle.load(f)

In [None]:
def get_correct_tokenizer(type):
    if type == "base":
        return clmbr_t_base_tokenizer
    elif type == "mimic":
        return clmbr_t_mimic_tokenizer
    elif type == "ime":
        return clmbr_t_ime_tokenizer

def extract_mec_stats(batch):
    stats = {
        "base": {
            "total": 0,
            "direct": 0,
            "parent": 0,
        },
        "mimic": {
            "total": 0,
            "direct": 0,
            "parent": 0,
        },
        "ime": {
            "total": 0,
            "direct": 0,
            "parent": 0,
        }
    }
    for events in batch["events"]:
        for event in events:
            for measurement in event["measurements"]:
                for type in ["base", "mimic", "ime"]:
                    tokenizer = get_correct_tokenizer(type)
                    stats[type]["total"] += 1
                    if get_tokenizer_match(measurement, tokenizer):
                        stats[type]["direct"] += 1
                        continue
                    parents = ontology.get_all_parents(measurement["code"])
                    for parent in parents:
                        if parent == measurement["code"]:
                            continue
                        measurement["code"] = parent
                        if get_tokenizer_match(measurement, tokenizer):
                            stats[type]["parent"] += 1
                            break
    return stats

def combine_mec_stats(stats1: dict, stats2: dict) -> dict:
    return {
        type: combine_stats(stats1[type], stats2[type]) for type in ["base", "mimic", "ime"]

    }

In [None]:
encoding_stats = aggregate_over_dataset(dataset, extract_mec_stats, combine_mec_stats, 50, 8)

In [None]:
for type in ["base", "mimic", "ime"]:
    mec = encoding_stats[type]["direct"] / encoding_stats[type]["total"]
    mec_o = (encoding_stats[type]["direct"] + 0.5 * encoding_stats[type]["parent"]) / encoding_stats[type]["total"]
    print(f"Type: {type}")
    print(f"MEC: {mec}")
    print(f"MEC_O: {mec_o}")