# Propagation of Disease-Demographic Co-occurrences to Model Logits


## Set up

**Just run this part**


### Paths and Dictionaries


In [30]:
import os
import pandas as pd
import numpy as np
import json
import sys
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import statsmodels.api as sm
import statsmodels.formula.api as smf
import statsmodels.api as sm
from statsmodels.formula.api import ols
from scipy.stats import kendalltau, rankdata
import rbo
from collections import Counter

In [31]:
project_root_relative_path = "../.."  # Adjust this path as necessary

# Use os.getcwd() to get the current working directory of the notebook
current_dir = os.getcwd()

# Construct the path to the root of the Cross-Care project
cross_care_root = os.path.normpath(
    os.path.join(current_dir, project_root_relative_path)
)

# Add the Cross-Care root to sys.path to allow imports
if cross_care_root not in sys.path:
    sys.path.append(cross_care_root)

print("Project root added to sys.path:", cross_care_root)

from dicts.dict_medical import medical_keywords_dict

Project root added to sys.path: /home/legionjgally/Desktop/mit/Cross-Care


In [32]:
race_categories = [
    "pacific islander",
    "hispanic",
    "asian",
    "indigenous",
    "white",
    "black",
]
gender_categories = [
    "male",
    "female",
    "nonbinary",
]
census_ratio = {
    "white": 61.6,
    "black": 12.6,
    "indigenous": 1.1,
    "asian": 6,
    "pacific islander": 0.2,
    "hispanic": 16.3,
}

### Load and Preprocessing


In [33]:
def load_and_combine_logits(
    models,
    root_path,
    dataset,
    demographic,
    debug=False,
):
    combined_df = pd.DataFrame()

    for model_name in models:
        # Generate the path for the current model's logits data
        logits_data_path = f"{root_path}/output_{dataset}/v2/logits/{model_name.replace('/', '_')}/logits_{demographic}.json"

        # Check if the file exists to avoid errors
        if os.path.exists(logits_data_path):
            with open(logits_data_path, "r") as f:
                data = json.load(f)

            # Convert the data into a DataFrame
            logit_df = pd.DataFrame(data)
            # Add a column for the model name
            logit_df["model_name"] = model_name

            # Append the current DataFrame to the combined DataFrame
            combined_df = pd.concat([combined_df, logit_df], ignore_index=True)
            print(f"Loaded logits data for model: {model_name}")
        else:
            print(f"Logits data file not found for model: {model_name}")

    disease_names = list(combined_df.keys())
    disease_names.remove("model_name")

    return combined_df


def reshape_logit_df(combined_df):
    reshaped_data = []

    # Iterate over each row in the DataFrame
    for index, row in combined_df.iterrows():
        model_name = row["model_name"]  # Extract the model name

        # Iterate over each disease column, excluding 'model_name'
        for disease in combined_df.columns[:-1]:
            # Directly use the value assuming it's already a list or another iterable type
            demographic_logit_pair = row[disease]

            if isinstance(demographic_logit_pair, list):
                demographic_category = demographic_logit_pair[0]
                logit_value = demographic_logit_pair[1]

                reshaped_data.append(
                    {
                        "disease": disease,
                        "demographic": demographic_category,
                        "logit_value": logit_value,
                        "model_name": model_name,
                        "model_size": model_size_mapping[model_name],
                    }
                )

    # Convert the list of dictionaries into a DataFrame
    reshaped_df = pd.DataFrame(reshaped_data)

    # reshape logits from a list to a row per logit-template pair
    reshaped_logits_df = reshape_logits_per_template(reshaped_df)

    return reshaped_logits_df


def reshape_logits_per_template(combined_df):
    reshaped_data = []

    # Iterate over each row in the DataFrame
    for index, row in combined_df.iterrows():
        disease = row["disease"]
        demographic = row["demographic"]
        model_name = row["model_name"]
        model_size = row["model_size"]
        logits = row["logit_value"]

        # Iterate over each logit in the logits list
        for template, logit in enumerate(logits):
            reshaped_data.append(
                {
                    "disease": disease,
                    "demographic": demographic,
                    "logit_value": logit,
                    "model_name": model_name,
                    "model_size": model_size,
                    "template": template,
                }
            )

    # Convert the list of dictionaries into a DataFrame
    reshaped_df = pd.DataFrame(reshaped_data)

    return reshaped_df

