# MCTS

## 1. Setup

### 1.1. Data

In [1]:
import os
if "pyproject.toml" not in os.listdir():
    os.chdir("..")

In [2]:
from mcts import mcts

In [3]:
import random
import numpy as np
from typing import Optional, Dict, List

In [4]:
from dr_claude import kb_reading, datamodels

In [5]:
reader = kb_reading.NYPHKnowldegeBaseReader("data/NYPHKnowldegeBase.html")

In [6]:
kb = reader.load_knowledge_base()

In [7]:
prob_matrix = datamodels.DiseaseSymptomKnowledgeBaseTransformer.to_numpy(kb)

In [8]:
import abc
from typing import List, NamedTuple

In [9]:
class State(abc.ABC):
    @abc.abstractmethod
    def getCurrentPlayer(self):
        ...

    @abc.abstractmethod
    def getPossibleActions(self):
        ...

    @abc.abstractmethod
    def takeAction(self, action):
        ...

    @abc.abstractmethod
    def isTerminal(self) -> bool:
        ...

    @abc.abstractmethod
    def getReward(self) -> float:
        ...

In [10]:
from typing import Set

In [11]:
class SymptomConditionsSpace(NamedTuple):
    conditions: Set[datamodels.Condition]
    symptoms: Set[datamodels.WeightedSymptom]

    @classmethod
    def from_kb(cls, kb: datamodels.DiseaseSymptomKnowledgeBase) -> 'SymptomConditionsSpace':
        conditions = list(set(kb.condition_symptoms))
        symptoms = list({s for c, symps in kb.condition_symptoms.items() for s in symps})
        return cls(conditions, symptoms)

In [12]:
len(SymptomConditionsSpace.from_kb(kb).symptoms), len(SymptomConditionsSpace.from_kb(kb).conditions)

(406, 133)

In [13]:
class Action:
    def __init__(self, name: str, is_condition: bool) -> None:
        self.name = name
        self.is_condition = is_condition

    def __hash__(self) -> int:
        return hash(self.name) + int(self.is_condition)

    def __repr__(self) -> str:
        return f"{'Condition' if self.is_condition else 'Symptom'} Action: {self.name}"

    def __eq__(self, other):
        return self.name == other.name and self.is_condition == other.is_condition

In [14]:
from copy import deepcopy

In [15]:
class Patient(State):
    def __init__(self, condition, symptoms: List, action_space: List[Action]) -> None:
        self.condition = condition
        self.symptoms = symptoms
        self.symptom_names = [s.name for s in symptoms]
        self.is_done = False
        self.action_space = action_space
        self.state_space = [a for a in self.action_space if not a.is_condition]

    def reset(self) -> None:
        self.state = np.zeros(shape=(len(self.state_space) + 1, 1))

    def getCurrentPlayer(self):
        return 1

    def getPossibleActions(self):
        return self.action_space

    def takeAction(self, action: Action):
        if action.is_condition:
            self.is_done = True
        try:
            i: int = self.state_space.index(action)
        except ValueError:
            return self
        else:
            state = deepcopy(self.state)
            state[i] = 1 if action.name in self.symptom_names else -1
        new = deepcopy(self)
        new.state = state
        return new
        
    def isTerminal(self):
        return self.is_done

    def getReward(self) -> float:
        return 1.0

    def __repr__(self) -> str:
        return f"Patient([{self.condition.name} || {[s.name for s in self.symptoms]}])"

In [18]:
import pandas as pd

In [119]:
def kb_to_dataframe(kb: datamodels.DiseaseSymptomKnowledgeBase) -> pd.DataFrame:
    rows: List[Tuple[str,str,str,str]] = []
    cols = ("Disease Code", "Disease", "Symptom Code", "Symptom", "Weight", "Noise")
    for condition, symptoms in kb.condition_symptoms.items():
        for s in symptoms:
            rows.append((condition.umls_code, condition.name, s.umls_code, s.name, s.weight, s.noise_rate))
    return pd.DataFrame(rows, columns=cols)

In [131]:
for _, row in kb_to_dataframe(kb).iterrows():
    print(row.get("Waeight", 'hi'))
    break

hi


In [21]:
prob_matrix = datamodels.DiseaseSymptomKnowledgeBaseTransformer.to_numpy(kb)

