In [None]:
!pip install gym

In [213]:
%load_ext autoreload
%autoreload 2
from src.utils import *


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
import traceback

all_transplants, _ = read_organ_data()

In [7]:
all_transplants = all_transplants.dropna(subset=[Column.DONOR_ID.name, Column.ORGAN_TRANSPLANT_ID.name, Column.ORGAN_RECOVERY_DATE.name, Column.TRANSPLANT_DATE.name, Column.END_DATE.name]).sort_values(by=Column.TRANSPLANT_DATE.name)
all_transplants[[Column.DONOR_ID.name, Column.ORGAN_TRANSPLANT_ID.name, Column.ORGAN_RECOVERY_DATE.name, Column.TRANSPLANT_DATE.name, Column.END_DATE.name]].sort_values(by=Column.TRANSPLANT_DATE.name)

Unnamed: 0,DONOR_ID,ORGAN_TRANSPLANT_ID,ORGAN_RECOVERY_DATE,TRANSPLANT_DATE,END_DATE
818,89832.0,A373229,2002-03-23,2002-03-23,2002-03-23
379,183045.0,A358946,2002-04-02,2002-04-03,2002-04-03
7381,319621.0,A294693,2002-04-03,2002-04-03,2002-04-03
11508,158846.0,A249016,2002-04-03,2002-04-04,2002-04-04
7568,268138.0,A41833,2002-04-05,2002-04-05,2002-04-05
...,...,...,...,...,...
5300,709040.0,A1041486,2024-09-28,2024-09-28,2024-09-28
8483,709098.0,A1041306,2024-09-29,2024-09-29,2024-09-29
4755,708557.0,A1041519,2024-09-28,2024-09-29,2024-09-29
157,708568.0,A1041578,2024-09-29,2024-09-30,2024-09-30


In [77]:
highest_rows_per_recovery_date = all_transplants.groupby(Column.ORGAN_RECOVERY_DATE.name).size().max()
highest_rows_per_recovery_date


9

In [78]:
date = pd.Timestamp('2002-04-03')
available_organs = get_mininal_columns_available_organs(by_date=date)

In [79]:
available_organs

Unnamed: 0,DONOR_ID,DONOR_BLOOD_TYPE,DONOR_BLOOD_TYPE_AS_CODE,ORGAN_RECOVERY_DATE
0,183045,O,7,2002-04-02
1,319621,B,6,2002-04-03
2,158846,O,7,2002-04-03


In [64]:
waitlist_members = get_mininal_columns_waitlist(by_date=date)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


In [65]:
waitlist_members.head(5)

Unnamed: 0,ORGAN_TRANSPLANT_ID,RECIPIENT_ID,DONOR_ID,INIT_MELD_PELD_LAB_SCORE,FINAL_MELD_PELD_LAB_SCORE,PATIENT_SURVIVAL_TIME,GRAFT_LIFESPAN,DAYS_ON_WAITLIST,RECIPIENT_BLOOD_TYPE,RECIPIENT_BLOOD_TYPE_AS_CODE,RECIPIENT_AGE,TRANSPLANT_DATE,END_DATE,DIAGNOSIS
260,A395477,197725,243842,44,36,2341,2341,46,A,0,51.0,2002-05-14,2002-05-14,4215
6029,A180548,547245,141481,24,21,5879,5879,505,O,7,50.0,2003-07-26,2003-07-26,4215
10793,A152583,402877,98330,21,23,6,0,293,O,7,46.0,2003-01-16,2003-01-16,4307
2149,A216989,377703,183293,20,22,3505,3505,31,O,7,50.0,2002-04-07,2002-04-07,4307
12558,A81652,259215,291782,18,22,7700,7700,139,A,0,40.0,2002-07-30,2002-07-30,4204


In [37]:
WAITLIST_FEATURES = [
        Column.RECIPIENT_ID.name, 
        Column.INIT_MELD_PELD_LAB_SCORE.name, 
        Column.DAYS_ON_WAITLIST.name,
        Column.RECIPIENT_BLOOD_TYPE.name,
        Column.RECIPIENT_BLOOD_TYPE_AS_CODE.name,
]
sample_waitlist_members = waitlist_members[WAITLIST_FEATURES].head(5)
sample_waitlist_members