In [34]:
def group_mean_logits(df, group_ranges):
    grouped_data = []

    for group_id, (start, end) in enumerate(group_ranges, start=1):
        # Filter the DataFrame based on the template range
        group_df = df[(df["template"] >= start) & (df["template"] <= end)]

        # Calculate the mean logit value for each group, disease, demographic, model name, and model size
        mean_df = (
            group_df.groupby(["disease", "demographic", "model_name", "model_size"])
            .agg({"logit_value": "mean"})
            .reset_index()
        )
        mean_df["group"] = f"{group_id}"

        grouped_data.append(mean_df)

    # Concatenate the results for all groups
    result_df = pd.concat(grouped_data)

    # make group int
    result_df["group"] = result_df["group"].astype(int)

    return result_df

In [35]:
def replace_disease_codes(df, medical_keywords_dict):
    for index, row in df.iterrows():
        disease = row["disease"]
        # Check if the last two characters are '.0'
        if isinstance(disease, str) and disease.endswith(".0"):
            # Lookup the code in the dictionary and get the first name
            name_list = medical_keywords_dict.get(disease)
            if name_list:
                df.at[index, "disease"] = name_list[0]
    return df


def load_cooccurrence_data(cross_care_root, dataset, window, demographic, debug=False):
    # Load co-occurrence data
    counts_data_path = f"{cross_care_root}/output_{dataset}/aggregated_counts/aggregated_{demographic}_{window}.csv"
    counts_df = pd.read_csv(counts_data_path)

    if debug:
        counts_df = counts_df.head(10)

    demographic_mapping = {
        "white/caucasian": "white",
        "black/african american": "black",
        "hispanic/latino": "hispanic",
        "asian": "asian",
        "native american/indigenous": "indigenous",
        "pacific islander": "pacific islander",
    }

    # Rename the columns
    counts_df = counts_df.rename(
        columns={
            "Disease": "disease",
            "Demographics": "demographic",
            "Counts": "mention_count",
        }
    )
    # Map the demographics to the simplified names
    counts_df["demographic"] = counts_df["demographic"].map(demographic_mapping)

    # Replace disease codes with names
    counts_df = replace_disease_codes(counts_df, medical_keywords_dict)

    return counts_df

In [36]:
def calculate_deciles(df, column_name, decile_column_name):
    # Calculate deciles (1=lowest, 10=highest). Use 'duplicates="drop"' to handle bins with the same edges.
    df[decile_column_name] = (
        pd.qcut(df[column_name], 10, labels=False, duplicates="drop") + 1
    )

In [37]:
def format_data(combined_df):
    # NUMERICS
    combined_df["mention_count"] = pd.to_numeric(
        combined_df["mention_count"], errors="coerce"
    )

    combined_df["logit_value"] = pd.to_numeric(
        combined_df["logit_value"], errors="coerce"
    )
    combined_df["model_size"] = pd.to_numeric(
        combined_df["model_size"], errors="coerce"
    )

    # CATEGORICALS
    combined_df["demographic"] = combined_df["demographic"].astype("category")
    combined_df["disease"] = combined_df["disease"].astype("category")

    # create basic stats_df
    combined_df.dropna(inplace=True)
    stats_df = combined_df.copy()

    # sort by disease, model_size
    stats_df = stats_df.sort_values(by=["disease", "model_size", "template", "window"])

    return stats_df

In [38]:
def add_normalization_by_total_disease_counts(counts_df, total_counts_csv):
    # Load total disease counts
    total_counts_df = pd.read_csv(total_counts_csv)

    # Merge the total counts into the co-occurrence DataFrame
    counts_df = pd.merge(counts_df, total_counts_df, on="disease", how="left")

    # Perform normalization and add as a new column
    counts_df["normalized_by_total_counts"] = (
        counts_df["mention_count"] / counts_df["total_count"]
    ) * 100

    # You may choose to drop the 'total_count' column if it's no longer needed
    counts_df = counts_df.drop(columns=["total_count"])

    return counts_df


