# Causality mining using OpenAI API


## Loading packages and data

In [1]:
from openai import OpenAI
from tqdm import tqdm
import pandas as pd
# from sentence_transformers import SentenceTransformer, util
# from sklearn.cluster import AgglomerativeClustering
# from sklearn.metrics.pairwise import cosine_distances
from IPython.display import display, Markdown
from dotenv import load_dotenv
import os
import re

In [3]:
# Read API key from the .env file
load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [2]:
who_data = pd.read_csv("../data/corpus.csv")
who_data_epi = who_data[who_data["InformationType"] == "Epidemiology"]
who_data_assessment = who_data[who_data["InformationType"] == "Assessment"]
who_data_epi_and_assessment = who_data[who_data["InformationType"].isin(["Epidemiology", "Assessment"])]

In [2]:
#import 3.Clustering Drivers.xlsx, sheet "Sheet1"
driver_cat = pd.read_excel("../data/3. Clustering Drivers.xlsx", sheet_name="V2_Peter")

## Function defnition

### Prompt Generation

In [4]:
class PromptDesigner:
    def __init__(self):
        # Store different parts of the prompt as class attributes
        self.persona_task_description = """
        You are an epidemiologist tasked with identifying sentences or phrases from outbreak reports that describe the drivers or contributors to the emergence or transmission of emerging pests and pathogens.
        """

        self.domain_localization = """
        Here is the definition of DPSIR (Drivers, Pressure, State, Impacts, and Responses) framework, where it shows how drivers are associated with the emergence of disease.
        Drivers: underlying socio-economic, environmental, or ecological forces that create conditions favourable for how a disease emerges, spreads or sustains transmission in human, animals or plants.
        Pressure: human anthropogenic activities that are mainly responsible for the chances of spillover events and the transmission of pests and pathogens.
        State: the current circulation of pests and pathogens, represented as either new case detected, an endemic, an epidemic or a pandemic.
        Impacts: the effects caused by pests and pathogens on individuals, communities' socio-economic, and political.
        Responses: the actions and interventions taken by governments to manage the occurrence of drivers and pressures, and to control the spread of pests and pathogens and to mitigate the impacts.
        """

        self.causality_definition = """
        Causality definition: In the reports, causality can take two forms. The first form is "Intra-sentence causality", where the “cause” and the “effect” lie in a single sentence, while in "Inter-sentence causality", the “cause” and the “effect” lie in different sentences.
        """
        
        self.extraction_guide = """
        Input text: The sudden appearance of unlinked cases of mpox in South Africa without a history of international travel, the high HIV prevalence among confirmed cases, and the high case fatality ratio suggest that community transmission is underway, and the cases detected to date represent a small proportion of all mpox cases that might be occurring in the community; it is unknown how long the virus may have been circulating. This may in part be due to the lack of early clinical recognition of an infection with which South Africa previously gained little experience during the ongoing global outbreak, potential pauci-symptomatic manifestation of the disease, or delays in care-seeking behaviour due to limited access to care or fear of stigma.
        
        Expected output
        1. Raw text with marked causes and effects
        The sudden appearance of unlinked cases of mpox in South Africa without a history of international travel, the high HIV prevalence among confirmed cases, and the high case fatality ratio suggest that (E1) community transmission (E1) is underway, and the cases detected to date represent a small proportion of all mpox cases that might be occurring in the community; it is unknown how long the virus may have been circulating. This may in part be due to the (C1) lack of early clinical recognition of an infection (C1) with which South Africa previously gained little experience during the ongoing global outbreak, potential (C1) pauci-symptomatic manifestation of the disease (C1), or (C1, E2) delays in care-seeking behavior (C1, E2) due to (C2) limited access to care (C2) or (C2) fear of stigma (C2).
       
        2. Extracted causes and effects
        C1: lack of early clinical recognition of an infection -> E1: community transmission 
        C1: pauci-symptomatic manifestation of the disease -> E1: community transmission 
        C1: delays in care-seeking behavior -> E1: community transmission 
        C2: limited access to care -> E2: delays in care-seeking behaviour
        C2: fear of stigma -> E2: delays in care-seeking behaviour delays in care-seeking behaviour  
        """

        self.few_shot_examples = """
        Below are some examples how causality can be reported in different forms:
        - Single cause, single effect (Type 1)

        Example 1: (C1) High population density and mobility in urban areas (C1) have facilitated (E1) the rapid spread of the virus (E1)". 

        Example 2: There is (C1) no vaccine for Influenza A(H1N1)v infection currently licensed for use in humans (C1). Seasonal influenza vaccines against human influenza viruses are generally not expected to protect people from (E1) infection with influenza viruses (E1) that normally circulate in pigs, but they can reduce severity.


        - Single cause, multiple effects (Type 2)

        Example 3: Several countries including Cameroon, Ethiopia, Haiti, Lebanon, Nigeria (north-east of the country), Pakistan, Somalia, Syria and the Democratic Republic of Congo (eastern part of the country) are in the midst of complex (C1) humanitarian crises (C1) with (E1) fragile health systems (E1), (E1) inadequate access to clean water and sanitation (E1) and have (E1) insufficient capacity to respond to the outbreaks (E1)

        - Multiple causes, single effect (Type 3)
        Example 4: Moreover, (C1) a low index of suspicion (C1), (C1) socio-cultural norms (C1), (C1) community resistance (C1), (C1) limited community knowledge regarding anthrax transmission (C1), (C1) high levels of poverty (C1) and (C1) food insecurity (C1), (C1) a shortage of available vaccines and laboratory reagents (C1), (C1) inadequate carcass disposal (C1) and (C1) decontamination practices (C1) significantly contribute to hampering (E1) the containment of the anthrax outbreak (E1).

        Example 5:
        The (E1) risk at the national level (E1) is assessed as 'High' due to the following:
        + In other parts of Timor-Leste (C1) health workers have limited knowledge dog bite and scratch case management (C1) including PEP and RIG administration
        + (C2) Insufficient stock of human rabies vaccines (C2) in the government health facilities.

        - Multiple causes, multiple effects (Type 4) - Chain of causalities
        The text may describe a chain of causality, where one effect becomes then the cause of another effect. To describe the chain, you should number the causes and effects. For example, cause 1 (C1) -> effect 1 (E1), but since effect 1 is also cause of effect 2, you should do cause 1 (C1) -> effect 1 (E1, C2) -> effect 2 (E2). 

        Example 6: (E2) The risk of insufficient control capacities (E2) is considered high in Zambia due to (C1) concurrent public health emergencies in the country (cholera, measles, COVID-19) (C1) that limit the country’s human and (E1, C2) financial capacities to respond to the current anthrax outbreak adequately (E1, C2).

        Example 7: (C1) Surveillance systems specifically targeting endemic transmission of chikungunya or Zika are weak or non-existent (C1) -> (E1, C2) Misdiagnosis between diseases  & Skewed surveillance (E1, C2) -> (E2, C3) Misinform policy decisions (E2, C3) -> (E3)reduced accuracy on the estimation of the true burden of each diseases (E3), poor risk assessments (E3), and non optimal clinical management and resource allocation (E3). 

        Example 8: (C1) Changes in the predominant circulating serotype (C1) -> (E1, C2) increase the population risk of subsequent exposure to a heterologous DENV serotype (E1, C2), -> (E2) which increases the risk of higher rates of severe dengue and deaths (E2).

        """

        self.negative_cases = """
        Irrelevant causality (negative cases): Some sentences contain causal relationships, but the effect may not be related to the disease transmission or emergence. Avoid classifying those causal relationships.

        Example 1 (no causality): Because these viruses continue to be detected in swine populations worldwide, further human cases following direct or indirect contact with infected swine can be expected.

        Example 2 (no relevant causality): There is some (E1) pressure on the healthcare capacity (E1) due to the (C1) very high number of admissions for dengue (C1); (C1) high vector density (C1); and an (C1) anticipated prolonged monsoon (C1). 

        Example 3 (no relevant causality): (C1) MVD is a highly virulent disease (C1) that can cause (E1) haemorrhagic fever (E1) and is clinically similar to Ebola virus disease.

        """

        self.mechanism_of_causality = """
        When the text describes/list possible mechanisms behind the cause of transmission or emergence, tag them with (M). A mechanism of causality describes the specific interactions between the pathogen, host, and environment that causes the transmission / emergence. They often describe interactions at the physiological level. 

        Example 1: The global outbreak 2022 — 2024 has shown that (C1) sexual contact (C1) enables faster and more efficient (E1) spread of the virus (E1) from one person  to another due to (M1) direct contact of mucous membranes between people (M1), (M1) contact with multiple partners (M1), (M1) a possibly shorter incubation period on average (M1), and (M1) a longer infectious period for immunocompromised individuals (M1).

        """

        self.sign_of_causality = """
        For each cause-effect relationship, indicate whether each cause (C) is positive (C+) or negative (C-) and each effect (E) is positive (E+) or negative (E-). 
        Use the list of positive and negative sign words provided to help determine the sign of each cause and effect. Be mindful of sentences with negations (e.g., “does not improve”), which reverses polarity. 
        Positive sign words: increase, facilitate, support, improve, expand, promote, enable, enhance, accelerate, advance, grow, boost, strengthen, benefit, contribute, progress, initiate, develop, elevate, stimulate, alleviate, optimize, revitalize. 
        Negative sign words: limit, decrease, reduce, hamper, hinder, restrict, suppress, impair, inhibit, undermine, challenge, disrupt, lack, insufficient, incomplete, challenge, deficit, obstacle, barrier, diminish, shortage, scarcity, obstruct, worsen, decline. 

        Example 1: “(C1-) a lack of timely access to diagnostics in many areas (C1-), (C1-) incomplete epidemiological investigations (C1-), (C1-) challenges in contact tracing and extensive but inconclusive animal investigations (C1-) continue to hamper rapid response (E1-)”

        Example 2: Moreover, (C1-) a low index of suspicion (C1-), (C1) socio-cultural norms (C1), (C1) community resistance (C1), (C1-) limited community knowledge regarding anthrax transmission (C1-), (C1+) high levels of poverty (C1+) and (C1) food insecurity (C1), (C1-) a shortage of available vaccines and laboratory reagents (C1-), (C1-) inadequate carcass disposal (C1-) and (C1) decontamination practices (C1) significantly contribute to hampering (E1-) the containment of the anthrax outbreak (E1-).
        """

    def generate_prompt(self, include_persona=False, include_domain=False, include_causality=False, include_guidance = False, include_examples=False, include_negative=False, include_mechanism=False, include_sign=False):
        """
        Dynamically generate a prompt based on the specified parts.
        """
        # Start with an empty prompt
        prompt = ""

        # Append parts based on the arguments provided
        if include_persona:
            prompt += self.persona_task_description + "\n"
        
        if include_domain:
            prompt += self.domain_localization + "\n"
        if include_causality:
            prompt += self.causality_definition + "\n"
        if include_guidance:
            prompt += self.extraction_guide + "\n"
        if include_examples:
            prompt += self.few_shot_examples + "\n"
        if include_negative:
            prompt += self.negative_cases + "\n"
        if include_mechanism:
            prompt += self.mechanism_of_causality + "\n"
        if include_sign:
            prompt += self.sign_of_causality + "\n"

        return prompt