Unnamed: 0,RECIPIENT_ID,INIT_MELD_PELD_LAB_SCORE,DAYS_ON_WAITLIST,RECIPIENT_BLOOD_TYPE,RECIPIENT_BLOOD_TYPE_AS_CODE
260,197725,44,46,A,0
7381,407356,38,2,B,6
11508,132222,27,2,B,6
6029,547245,24,505,O,7
10793,402877,21,293,O,7


In [218]:
def compare_to_baseline(current_date: pd.Timestamp, end_date: pd.Timestamp):
    total_graft_lifespan = 0
    total_patient_survival_time = 0

    date = current_date

    while current_date <= end_date:
        # Get the next day's data
        date, [available_organs, waitlist_members] = get_next_day(date, allocated_ids=[])

        # Sum up the GRAFT_LIFESPAN and PATIENT_SURVIVAL_TIME for the current day's waitlist members
        total_graft_lifespan += waitlist_members[Column.GRAFT_LIFESPAN.name].sum()
        total_patient_survival_time += waitlist_members[Column.PATIENT_SURVIVAL_TIME.name].sum()

    total_graft_lifespan, total_patient_survival_time

In [None]:
baseline_graft_lifespan, baseline_patient_survival_time = compare_to_baseline(pd.Timestamp('02-04-24'), pd.Timestamp('02-04-24'))

In [204]:
import logging
logging.basicConfig(level=logging.INFO)

In [230]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random

class OrganMatchingEnv:
    def __init__(self, max_waitlist_members: int = 10, max_days: int = 3):
        """
        Args:
            daily_available_donors_list: A list of DataFrames, one per day, of available donors.
                                         Each day's DataFrame has columns like: 
                                         [DONOR_ID, DONOR_BLOOD_TYPE, ...]
                                         The length of this DataFrame is the number of donors for that day.

            daily_waitlist_members_list: A list of DataFrames, one per day, of waitlist recipients.
                                         Each day's DataFrame has columns like:
                                         [RECIPIENT_ID, FINAL_MELD_PELD_LAB_SCORE, GRAFT_LIFESPAN (predicted or known), ...]
                                         The length of this DataFrame is the number of recipients for that day.
                                         
            max_days: The horizon over which we measure performance or run the simulation.

        We assume:
        - Each day we have multiple donors and multiple recipients.
        - Each step in the environment involves selecting a recipient for the **current donor**.
          After all donors for that day are processed, we move to the next day.
        """
        self.initial_date = pd.Timestamp('2002-04-02')
        self.max_waitlist_members = max_waitlist_members
        self.max_days = max_days

        self.current_day_idx = self.initial_date
        self.current_donor_idx = 0
        self.total_days_allocated = -1

        # These will be set at reset
        self.available_donor_df = None
        self.waitlist_member_df = None
        self.donor_ids_allocated = []

    def reset(self):
        self.current_donor_idx = 0
        self.donor_ids_allocated = []
        self.refetch_data(self.initial_date)
        return self._get_state()
    
    def refetch_data(self, date: pd.Timestamp):
        next_date, data = get_next_day(date, allocated_ids=self.donor_ids_allocated)
        # Reset back to 0
        if date != next_date:
            self.current_donor_idx = 0
            logging.info(f'Moving onto next day {next_date} with donor index {self.current_donor_idx} for donor ID: {self._get_donor()}')
        self.current_day_idx = next_date
        self.available_donor_df = data[0]
        self.waitlist_member_df = data[1]

    def get_valid_actions(self):
        if self.current_donor_idx >= len(self.available_donor_df):
            logging.info(f'{self.current_donor_idx} is bigger than available donor df: {len(self.available_donor_df)}')
        donor = self.available_donor_df.iloc[self.current_donor_idx]
        donor_blood_type = donor[Column.DONOR_BLOOD_TYPE.name]
        valid_allocations = [i for i in range(len(self.waitlist_member_df)) if get_match_value(donor_blood_type=donor_blood_type, recipient_blood_type=self.waitlist_member_df.iloc[i][Column.RECIPIENT_BLOOD_TYPE.name]) >= 1.0]
        return valid_allocations

    def _get_state(self):
        self.refetch_data(self.current_day_idx)
        encoded_array = np.array([
            list(row) for row in zip(self.waitlist_member_df[Column.RECIPIENT_ID.name], 
                                    self.waitlist_member_df[Column.INIT_MELD_PELD_LAB_SCORE.name], 
                                    self.waitlist_member_df[Column.DAYS_ON_WAITLIST.name], 
                                    self.waitlist_member_df[Column.RECIPIENT_BLOOD_TYPE_AS_CODE.name])
        ])

        # Pad the array with additional rows if necessary
        # This should never get hit
        if encoded_array.shape[0] < self.max_waitlist_members:
            padding = np.full((self.max_waitlist_members - encoded_array.shape[0], encoded_array.shape[1]), -1)
            encoded_array = np.vstack([encoded_array, padding])

        state = np.array(encoded_array, dtype=np.float32)
        return state
    
    def _get_donor(self):
        return self.available_donor_df.iloc[self.current_donor_idx][Column.DONOR_ID.name]
    
    def _get_chosen_recipient(self, action: int):
        if action > len(self.waitlist_member_df):
            return pd.Series.empty()
        return self.waitlist_member_df.iloc[action]

    def step(self, action: int):
        """
        Args:
            action: index corresponding to the chosen recipient in waitlist_member_df.
                    This means action is in [0, max_waitlist_members_to_consider].
        
        Returns:
            next_state: state after this allocation
            reward: float, e.g. graft lifespan or related metric
            done: bool, whether the episode ended
            info: dict, extra info
        """
        chosen_recipient = self._get_chosen_recipient(action)
        self.total_days_allocated += 1
        self.donor_ids_allocated.append(self._get_donor())
        self.donor_ids_allocated.append(chosen_recipient[Column.DONOR_ID.name])
        reward = 0.0
        if chosen_recipient.empty:
            reward = -1000
        else:  
            # It doesn't matter who they got the organ from. Just include the rewards of this person getting an organ.
            if chosen_recipient[Column.GRAFT_LIFESPAN.name] is not None:
                reward += float(chosen_recipient[Column.GRAFT_LIFESPAN.name])
            if chosen_recipient[Column.GRAFT_LIFESPAN.name] is not None:
                reward += float(chosen_recipient[Column.PATIENT_SURVIVAL_TIME.name])
        logging.info(f'Chose recipient with reward {chosen_recipient[Column.RECIPIENT_ID.name]} : {reward}')

        # Move to next donor
        chosen_donor = self._get_donor()
        self.current_donor_idx += 1
        if self.current_donor_idx >= len(self.available_donor_df):
            if self.total_days_allocated < self.max_days:
                logging.info(f'Donor {chosen_donor} is the end of donors')
                self.refetch_data(self.current_day_idx)
            else:
                # No more days left
                done = True
                next_state = np.zeros_like(self._get_state())  # terminal state
                return next_state, reward, done, {}
        
        done = (self.total_days_allocated >= self.max_days)
        if not done:
            next_state = self._get_state()
        else:
            next_state = np.zeros_like(self._get_state())

        return next_state, reward, done, {}