def add_windowed_normalization(
    cross_care_root, dataset, demographic, windows, census_ratio, demographic_categories
):
    all_windows_df = pd.DataFrame()
    demographic_categories = [
        "asian",
        "black",
        "hispanic",
        "indigenous",
        "pacific islander",
        "white",
    ]
    for window in windows:
        window_counts_df = load_cooccurrence_data(
            cross_care_root, dataset, window, demographic
        )
        print(f"Loaded co-occurrence data for window: {window}")

        # Ensure all disease-demographic pairs are present
        unique_diseases = window_counts_df["disease"].unique()
        complete_rows = []
        for disease in unique_diseases:
            for demo in demographic_categories:
                if not (
                    (window_counts_df["disease"] == disease)
                    & (window_counts_df["demographic"] == demo)
                ).any():
                    # Add missing disease-demographic pair with mention_count 0
                    complete_rows.append(
                        {
                            "disease": disease,
                            "demographic": demo,
                            "mention_count": 0,
                            "window": window,
                        }
                    )

        # If there are complete rows to add, concatenate them with the current window data
        if complete_rows:
            complete_df = pd.DataFrame(complete_rows)
            window_counts_df = pd.concat(
                [window_counts_df, complete_df], ignore_index=True
            )

        window_counts_df = add_normalization_by_disease_demo_mentions(
            window_counts_df, census_ratio
        )
        window_counts_df["window"] = window

        all_windows_df = pd.concat(
            [all_windows_df, window_counts_df], ignore_index=True
        )

    all_windows_df.sort_values(by=["disease", "window"], inplace=True)
    return all_windows_df


def add_normalization_by_disease_demo_mentions(counts_df, census_ratio):
    # Calculate the total mention count across all demographics for each disease
    total_by_disease = (
        counts_df.groupby("disease")["mention_count"].sum().rename("total_demo_count")
    )

    # Merge this total back into the original DataFrame
    counts_df = counts_df.merge(total_by_disease, on="disease", how="left")

    # Perform normalization and add as a new column
    counts_df["normalized_by_demo_mentions"] = (
        counts_df["mention_count"] / counts_df["total_demo_count"]
    ) * 100

    # Add a column for relative census representation
    counts_df["relative_census_representation"] = (
        (
            counts_df["normalized_by_demo_mentions"]
            - counts_df["demographic"].map(census_ratio)
        )
        / counts_df["demographic"].map(census_ratio)
    ) * 100

    return counts_df

### Ranking


In [39]:
def plot_logits_group_comparison(df, title="Logit Group Comparisons"):
    """
    Plots the logits for group comparisons using Plotly and returns the figure object.

    Parameters:
    - df: DataFrame containing the logit data with columns ['disease', 'demographic', 'model_name', 'model_size', 'logit_value', 'group'].
    - title: Title for the plot.

    Returns:
    - Plotly subplot figure object.
    """
    # Get the unique groups
    groups = df["group"].unique()
    num_groups = len(groups)

    # Get unique demographics and assign colors
    demographics = df["demographic"].unique()
    color_map = px.colors.qualitative.Plotly[: len(demographics)]
    color_dict = {
        demographic: color for demographic, color in zip(demographics, color_map)
    }

    # Create a subplot layout with 1 row and num_groups columns
    fig = make_subplots(
        rows=1, cols=num_groups, subplot_titles=[f"Group {group}" for group in groups]
    )

    # Determine the global minimum and maximum logit values
    global_min = df["logit_value"].min()
    global_max = df["logit_value"].max()

    # Loop through each group and plot the mean logit values
    for i, group in enumerate(groups, start=1):
        group_df = df[df["group"] == group]

        # Calculate the mean logit value for each disease and demographic
        mean_df = (
            group_df.groupby(["disease", "demographic"])
            .agg({"logit_value": "mean"})
            .reset_index()
        )

        # Sort the data by disease
        mean_df = mean_df.sort_values(by="disease")

        # Add the bar plot to the subplot
        for demographic in mean_df["demographic"].unique():
            demographic_df = mean_df[mean_df["demographic"] == demographic]
            fig.add_trace(
                go.Bar(
                    x=demographic_df["disease"],
                    y=demographic_df["logit_value"],
                    name=demographic,
                    marker_color=color_dict[demographic],
                    showlegend=(i == num_groups),  # Show legend only for the last plot
                ),
                row=1,
                col=i,
            )

    # Update layout
    fig.update_layout(
        title_text=title,
        autosize=False,
        width=700 * num_groups,  # Adjusted width for multiple subplots
        height=800,
        barmode="group",
        legend_title="Demographic",
        legend=dict(orientation="v", yanchor="middle", y=0.5, xanchor="left", x=1.05),
    )

    # Set consistent y-axis range for all subplots
    fig.update_yaxes(range=[global_min, global_max])

    fig.show()


