In [37]:
import os
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import display, Markdown
from dotenv import load_dotenv
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from causal_chains.CausalChain import util  # https://github.com/helliun/causal-chains

import google.generativeai as genai

from IPython.display import display
from IPython.display import Markdown

import pathlib
import textwrap
def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

In [38]:
# Load WHO data
who_data = pd.read_csv("../data/corpus.csv")

In [41]:

# Load environment variables
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")

# Initialize the Gemini API client
genai.configure(api_key=gemini_api_key)


class CausalChain:
    def __init__(self, chunks=[]):
        self.chunks = chunks
        self.causes = []
        self.effects = []
        self.outlines = []

    def create_effects(self, batch_size=16):
        print("Analyzing causation...")

        for chunk in tqdm(self.chunks):
            cause_effect_pairs = self.extract_cause_effect(chunk)
            for pair in cause_effect_pairs:
                cause, effect = pair
                self.causes.append(cause)
                self.effects.append(effect)
                self.outlines.append(f"Cause: {cause} -> Effect: {effect}")
                # Print formatted causes and effects
                display(Markdown(f"**Cause:** {cause}  \n**Effect:** {effect}"))

    def extract_cause_effect(self, chunk):
        one_shot_example = """
        Text: The sudden appearance of unlinked cases of mpox in South Africa without a history of international travel, the high HIV prevalence among confirmed cases, and the high case fatality ratio suggest that community transmission is underway, and the cases detected to date represent a small proportion of all mpox cases that might be occurring in the community; it is unknown how long the virus may have been circulating. This may in part be due to the lack of early clinical recognition of an infection with which South Africa previously gained little experience during the ongoing global outbreak, potential pauci-symptomatic manifestation of the disease, or delays in care-seeking behaviour due to limited access to care or fear of stigma.
        Cause: lack of early clinical recognition of an infection -> Effect: community transmission of mpox
        Cause: pauci-symptomatic manifestation of the disease -> Effect: lack of early clinical recognition of an infection 
        Cause: delays in care-seeking behaviour -> Effect: lack of early clinical recognition of an infection 
        Cause: limited access to care -> Effect: delays in care-seeking behaviour 
        Cause: fear of stigma -> Effect: delays in care-seeking behaviour
        """

        prompt = f"""
        Here is an example of how to identify causes (drivers leading to the diseases) and their effects (intermediate drivers leading to the diseases, excluding mortality and impacts of diseases):

        {one_shot_example}
        Now, analyze the following text and identify the specific causes and their effects:
        Text: {chunk}
        List the causes and their corresponding effects in the format 'Cause: [cause] -> Effect: [effect]':
        """

        response = genai.GenerativeModel('gemini-1.5-pro').generate_content(prompt)
        response_text = response.candidates[0].content.parts[0].text
        cause_effect_pairs = []

        for line in response_text.split("\n"):
            if "Cause:" in line and "-> Effect:" in line:
                cause = line.split("Cause:")[1].split("-> Effect:")[0].strip()
                effect = line.split("-> Effect:")[1].strip()
                cause_effect_pairs.append((cause, effect))
        return cause_effect_pairs

In [42]:
text = who_data["Text"][9]
chunks = util.create_chunks(text)
cc = CausalChain(chunks)
cc.create_effects()

Analyzing causation...
Analyzing causation...


  0%|                                                   | 0/12 [00:00<?, ?it/s]

**Cause:** Most cases being reported in children under 15 years of age, especially young children  
**Effect:** High risk of severe disease and death in infants and children under five**

**Cause:** Limited or unavailable prompt optimal case management  
**Effect:** High risk of severe disease and death in infants and children under five**

**Cause:** Continued high weekly case count  
**Effect:** Geographic expansion of the outbreak**

  8%|███▌                                       | 1/12 [00:03<00:34,  3.10s/it]

  8%|███▌                                       | 1/12 [00:06<01:12,  6.62s/it]




IndexError: list index out of range

In [35]:
def create_causes_effects_dataframe(causes, effects):
    data = {"Cause": causes, "Effect": effects}
    df = pd.DataFrame(data)
    return df


df = create_causes_effects_dataframe(cc.causes, cc.effects)
display(df)

Unnamed: 0,Cause,Effect