In [196]:
env = OrganMatchingEnv()
env.reset()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  waitlist_members[Column.DONOR_ID.name] = waitlist_members[Column.DONOR_ID.name].astype(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  waitlist_members[Column.DONOR_ID.name] = waitlist_members[Column.DONOR_ID.name].astype(


array([[1.97725e+05, 4.40000e+01, 4.60000e+01, 0.00000e+00],
       [4.07356e+05, 3.80000e+01, 2.00000e+00, 6.00000e+00],
       [1.32222e+05, 2.70000e+01, 2.00000e+00, 6.00000e+00],
       [5.47245e+05, 2.40000e+01, 5.05000e+02, 7.00000e+00],
       [4.02877e+05, 2.10000e+01, 2.93000e+02, 7.00000e+00],
       [3.77703e+05, 2.00000e+01, 3.10000e+01, 7.00000e+00],
       [2.59215e+05, 1.80000e+01, 1.39000e+02, 0.00000e+00],
       [4.10354e+05, 1.70000e+01, 1.60000e+01, 0.00000e+00],
       [3.88101e+05, 1.70000e+01, 2.10000e+01, 7.00000e+00],
       [1.97548e+05, 1.60000e+01, 9.90000e+01, 6.00000e+00]],
      dtype=float32)

In [194]:
env.available_donor_df

Unnamed: 0,DONOR_ID,DONOR_BLOOD_TYPE,DONOR_BLOOD_TYPE_AS_CODE,ORGAN_RECOVERY_DATE
0,183293,O,7,2002-04-07


In [193]:
env.waitlist_member_df.head(10)

Unnamed: 0,ORGAN_TRANSPLANT_ID,RECIPIENT_ID,DONOR_ID,INIT_MELD_PELD_LAB_SCORE,FINAL_MELD_PELD_LAB_SCORE,PATIENT_SURVIVAL_TIME,GRAFT_LIFESPAN,DAYS_ON_WAITLIST,RECIPIENT_BLOOD_TYPE,RECIPIENT_BLOOD_TYPE_AS_CODE,RECIPIENT_AGE,TRANSPLANT_DATE,END_DATE,DIAGNOSIS
260,A395477,197725,243842,44,36,2341,2341,46,A,0,51.0,2002-05-14,2002-05-14,4215
10793,A152583,402877,98330,21,23,6,0,293,O,7,46.0,2003-01-16,2003-01-16,4307
2149,A216989,377703,183293,20,22,3505,3505,31,O,7,50.0,2002-04-07,2002-04-07,4307
12558,A81652,259215,291782,18,22,7700,7700,139,A,0,40.0,2002-07-30,2002-07-30,4204
37,A313443,410354,185424,17,14,2988,2988,16,A,0,68.0,2002-04-13,2002-04-13,4220
5794,A346550,197548,203529,16,13,3865,3865,99,B,6,53.0,2002-06-15,2002-06-15,4215
2685,A218998,157944,284553,15,12,3763,3763,160,A,0,43.0,2002-08-10,2002-08-10,4216
4044,A103642,495544,296852,15,19,2101,2101,330,O,7,49.0,2003-02-13,2003-02-13,4216
4835,A301806,87283,267798,15,15,8034,8034,68,O,7,57.0,2002-06-08,2002-06-08,4401


In [188]:
env.waitlist_member_df[env.waitlist_member_df[Column.DONOR_ID.name] == 183045]

Unnamed: 0,ORGAN_TRANSPLANT_ID,RECIPIENT_ID,DONOR_ID,INIT_MELD_PELD_LAB_SCORE,FINAL_MELD_PELD_LAB_SCORE,PATIENT_SURVIVAL_TIME,GRAFT_LIFESPAN,DAYS_ON_WAITLIST,RECIPIENT_BLOOD_TYPE,RECIPIENT_BLOOD_TYPE_AS_CODE,RECIPIENT_AGE,TRANSPLANT_DATE,END_DATE,DIAGNOSIS


In [191]:
env.get_valid_actions()

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [226]:
done = False
env = OrganMatchingEnv()
env.reset()
total_reward = 0
while not done:
    valid_actions = env.get_valid_actions()
    next_state, reward, done, _ = env.step(random.choice(valid_actions))
    total_reward += reward
    print(f'Reward: {total_reward} and reached done? {done}')

INFO:root:Chose recipient with reward 259215 : 15400.0
INFO:root:Donor 183045 is the end of donors
INFO:root:Moving onto next day 2002-04-03 00:00:00 with donor index 0 for : 183045


Reward: 15400.0 done False


INFO:root:Chose recipient with reward 132222 : 15218.0
INFO:root:Moving onto next day 2002-04-05 00:00:00 with donor index 0 for : 319621


Reward: 15218.0 done False


INFO:root:Chose recipient with reward 410354 : 5976.0
INFO:root:Donor 268138 is the end of donors
INFO:root:Moving onto next day 2002-04-07 00:00:00 with donor index 0 for : 268138


Reward: 5976.0 done False


INFO:root:Chose recipient with reward 197548 : 7730.0
INFO:root:Moving onto next day 2002-04-24 00:00:00 with donor index 0 for : 183293


Reward: 7730.0 done True


In [181]:
env.current_day_idx

Timestamp('2002-04-03 00:00:00')

In [180]:
env.donor_ids_allocated

[183045, 319621]

In [227]:
class DQN(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=64):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_size)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)
        
    def push(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        
    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states, dtype=np.float32),
                np.array(actions, dtype=np.int64),
                np.array(rewards, dtype=np.float32),
                np.array(next_states, dtype=np.float32),
                np.array(dones, dtype=np.bool_))
        
    def __len__(self):
        return len(self.memory)