In [22]:
from dr_claude.mcts.base import compute_condition_posterior_flat_prior, compute_symptom_posterior_flat_prior_dict

In [23]:
%%time
condition_probs = compute_condition_posterior_flat_prior(prob_matrix, pertinent_positives=[], pertinent_negatives=[])

CPU times: user 25 µs, sys: 31 µs, total: 56 µs
Wall time: 55.8 µs


In [26]:
symptom_condition_space = SymptomConditionsSpace.from_kb(kb)

In [27]:
action_space = [Action(s.name, False) for s in symptom_condition_space.symptoms] + [Action(s.name, True) for s in symptom_condition_space.conditions]

In [28]:
def sample_patient(kb: datamodels.DiseaseSymptomKnowledgeBase) -> Patient:
    condition = random.choice(list(kb.condition_symptoms))
    possible_symptoms = kb.condition_symptoms[condition]
    selected_symptoms = []
    for s in possible_symptoms:
        if random.uniform(0,1) < s.weight:
            selected_symptoms.append(s)
    return Patient(condition, selected_symptoms, action_space)

In [29]:
patient = sample_patient(kb)

In [30]:
patient.condition

Condition(name='suicide attempt', umls_code='C0038663')

In [31]:
patient.reset()

In [32]:
searcher = mcts(timeLimit=1000)

In [33]:
action = searcher.search(initialState=patient)

In [34]:
action

Symptom Action: metastatic lesion

In [35]:
from langchain.llms import Anthropic
from langchain import LLMChain, PromptTemplate

In [36]:
llm = Anthropic(model='claude-2', temperature=0.0, max_tokens_to_sample=2000)
prompt_template = """Here is a list of symptoms for the condition {condition}.
                    Symptoms: {symptoms_list}.

                    Here is the output schema:
                    <?xml version="1.0" encoding="UTF-8"?>
                        <xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
                          
                          <xs:element name="covidSymptoms">
                            <xs:complexType>
                              <xs:sequence>
                                <xs:element name="symptom" maxOccurs="unbounded">
                                  <xs:complexType>
                                    <xs:sequence>
                                      <xs:element name="name" type="xs:string"/>
                                      <xs:element name="frequency">
                                        <xs:simpleType>
                                          <xs:restriction base="xs:string">
                                            <xs:enumeration value="Very common"/>
                                            <xs:enumeration value="Common"/> 
                                            <xs:enumeration value="Uncommon"/>
                                            <xs:enumeration value="Rare"/>
                                          </xs:restriction>
                                        </xs:simpleType>
                                      </xs:element>
                                    </xs:sequence>
                                  </xs:complexType>
                                </xs:element>
                              </xs:sequence>
                            </xs:complexType>
                          </xs:element>
                        
                        </xs:schema>
            Please parse the symptoms into the above schema, assigning a correct frequency value to each symptom.
            """
prompt = PromptTemplate.from_template(prompt_template)



In [37]:
import xml.etree.ElementTree as ET

In [49]:
DEFAULT_FREQ_TERM_TO_WEIGHT = {"Very common": 0.9, "Common": 0.6, "Uncommon": 0.3, "Rare": 0.1}

In [55]:
from langchain.schema import BaseOutputParser
from abc import abstractmethod

DEFAULT_FREQ_TERM_TO_WEIGHT = {"Very common": 0.9, "Common": 0.6, "Uncommon": 0.3, "Rare": 0.1}

class WeightedSymptomXMLOutputParser(BaseOutputParser[List[str]]):
    """Parse the output of an LLM call to a list.
    
    Args:
        frequency_term_to_weight: Mapping from a frequncy term to the
            causal weight that it constitutes.
    """
    frequency_term_to_weight: Dict[str, float] = DEFAULT_FREQ_TERM_TO_WEIGHT

    @property
    def _type(self) -> str:
        return "xml"

    def parse(self, text: str) -> List[datamodels.WeightedSymptom]:
        """Parse the output of an LLM call."""
        root = ET.fromstring(text)
        symptoms = []
        for symptom_elem in root:
            name = symptom_elem.find('name').text
            frequency = symptom_elem.find('frequency').text
            weight = self.frequency_term_to_weight.get(frequency, self.min_weight)
            symptom = datamodels.WeightedSymptom(umls_code="none", name=name, weight=weight)
            symptoms.append(symptom)
        return symptoms

    @property
    def min_weight(self) -> float:
        return min(self.frequency_term_to_weight.values())