### Causality Extraction

In [7]:
# Function to split text into chunks
def batch(iterable, n=1):
    """Utility function to batch sentences into chunks."""
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

class CausalChain:
    def __init__(self, dataframe, prompt_designer=None):
        self.dataframe = dataframe
        self.outlines = []  # Store a list of dictionaries to represent complex relationships
        self.prompt_designer = prompt_designer if prompt_designer else PromptDesigner()
        self.processed_chunks = set()  # Track processed chunks to avoid repetition

    def create_effects(self, batch_size=16, prompt_parts={}):
        print("Extracting causal relationships from text...")

        with open("api_responses.md", "w", encoding="utf-8") as file:
            for index, row in tqdm(self.dataframe.iterrows(), total=self.dataframe.shape[0]):
                text = row['Text']
                don_id = row['DonID_standardized']

                # Split text into sentences and then into chunks of 3 sentences
                sentences = text.split(". ")
                chunks = [". ".join(a) + "." for a in batch(sentences, 3)]

                for chunk in chunks:
                    # Skip if the chunk has already been processed
                    if (don_id, chunk) in self.processed_chunks:
                        continue

                    cause_effect_pairs, raw_texts, causality_types, response_text = self.extract_cause_effect_openai(chunk, prompt_parts, don_id)

                    # Mark the chunk as processed
                    self.processed_chunks.add((don_id, chunk))

                    # Write the response to the file immediately after receiving it
                    file.write(f"\n\n## API Response for Article ID {don_id}:\n\n{response_text}\n\n")

                    if not cause_effect_pairs and not raw_texts:
                        self.outlines.append({
                            "DonId": don_id,
                            "Cause": None,
                            "Effect": None,
                            "Causality_Type": "No relevant causality",
                            "Raw_Text": chunk
                        })
                        print(f"No cause-effect pairs found for chunk: {chunk}")

                    for pair, raw_text, causality_type in zip(cause_effect_pairs, raw_texts, causality_types):
                        cause, effect = pair
                        # Remove markers like "E1:" from the effect
                        effect = effect.split(":", 1)[-1].strip() if effect and ":" in effect else effect
                        self.outlines.append({
                            "DonId": don_id,
                            "Cause": cause,
                            "Effect": effect,
                            "Causality_Type": causality_type,
                            "Raw_Text": raw_text
                        })

        # Print the raw texts, causes, effects, and types of causality
        if self.outlines:
            for outline in self.outlines:
                print(f"DonId: {outline['DonId']}")
                print(f"Raw Text: {outline['Raw_Text']}")
                print(f"Cause: {outline['Cause']}")
                print(f"Effect: {outline['Effect']}")
                print(f"Causality Type: {outline['Causality_Type']}")
                print("\n")
        else:
            print("No cause-effect pairs found in the entire dataset.")

    def extract_cause_effect_openai(self, chunk, prompt_parts={}, don_id=None):
        # Use the PromptDesigner to generate the customized prompt
        prompt = self.prompt_designer.generate_prompt(**prompt_parts)

        # Append the text chunk to the prompt and provide a clear format for the response
        full_prompt = f"""{prompt}

        Input text: {chunk}

        Expected output format:
        1. Raw text with marked causes and effects:
        [Provide the input text with marked causes and effects]

        2. Extracted causes and effects:
        C1: [cause] -> E1: [effect], Causality type: [T1/T2/...]
        C2: [cause] -> E2: [effect], Causality type: [T1/T2/...]
        ...
        """

        # Call the OpenAI API
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[{"role": "user", "content": full_prompt}],
            max_tokens=2048,
            temperature=0,
        )

        response_text = response.choices[0].message.content
        print(f"API Response for Article ID {don_id}: {response_text}")  # Print the API response for debugging
        return self.parse_response(response_text) + (response_text,)

    @staticmethod
    def parse_response(response_text):
        cause_effect_pairs = []
        raw_texts = []
        causality_types = []  # Store the causality types

        if not response_text.strip():
            print("Empty response received from API.")
            return cause_effect_pairs, raw_texts, causality_types

        # Parse the response based on the expected output format
        lines = response_text.split("\n")
        raw_text_section = False
        extracted_pairs_section = False
        raw_text = ""

        for line in lines:
            line = line.strip()

            if line.startswith("1. Raw text with marked causes and effects"):
                raw_text_section = True
                extracted_pairs_section = False
                raw_text = ""  # Reset raw text for each new section
                continue

            if line.startswith("2. Extracted causes and effects"):
                raw_text_section = False
                extracted_pairs_section = True
                continue

            if raw_text_section and line:
                raw_text += line + " "

            if extracted_pairs_section:
                if line.startswith("C") and "->" in line:
                    try:
                        cause = line.split(":")[1].split("->")[0].strip()
                        effect = line.split("->")[1].split(", Causality type:")[0].strip()
                        causality_type = line.split("Causality type:")[1].strip()
                        #! Remove markers like "E1:" from the effect
                        effect = effect.split(":", 1)[-1].strip() if effect and ":" in effect else effect
                        cause_effect_pairs.append((cause, effect))
                        raw_texts.append(raw_text.strip())
                        causality_types.append(causality_type)
                    except IndexError:
                        print(f"Malformed line: {line}")
                #! Process No causality line
                # elif line.lower().startswith("no relevant causality") or line.lower().startswith("there are no relevant causes"):
                elif "->" not in line:
                    cause_effect_pairs.append((None, None))
                    raw_texts.append(raw_text.strip())
                    causality_types.append("No causality")

        return cause_effect_pairs, raw_texts, causality_types

