# MCTS

## 1. Setup

### 1.1. Data

In [4]:
import os
if "pyproject.toml" not in os.listdir():
    os.chdir("..")

In [5]:
from mcts import mcts

In [283]:
import random
import numpy as np
from typing import Optional

In [83]:
from dr_claude import kb_reading, datamodels

In [11]:
reader = kb_reading.NYPHKnowldegeBaseReader("data/NYPHKnowldegeBase.html")

In [12]:
kb = reader.load_knowledge_base()

In [16]:
prob_matrix = datamodels.DiseaseSymptomKnowledgeBaseTransformer.to_numpy(kb)

In [27]:
import abc
from typing import List, NamedTuple

In [28]:
class State(abc.ABC):
    @abc.abstractmethod
    def getCurrentPlayer(self):
        ...

    @abc.abstractmethod
    def getPossibleActions(self):
        ...

    @abc.abstractmethod
    def takeAction(self, action):
        ...

    @abc.abstractmethod
    def isTerminal(self) -> bool:
        ...

    @abc.abstractmethod
    def getReward(self) -> float:
        ...

In [90]:
from typing import Set

In [109]:
class SymptomConditionsSpace(NamedTuple):
    conditions: Set[datamodels.Condition]
    symptoms: Set[datamodels.WeightedSymptom]

    @classmethod
    def from_kb(cls, kb: datamodels.DiseaseSymptomKnowledgeBase) -> 'SymptomConditionsSpace':
        conditions = list(set(kb.condition_symptoms))
        symptoms = list({s for c, symps in kb.condition_symptoms.items() for s in symps})
        return cls(conditions, symptoms)

In [110]:
len(SymptomConditionsSpace.from_kb(kb).symptoms), len(SymptomConditionsSpace.from_kb(kb).conditions)

(406, 133)

In [199]:
class Action:
    def __init__(self, name: str, is_condition: bool) -> None:
        self.name = name
        self.is_condition = is_condition

    def __hash__(self) -> int:
        return hash(self.name) + int(self.is_condition)

    def __repr__(self) -> str:
        return f"{'Condition' if self.is_condition else 'Symptom'} Action: {self.name}"

    def __eq__(self, other):
        return self.name == other.name and self.is_condition == other.is_condition

In [200]:
from copy import deepcopy

In [267]:
class Patient(State):
    def __init__(self, condition, symptoms: List, action_space: List[Action]) -> None:
        self.condition = condition
        self.symptoms = symptoms
        self.symptom_names = [s.name for s in symptoms]
        self.is_done = False
        self.action_space = action_space
        self.state_space = [a for a in self.action_space if not a.is_condition]

    def reset(self) -> None:
        self.state = np.zeros(shape=(len(self.state_space) + 1, 1))

    def getCurrentPlayer(self):
        return 1

    def getPossibleActions(self):
        return self.action_space

    def takeAction(self, action: Action):
        if action.is_condition:
            self.is_done = True
        try:
            i: int = self.state_space.index(action)
        except ValueError:
            return self
        else:
            state = deepcopy(self.state)
            state[i] = 1 if action.name in self.symptom_names else -1
        new = deepcopy(self)
        new.state = state
        return new
        
    def isTerminal(self):
        return self.is_done

    def getReward(self) -> float:
        return 1.0

    def __repr__(self) -> str:
        return f"Patient([{self.condition.name} || {[s.name for s in self.symptoms]}])"

In [298]:
class PatientState(State):
    def __init__(self, action_space: List[Action], state: Optional[np.array] = None) -> None:
        self.action_space = action_space
        if not state:
            self.state = np.zeros(shape=(len(self.state_space), 1))
        else: self.state = state

    def getCurrentPlayer(self):
        return 1

    def getPossibleActions(self):
        return self.action_space

    def takeAction(self, action: Action):
        # action could be a disease prediction or a symptom question
        
    def isTerminal(self):
        return self.is_done

    def getReward(self) -> float:
        # reward based on condition probabilities
        return 1.0

    def __repr__(self) -> str:
        return f"Patient([{self.condition.name} || {[s.name for s in self.symptoms]}])"

In [285]:
symptom_condition_space = SymptomConditionsSpace.from_kb(kb)

In [286]:
action_space = [Action(s.name, False) for s in symptom_condition_space.symptoms] + [Action(s.name, True) for s in symptom_condition_space.conditions]

In [287]:
def sample_patient(kb: datamodels.DiseaseSymptomKnowledgeBase) -> Patient:
    condition = random.choice(list(kb.condition_symptoms))
    possible_symptoms = kb.condition_symptoms[condition]
    selected_symptoms = []
    for s in possible_symptoms:
        if random.uniform(0,1) < s.weight:
            selected_symptoms.append(s)
    return Patient(condition, selected_symptoms, action_space)

In [288]:
patient = sample_patient(kb)

In [289]:
patient.condition

Condition(name='encephalopathy', umls_code='C0085584')

In [290]:
patient.reset()

In [291]:
searcher = mcts(timeLimit=1000)

In [296]:
action = searcher.search(initialState=patient)

In [297]:
action

Symptom Action: abdomen acute