In [228]:
def train_dqn(env: OrganMatchingEnv,
              episodes: int =100,
              gamma: float=0.99,
              epsilon_start: float = 1.0,
              epsilon_end: float = 0.01,
              epsilon_decay: int = 500,
              lr: float = 1e-3,
              batch_size: int = 32,
              memory_capacity: int = 1000,
              target_update: int = 10):
    # Initialize replay memory
    memory = ReplayMemory(memory_capacity)

    state = env.reset()
    state_size = len(state)
    action_size = env.max_waitlist_members
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    policy_net = DQN(state_size, action_size).to(device)
    target_net = DQN(state_size, action_size).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()
    
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    epsilon = epsilon_start
    steps_done = 0
    
    for ep in range(episodes):
        state = env.reset()
        done = False
        episode_reward = 0.0
        
        while not done:
            steps_done += 1

            # The actual available actions on this step is equal to the current day's recipients
            actions = np.arange(action_size)
            valid_allocations = env.get_valid_actions()
            invalid_allocations = [i for i in actions if i not in valid_allocations]
            
            # Epsilon-greedy action selection
            if random.random() < epsilon:
                # Random action from the valid range of recipients
                action = random.choice(valid_allocations)
            else:
                with torch.no_grad():
                    q_values = policy_net(torch.tensor([state], dtype=torch.float32))
                    # Mask out invalid actions by setting Q-values of invalid actions to a large negative number
                    q_values[0, invalid_allocations] = -1e9
                    action = q_values.argmax(dim=1).item()

            next_state, reward, done, info = env.step(action)
            episode_reward += reward
            
            memory.push(state, action, reward, next_state, done)
            state = next_state

            # Train if memory is sufficient
            if False: #len(memory) > batch_size:
                states_b, actions_b, rewards_b, next_states_b, dones_b = memory.sample(batch_size)
                
                states_t = torch.tensor(states_b, dtype=torch.float32, device=device)
                actions_t = torch.tensor(actions_b, device=device).unsqueeze(1)
                rewards_t = torch.tensor(rewards_b, dtype=torch.float32, device=device)
                next_states_t = torch.tensor(next_states_b, dtype=torch.float32, device=device)
                dones_t = torch.tensor(dones_b, dtype=torch.bool, device=device)
                
                # Compute Q(s,a)
                q_values = policy_net(states_t).gather(1, actions_t)
                
                # Compute max Q(s',a') from target net
                next_q_values_all = target_net(next_states_t)
                # For next states, we need to consider that the action space might be different.
                # If you know next_states_t corresponds to a certain day with fewer recipients,
                # you would mask out invalid actions similarly as during selection.
                # For simplicity, assume max_action_size recipients are always possible.
                # (If not, you must store the action space size per transition and mask accordingly.)
                
                next_q_values = next_q_values_all.max(1)[0]
                next_q_values[dones_t] = 0.0
                
                target = rewards_t + gamma * next_q_values
                
                loss = nn.MSELoss()(q_values.squeeze(), target)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Decay epsilon
            epsilon = epsilon_end + (epsilon_start - epsilon_end) * np.exp(-steps_done / epsilon_decay)
        
        # Update target network
        if ep % target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())
        
        print(f"Episode {ep+1}/{episodes}, Reward: {episode_reward:.2f}, Epsilon: {epsilon:.2f}")
    
    return policy_net