def create_causes_effects_dataframe(outlines):
    # Create a DataFrame from the outlines list
    dataframe = pd.DataFrame(outlines)

    # Ensure all expected columns are present, fill with 'Unknown' if missing
    for column in ["DonId", "Cause", "Effect", "Causality_Type", "Raw_Text"]:
        if column not in dataframe:
            dataframe[column] = "Unknown"

    return dataframe


## Experimenting the model

In [6]:
prompt_designer = PromptDesigner()

prompt_parts = {
    "include_persona": True,
    "include_domain": True,
    "include_causality": True,
    "include_guidance": True,
    "include_examples": True,
    "include_negative": True,
    "include_mechanism": False,
    "include_sign": False,
}

In [None]:
example_data = who_data_assessment.iloc[0:30]


# Create a CausalChain instance with the dataset
causal_chain = CausalChain(dataframe=example_data, prompt_designer=prompt_designer)

# Generate effects based on the chunks of text
causal_chain.create_effects(prompt_parts=prompt_parts)

In [None]:
# Create a DataFrame with the causes, effects, and other related information
result_df = create_causes_effects_dataframe(causal_chain.outlines)

# Display the resulting DataFrame
display(result_df)

In [12]:
# Export data to csv
result_df.to_csv('result_df_31 Oct.csv', index=False)