In [56]:
llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=WeightedSymptomXMLOutputParser())

In [57]:
result = llm_chain.run(condition="COVID-19", symptoms_list="Fever, cough, vomiting")

In [59]:
class LLMWeightUpdateReader(kb_reading.KnowledgeBaseReader):
    """Update the weights of an existing KnowledgeBase using an LLMChain"""
    
    def __init__(self, kb_reader: kb_reading.KnowledgeBaseReader, weight_updater_chain: LLMChain) -> None:
        self._kb_reader = kb_reader
        self._weight_updater_chain = weight_updater_chain

    def load_knowledge_base(self) -> datamodels.DiseaseSymptomKnowledgeBase:
        return super().load_knowledge_base()

In [60]:
weight_update_reader = LLMWeightUpdateReader(kb, llm_chain)

In [74]:
import nest_asyncio
nest_asyncio.apply()

In [101]:
from tqdm import tqdm
from xml.etree.ElementTree import ParseError

In [102]:
sample_condition_symptoms = {c: symptoms for i, (c, symptoms) in enumerate(kb.condition_symptoms.items()) if i < 4}

In [113]:
import asyncio

async def process_condition(sem, condition, symptoms):
    async with sem:
        symptoms_str = ", ".join([s.name for s in symptoms])
        try:
            result = await llm_chain.arun(condition=condition.name, symptoms_list=symptoms_str)
        except ParseError:
            return None
        return (condition, result)

async def main(condition_symptoms: Dict[datamodels.Condition, List[datamodels.WeightedSymptom]]):
  sem = asyncio.Semaphore(1) # max concurrent calls
  weight_updated_condition_symptoms = {}
  tasks = []
  for condition, symptoms in condition_symptoms.items():
    task = asyncio.ensure_future(process_condition(sem, condition, symptoms)) 
    tasks.append(task)
  
  with tqdm(total=len(condition_symptoms)) as progress: 
    for f in asyncio.as_completed(tasks):
      outcome = await f
      if outcome is not None:
          condition, result = outcome
          weight_updated_condition_symptoms[condition] = result
      progress.update(1)

    await asyncio.gather(*tasks)
    return weight_updated_condition_symptoms

# sample_result = asyncio.run(main(sample_condition_symptoms))

In [114]:
final_result = asyncio.run(main(kb.condition_symptoms))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 133/133 [23:15<00:00, 10.49s/it]


In [116]:
weight_updated_kb = datamodels.DiseaseSymptomKnowledgeBase(condition_symptoms=final_result)

In [5]:
# kb_to_dataframe(weight_updated_kb).to_csv("data/ClaudeKnowledgeBase.csv")

In [6]:
kb_reading.CSVKnowledgeBaseReader('data/ClaudeKnowledgeBase.csv').load_knowledge_base()

DiseaseSymptomKnowledgeBase(condition_symptoms={Condition(name='hypertensive disease', umls_code='C0020538'): [WeightedSymptom(name='pain chest', umls_code='none', noise_rate=0.03, weight=0.6), WeightedSymptom(name='shortness of breath', umls_code='none', noise_rate=0.03, weight=0.6), WeightedSymptom(name='dizziness', umls_code='none', noise_rate=0.03, weight=0.6), WeightedSymptom(name='asthenia', umls_code='none', noise_rate=0.03, weight=0.3), WeightedSymptom(name='fall', umls_code='none', noise_rate=0.03, weight=0.3), WeightedSymptom(name='syncope', umls_code='none', noise_rate=0.03, weight=0.3), WeightedSymptom(name='vertigo', umls_code='none', noise_rate=0.03, weight=0.3), WeightedSymptom(name='sweat', umls_code='none', noise_rate=0.03, weight=0.6), WeightedSymptom(name='sweating increased', umls_code='none', noise_rate=0.03, weight=0.6), WeightedSymptom(name='palpitation', umls_code='none', noise_rate=0.03, weight=0.6), WeightedSymptom(name='nausea', umls_code='none', noise_rate=0