In [229]:
env = OrganMatchingEnv(max_days=3)
trained_policy = train_dqn(env, episodes=2)

  from .autonotebook import tqdm as notebook_tqdm
INFO:root:Chose recipient with reward 132222 : 15218.0
INFO:root:Donor 183045 is the end of donors
INFO:root:Moving onto next day 2002-04-03 00:00:00 with donor index 0 for : 183045
INFO:root:Chose recipient with reward 388101 : 4352.0
INFO:root:Donor 319621 is the end of donors
INFO:root:Moving onto next day 2002-04-07 00:00:00 with donor index 0 for : 319621
INFO:root:Chose recipient with reward 157944 : 7526.0
INFO:root:Donor 183293 is the end of donors
INFO:root:Moving onto next day 2002-04-12 00:00:00 with donor index 0 for : 183293
INFO:root:Chose recipient with reward 267319 : 766.0
INFO:root:Moving onto next day 2002-04-24 00:00:00 with donor index 0 for : 185424


Episode 1/2, Reward: 27862.00, Epsilon: 0.99


INFO:root:Chose recipient with reward 407356 : 60.0
INFO:root:Moving onto next day 2002-04-03 00:00:00 with donor index 0 for : 183045


Episode 2/2, Reward: 60.00, Epsilon: 0.99