## Drivers Categorization

In [13]:
# Import result_df_31 Oct.csv
result_df = pd.read_csv('../data/result_df_31 Oct.csv')

In [None]:
# Function to get a summary for each cause-effect pair
def get_summary(text):
    prompt = (
        f"Analyze the following text to identify common categories related to drivers of infectious diseases. "
        f"Avoid mentioning specific diseases or too specific terms. "
        f"Summarize the text into two words only. For example:\n"
        f"- Text: 'transmission of ebola' -> Summary: 'disease transmission'\n"
        f"- Text: 'COVID-19 infection' -> Summary: 'disease transmission'\n"
        f"Please summarize the following: '{text}'"
    )
    
    response = client.chat.completions.create(
        model="gpt-4-turbo",
        messages=[{"role": "user", "content": prompt}],
        max_tokens=100,
        temperature=0
    )
    return response.choices[0].message.content

# Apply function to each row in cause and effect columns
result_df['Cause_category'] = result_df['Cause'].apply(get_summary)
result_df['Effect_category'] = result_df['Effect'].apply(get_summary)


### Using predefined list to categorize drivers

In [None]:
# Sort column "Consolidated Name" and remove duplicates in that column, write in method chaining and save to new object
driver_cat_rm_dups = (
    driver_cat
    .rename(columns={"Peter's name": "Category"})
    .assign(Category=lambda df: df['Category'].ffill()) # fill the data of column "Category" by the value above it for missing values
    .drop_duplicates(subset=["Category", "Consolidated Name"])
) 