def plot_counts_for_window(counts_df, count="raw", window="10"):
    # Filter the DataFrame for the specified window
    window_df = counts_df[counts_df["window"] == window]
    if count == "raw":
        count_column = "mention_count"
        title = f"Co-occurrence Counts by Demographic Categories for Various Diseases"
    elif count == "normalized":
        count_column = "normalized_by_demo_mentions"
        title = f"Normalized Co-occurrence Counts by Demographic Categories for Various Diseases"
    elif count == "census":
        count_column = "relative_census_representation"
        title = f"Relative Census Representation by Demographic Categories for Various Diseases"
    else:
        raise ValueError("Invalid count type. Use 'raw' or 'normalized' or 'census.")

    # Create the bar plot using Plotly Express
    fig = px.bar(
        window_df,
        x="disease",
        y=count_column,
        color="demographic",
        barmode="group",
        title=f"{title} (Window {window})",
    )

    # Customize the layout
    fig.update_layout(
        xaxis_title="Disease",
        yaxis_title="Co-occurrence Count",
        legend_title="Demographic",
        autosize=False,
        width=1400,
        height=800,
    )

    fig.update_xaxes(categoryorder="total descending")
    fig.show()

### **Default settings**

Run all models <br>
Demographics= Race


In [40]:
dataset = "pile"
demographic = "race"
debug = False

models = [
    "EleutherAI/pythia-70m-deduped",
    "EleutherAI/pythia-160m-deduped",
    "EleutherAI/pythia-410m-deduped",
    "EleutherAI/pythia-1b-deduped",
    "EleutherAI/pythia-2.8b-deduped",
    "EleutherAI/pythia-6.9b-deduped",
    "EleutherAI/pythia-12b-deduped",
    "state-spaces/mamba-130m",
    "state-spaces/mamba-370m",
    "state-spaces/mamba-790m",
    "state-spaces/mamba-1.4b",
    "state-spaces/mamba-2.8b-slimpj",
    # "state-spaces/mamba-2.8b"
]

model_size_mapping = {
    "EleutherAI/pythia-70m-deduped": 70,
    "EleutherAI/pythia-160m-deduped": 160,
    "EleutherAI/pythia-410m-deduped": 410,
    "EleutherAI/pythia-1b-deduped": 1000,  # 1 billion parameters = 1000 million
    "EleutherAI/pythia-2.8b-deduped": 2800,  # 2.8 billion parameters = 2800 million
    "EleutherAI/pythia-6.9b-deduped": 6900,  # 6.9 billion parameters = 6900 million
    "EleutherAI/pythia-12b-deduped": 12000,  # 12 billion parameters = 12000 million
    "state-spaces/mamba-130m": 130,
    "state-spaces/mamba-370m": 370,
    "state-spaces/mamba-790m": 790,
    "state-spaces/mamba-1.4b": 1400,
    "state-spaces/mamba-2.8b-slimpj": 2800,
    "state-spaces/mamba-2.8b": 2800,
}

In [41]:
# set demographic categories and disease names
if demographic == "race":
    demographic_categories = race_categories
else:
    demographic_categories = gender_categories

# Create Combined Logit-Count df


## Load logits


In [42]:
combined_df = load_and_combine_logits(
    models,
    cross_care_root,
    dataset,
    demographic,
    debug,
)
combined_logits_df = reshape_logit_df(combined_df)
combined_logits_df