predefined_driver_category = (
    driver_cat_rm_dups
    .groupby("Category")["Consolidated Name"]
    .apply(lambda x: x.dropna().unique().tolist())
    .to_dict()
)

# Print the result
print(predefined_driver_category)

In [11]:
def format_prompt(text, category_dict):
    # Format the dictionary into a string for the prompt
    category_examples = "\n".join(
        f"- {category}: {', '.join(consolidated_names[:50])}..."  
        for category, consolidated_names in category_dict.items()
    )
    
    # Construct the prompt
    prompt = (
        f"Analyze the following text and map it to a predefined category from the list below. "
        f"Return the output in this exact format:\n"
        f"consolidate_name: [name], category: [category]\n\n"
        f"Categories and examples:\n{category_examples}\n\n"
        f"Example mappings:\n"
        f"- Text: 'Socio-economic factors, high levels of poverty' -> consolidate_name: socioeconomic, category: Economy\n"
        f"- Text: 'favorable conditions for vector populations during the monsoon season in affected areas' -> consolidate_name: climate, category: Climate/Weather\n"
        f"- Text: 'Lack of laboratory capacity' -> consolidate_name: infrastructure, category: Build infrastructure\n\n"
        f"In case you cannot match the orginal text with any of the consolidated names, but the original text is about diseases transmission process\n\n"
        f"Summarize the text into two words only for the consolidate_name. Avoid mentioning specific diseases or too specific terms. For example:\n"
        f"- Text: 'contact with infected poultry or environments that have been contaminated' -> consolidate_name: 'poultry exposure', category: Disease transmission\n"
        f"- Text: 'close contact with A(H5N1)-infected live or dead birds or mammals' -> consolidate_name: 'animal exposure', category: Disease transmission\n"
        f"If none of the above applies, please return consolidate_name: 'Undefined', category: Undefined\n\n"
        f"Now, analyze the following text:\n"
        f"'{text}'\n"
        f"Provide your answer in this format: consolidate_name: [name], category: [category]"
    )
    return prompt