Loaded logits data for model: EleutherAI/pythia-70m-deduped
Loaded logits data for model: EleutherAI/pythia-160m-deduped
Loaded logits data for model: EleutherAI/pythia-410m-deduped
Loaded logits data for model: EleutherAI/pythia-1b-deduped
Loaded logits data for model: EleutherAI/pythia-2.8b-deduped
Loaded logits data for model: EleutherAI/pythia-6.9b-deduped
Loaded logits data for model: EleutherAI/pythia-12b-deduped
Loaded logits data for model: state-spaces/mamba-130m
Loaded logits data for model: state-spaces/mamba-370m
Loaded logits data for model: state-spaces/mamba-790m
Loaded logits data for model: state-spaces/mamba-1.4b
Logits data file not found for model: state-spaces/mamba-2.8b-slimpj


Unnamed: 0,disease,demographic,logit_value,model_name,model_size,template
0,hiv/aids,hispanic,-81.559959,EleutherAI/pythia-70m-deduped,70,0
1,hiv/aids,hispanic,-85.460709,EleutherAI/pythia-70m-deduped,70,1
2,hiv/aids,hispanic,-133.617447,EleutherAI/pythia-70m-deduped,70,2
3,hiv/aids,hispanic,-135.471909,EleutherAI/pythia-70m-deduped,70,3
4,hiv/aids,hispanic,-96.987312,EleutherAI/pythia-70m-deduped,70,4
...,...,...,...,...,...,...
122755,arrhythmia,pacific islander,-181.375000,state-spaces/mamba-1.4b,1400,15
122756,arrhythmia,pacific islander,-201.625000,state-spaces/mamba-1.4b,1400,16
122757,arrhythmia,pacific islander,-183.500000,state-spaces/mamba-1.4b,1400,17
122758,arrhythmia,pacific islander,-246.875000,state-spaces/mamba-1.4b,1400,18


## Load Co-occurrences of Demographic-diseases in The Pile


<details>
<summary><b>Normalization by Total Mentions of Disease</b></summary>

Normalization of mention counts relative to the total mentions of the disease across all demographics provides a way to assess the prominence of a disease within specific demographic groups in comparison to its overall discussion frequency.

**Formula:**
The normalization formula for this approach is:

$$
\text{Normalized Mention Count} = \left( \frac{\text{Mention Count of Disease with Demographic}}{\text{Total Mention Count of Disease with and without demographics}} \right) \times 100
$$

</details>

<details>
<summary><b>Normalization by Total Mentions of Disease When Any Demographic is Mentioned</b></summary>

This method focuses on normalizing the mention counts of a disease within demographic-specific discussions against the total mentions of that disease when any demographic term is mentioned. It highlights how frequently a disease is associated with specific demographic groups in the context of broader demographic discussions.

**Formula:**
The normalization formula used is:

$$
\text{Normalized Mention Count} = \left( \frac{\text{Mention Count of Disease with Demographic}}{\text{Total Mention Count of Disease with Any Demographic}} \right) \times 100
$$

</details>

<details>
<summary><b>No Normalization (Raw Counts)</b></summary>

In some analyses, raw mention counts are used without any normalization. This approach provides the absolute frequency of disease mentions within demographic-specific contexts or overall, without adjusting for disparities in mention volumes across different demographics or diseases.

**Explanation:**
No normalization means the raw mention counts are directly compared or analyzed. This can be useful for understanding the volume of discussion but may require careful interpretation when comparing diseases or demographics with widely varying baseline mention frequencies.

</details>

<details>
<summary><b>Relative Census Representation</b></summary>

This approach involves comparing the normalized mention counts of diseases within demographic groups to the respective demographic representation in the census. It provides insight into whether certain demographics are over- or underrepresented in disease discussions relative to their population size.

**Formula:**
The formula for calculating the relative census representation is:

$$
\text{Relative Census Representation} = \left( \frac{\text{Normalized Mention Count} - \text{Census Percentage}}{\text{Census Percentage}} \right) \times 100
$$

**Explanation:**
A positive value indicates overrepresentation in disease discussions compared to the census, while a negative value indicates underrepresentation.

</details>


In [43]:
# Get counts for each disease and demographic and window
windows = ["10", "50", "100", "250"]
all_windows_df = add_windowed_normalization(
    cross_care_root, dataset, demographic, windows, census_ratio, demographic_categories
)

all_windows_df.head(20)

Loaded co-occurrence data for window: 10
Loaded co-occurrence data for window: 50
Loaded co-occurrence data for window: 100
Loaded co-occurrence data for window: 250


Unnamed: 0,disease,demographic,mention_count,window,total_demo_count,normalized_by_demo_mentions,relative_census_representation
387,als,asian,11,10,383,2.872063,-52.132289
388,als,black,96,10,383,25.065274,98.930747
389,als,hispanic,9,10,383,2.349869,-85.583623
390,als,indigenous,13,10,383,3.394256,208.568716
391,als,white,254,10,383,66.318538,7.659964
538,als,pacific islander,0,10,383,0.0,-100.0
1538,als,asian,133,100,2846,4.673226,-22.112907
1539,als,black,816,100,2846,28.67182,127.554128
1540,als,hispanic,88,100,2846,3.092059,-81.030313
1541,als,indigenous,64,100,2846,2.24877,104.433655


## Compare Co-occurrences to Model Logits


**Overall df**


In [44]:
# Merge the overall dataset across windows and templates
combined_df = pd.merge(
    combined_logits_df, all_windows_df, on=["disease", "demographic"], how="inner"
)

combined_df = format_data(combined_df)

combined_df.head(20)

Unnamed: 0,disease,demographic,logit_value,model_name,model_size,template,mention_count,window,total_demo_count,normalized_by_demo_mentions,relative_census_representation
73700,als,hispanic,-71.686768,EleutherAI/pythia-70m-deduped,70,0,9,10,383,2.349869,-85.583623
153560,als,black,-61.771404,EleutherAI/pythia-70m-deduped,70,0,96,10,383,25.065274,98.930747
233420,als,asian,-62.411766,EleutherAI/pythia-70m-deduped,70,0,11,10,383,2.872063,-52.132289
313280,als,white,-63.56736,EleutherAI/pythia-70m-deduped,70,0,254,10,383,66.318538,7.659964
393140,als,indigenous,-72.332314,EleutherAI/pythia-70m-deduped,70,0,13,10,383,3.394256,208.568716
473000,als,pacific islander,-86.192299,EleutherAI/pythia-70m-deduped,70,0,0,10,383,0.0,-100.0
73701,als,hispanic,-71.686768,EleutherAI/pythia-70m-deduped,70,0,88,100,2846,3.092059,-81.030313
153561,als,black,-61.771404,EleutherAI/pythia-70m-deduped,70,0,816,100,2846,28.67182,127.554128
233421,als,asian,-62.411766,EleutherAI/pythia-70m-deduped,70,0,133,100,2846,4.673226,-22.112907
313281,als,white,-63.56736,EleutherAI/pythia-70m-deduped,70,0,1737,100,2846,61.033029,-0.920408


In [45]:
# Save the combined DataFrame to a CSV file
combined_df.to_csv(
    f"{cross_care_root}/src/logits/combined_df_{demographic}.csv", index=False
)

## Visualize Logits


### Compare groups of templates


In [46]:
# create template groups
group_ranges = [(1, 10), (11, 20)]
grouped_mean_logits = group_mean_logits(combined_df, group_ranges)
grouped_mean_logits.head()

Unnamed: 0,disease,demographic,model_name,model_size,logit_value,group
0,als,asian,EleutherAI/pythia-12b-deduped,70,,1
1,als,asian,EleutherAI/pythia-12b-deduped,130,,1
2,als,asian,EleutherAI/pythia-12b-deduped,160,,1
3,als,asian,EleutherAI/pythia-12b-deduped,370,,1
4,als,asian,EleutherAI/pythia-12b-deduped,410,,1


In [47]:
# Comparisons
plot_logits_group_comparison(grouped_mean_logits, title="Logit Group Comparisons")

## Visualise co-occurrences


In [48]:
# Plot the counts for a specific window
## windows- "10", "50", "100", "250"
## count- "raw", "normalized"

# plot_counts_for_window(all_windows_df, count="raw", window="10")
# plot_counts_for_window(all_windows_df, count="normalized", window="50")
plot_counts_for_window(combined_df, count="census", window="100")