def get_summary_with_prelist(text, category_dict):
    # Format the prompt with the dictionary
    prompt = format_prompt(text, category_dict)
    
    # Send the prompt to OpenAI
    response = client.chat.completions.create(
        model="gpt-4-turbo",
        messages=[{"role": "user", "content": prompt}],
        max_tokens=100,
        temperature=0
    )
    response_text = response.choices[0].message.content
    print(f"API Response: {response_text}")  # Debugging line
    
    # Extract consolidate_name and category
    try:
        if response_text.startswith("consolidate_name:") and ", category:" in response_text:
            consolidate_name, category = map(str.strip, response_text.split(", category:"))
            consolidate_name = consolidate_name.replace("consolidate_name:", "").strip()
            category = category.replace("'", "").strip()
            return consolidate_name, category
        else:
            return None, None
    except Exception as e:
        print(f"Error processing response: {e}")
        return None, None


def categorize_text(row, column_name, category_dict):
    """
    Categorize the text from the specified column using OpenAI and the category dictionary.
    """
    consolidate_name, category = get_summary_with_prelist(row[column_name], category_dict)
    return pd.Series({f'{column_name}_consolidate_name': consolidate_name, f'{column_name}_category_new': category})



In [None]:
result_df = result_df.apply(
    lambda row: pd.concat([
        pd.Series(row),  # Keep the original row data
        categorize_text(row, 'Cause', predefined_driver_category),  # Add Cause-related columns
        categorize_text(row, 'Effect', predefined_driver_category)  # Add Effect-related columns
    ]),
    axis=1
)

In [None]:
(result_df
 .filter(['DonId', 'Cause', 'Cause_consolidate_name', 'Cause_category_new', 'Effect', 'Effect_consolidate_name', 'Effect_category_new']))

In [17]:
result_df.to_csv('../data/result_df_19 Nov.csv', index=False)

In [2]:
result_df = pd.read_csv('../data/result_df_19 Nov.csv')

In [5]:
new_df = result_df[['Cause', 'Cause_category', 'Cause_consolidate_name']]

# NLP algorithms

- Sentence Embeddings with Clustering: Create a vector capturing semantic meaning, then using clustering algorithms like K-Mean, Hierarchical Clustering, DBSCAN
- Topic Modeling: Latent Dirichlet Allocation (LDA) or Non-Negative Matrix Factorization (NMF) -> typically used for longer texts
- Text Similarity using Cosine Similarity and Clustering
- Self-Supervised Clustering with Transformers: BERTopic uses transformer embeddings with dimensionality reduction and topic representation, enabling a more dynamic clustering approach suitable for nuanced or dense datasets

In [28]:
cosine_test = pd.read_csv('result_df_31 Oct.csv')

# Create a list of unique categories
unique_drivers_categories = list(set(cosine_test['Cause_category'].tolist() + cosine_test['Effect_category'].tolist()))

In [29]:
# Load the transformer model
model = SentenceTransformer('all-mpnet-base-v2')

# Encode the unique categories
category_embeddings = model.encode(unique_drivers_categories, convert_to_tensor=True)

# Generate embeddings for each unique category
category_embeddings = model.encode(unique_drivers_categories, convert_to_tensor=True)

# Calculate the cosine similarity matrix
cosine_similarity_matrix = util.pytorch_cos_sim(category_embeddings, category_embeddings)

In [31]:
# Convert cosine similarities to distances for clustering
cosine_distance_matrix = 1 - cosine_similarity_matrix.cpu().numpy()

# Apply Agglomerative Clustering
clustering_model = AgglomerativeClustering(
    metric='precomputed',
    linkage='average',
    n_clusters=5  # Choose the number of clusters or use distance_threshold
)
cluster_labels = clustering_model.fit_predict(cosine_distance_matrix)

In [None]:
# Create a DataFrame for unique categories and their cluster labels
clustered_categories = pd.DataFrame({
    'category': unique_drivers_categories,
    'cluster': cluster_labels
})

# Display grouped categories
display(clustered_categories.sort_values(by='cluster'))

In [None]:
# Merge cluster labels for Cause and Effect categories back
cosine_test = cosine_test.merge(clustered_categories, left_on='Cause_category', right_on='category', how='left').rename(columns={'cluster': 'Cause_cluster'}).drop(columns=['category'])
cosine_test = cosine_test.merge(clustered_categories, left_on='Effect_category', right_on='category', how='left').rename(columns={'cluster': 'Effect_cluster'}).drop(columns=['category'])

# Display result with cluster labels
display(cosine_test[['Cause_category', 'Cause_cluster', 'Effect_category', 'Effect_cluster']])


In [None]:
cosine_test.to_csv('cosine_test_31 Oct.csv', index=False)

In [None]:
# Map embeddings for Cause and Effect categories
cosine_test['Cause_embedding'] = cosine_test['Cause_category'].apply(lambda x: category_embedding_dict[x])
cosine_test['Effect_embedding'] = cosine_test['Effect_category'].apply(lambda x: category_embedding_dict[x])

# Calculate cosine similarity between each Cause and Effect embedding
cosine_test['cosine_similarity'] = cosine_test.apply(
    lambda row: util.pytorch_cos_sim(row['Cause_embedding'], row['Effect_embedding']).item(),
    axis=1
)

print(cosine_test[['Cause_category', 'Effect_category', 'cosine_similarity']])


In [None]:
# Grouping by DonId and calculating the percentage of "No causality"
def calculate_no_relevant_causality_percentage(df):
    grouped = df.groupby('DonId')['Causality_Type'].apply(lambda x: (x == 'No causality').mean() * 100)
    return grouped

# Example usage
# Assuming `df` is your dataframe created from the outlines
grouped_percentage = calculate_no_relevant_causality_percentage(result_df)
print(grouped_percentage)