In [None]:
!pip install gym

In [None]:
%load_ext autoreload
%autoreload 2
from src.utils import *


In [91]:
import traceback

all_transplants, _ = read_organ_data()

In [96]:
import os
import pandas as pd

data_directory = 'data'

for filename in os.listdir(data_directory):
    if filename.endswith('waitlist.csv'):
        file_path = os.path.join(data_directory, filename)
        
        waitlist_df = pd.read_csv(file_path)
        
        # For each RECIPIENT_ID, prioritize non-NA DIAGNOSIS_CODE and FUNCTIONAL_CODE from all_transplants
        all_transplants_filtered = all_transplants.dropna(subset=[Column.DIAGNOSIS_CODE.name, Column.FUNCTIONAL_CODE.name])
        all_transplants_filtered = all_transplants_filtered.sort_values(by=Column.FUNCTIONAL_CODE.name, ascending=False)
        all_transplants_filtered = all_transplants_filtered.drop_duplicates(subset=[Column.RECIPIENT_ID.name], keep='first')
        
        waitlist_df = waitlist_df.merge(
            all_transplants_filtered[[Column.RECIPIENT_ID.name, Column.DIAGNOSIS_CODE.name, Column.FUNCTIONAL_STATUS_AT_REGISTRATION.name, Column.FUNCTIONAL_CODE.name]],
            on=Column.RECIPIENT_ID.name,
            how='left'
        )
        # Save the updated DataFrame back to the CSV file
        waitlist_df.to_csv(file_path, index=False)

In [43]:
all_transplants = all_transplants.dropna(subset=[Column.DONOR_ID.name, Column.ORGAN_TRANSPLANT_ID.name, Column.ORGAN_RECOVERY_DATE.name, Column.TRANSPLANT_DATE.name, Column.END_DATE.name]).sort_values(by=Column.TRANSPLANT_DATE.name)
# all_transplants[[Column.DONOR_ID.name, Column.ORGAN_TRANSPLANT_ID.name, Column.ORGAN_RECOVERY_DATE.name, Column.TRANSPLANT_DATE.name, Column.END_DATE.name]].sort_values(by=Column.TRANSPLANT_DATE.name)

In [None]:
import matplotlib.pyplot as plt

# Filter data for the years 2002-2004
filtered_transplants = all_transplants[
    (all_transplants[Column.ORGAN_RECOVERY_DATE.name] >= '2020-01-01') & 
    (all_transplants[Column.ORGAN_RECOVERY_DATE.name] <= '2023-12-31')
]

# Extract year and month from ORGAN_RECOVERY_DATE
filtered_transplants['Year'] = filtered_transplants[Column.ORGAN_RECOVERY_DATE.name].dt.year
filtered_transplants['Month'] = filtered_transplants[Column.ORGAN_RECOVERY_DATE.name].dt.month

# Plot distribution by year
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
filtered_transplants['Year'].value_counts().sort_index().plot(kind='bar')
plt.title('Distribution of ORGAN_RECOVERY_DATE by Year (2002-2004)')
plt.xlabel('Year')
plt.ylabel('Count')

# Plot distribution by month
plt.subplot(1, 2, 2)
filtered_transplants['Month'].value_counts().sort_index().plot(kind='bar')
plt.title('Distribution of ORGAN_RECOVERY_DATE by Month (2002-2004)')
plt.xlabel('Month')
plt.ylabel('Count')

plt.tight_layout()
plt.show()

In [None]:
highest_rows_per_recovery_date = all_transplants.groupby(Column.ORGAN_RECOVERY_DATE.name).size().max()
highest_rows_per_recovery_date


In [97]:
date = pd.Timestamp('2020-04-03')
available_organs = get_mininal_columns_available_organs(by_date=date)

In [None]:
available_organs

In [69]:
waitlist_members = get_mininal_columns_waitlist(by_date=date)

In [None]:
waitlist_members.head(10)

In [98]:
date, [daily_organs, daily_waitlist_members] = get_next_day(date, allocated_ids =[])

In [None]:
daily_waitlist_members

In [None]:
WAITLIST_FEATURES = [
        Column.RECIPIENT_ID.name, 
        Column.INIT_MELD_PELD_LAB_SCORE.name, 
        Column.DAYS_ON_WAITLIST.name,
        Column.RECIPIENT_BLOOD_TYPE.name,
        Column.RECIPIENT_BLOOD_TYPE_AS_CODE.name,
]
sample_waitlist_members = waitlist_members[WAITLIST_FEATURES].head(5)
sample_waitlist_members

In [43]:
def build_dataset(current_date: pd.Timestamp, end_date: pd.Timestamp):
    date = current_date
    while date <= end_date:
        date, [organs, waitlist_members] = get_next_day(date, allocated_ids=[], max_waitlist=1500)
        
        # Save organs and waitlist members to files
        organs_filename = f"data/{date.strftime('%Y-%m-%d')}_organs.csv"
        waitlist_filename = f"data/{date.strftime('%Y-%m-%d')}_waitlist.csv"
        
        organs.to_csv(organs_filename, index=False)
        waitlist_members.to_csv(waitlist_filename, index=False)
        
        date += pd.Timedelta(days=1)

In [None]:
build_dataset(pd.Timestamp('01-01-20'), pd.Timestamp('05-31-21')) #pd.Timestamp('12-31-20'))

In [40]:
def read_dataframes_for_dates(start_date: pd.Timestamp, end_date: pd.Timestamp) -> List[List[pd.DataFrame]]:
    date = start_date
    dataframes_list = []
    
    while date <= end_date:
        organs, waitlist = read_from_file(date)
        dataframes_list.append((organs, waitlist))
        date += pd.Timedelta(days=1)
    
    return dataframes_list

In [None]:
dataframes = read_dataframes_for_dates(pd.Timestamp('01-01-20'), pd.Timestamp('01-3-20'))
dataframes

In [17]:
import logging
logging.disable(logging.CRITICAL)
# logging.basicConfig(level=logging.INFO)

In [18]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random

class DQN(nn.Module):
    def __init__(self, state_shape, action_size, hidden_size=64):
        super(DQN, self).__init__()
        # Flatten the state shape to get the input size for the first layer
        input_size = state_shape[0] * state_shape[1]
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_size)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)
        
    def push(self, state, action, reward, next_state, done, info):
        self.memory.append((state, action, reward, next_state, done, info))
        
    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones, infos = zip(*batch)
        return (np.array(states, dtype=np.float32),
                np.array(actions, dtype=np.int64),
                np.array(rewards, dtype=np.float32),
                np.array(next_states, dtype=np.float32),
                np.array(dones, dtype=np.bool_),
                infos)
        
    def __len__(self):
        return len(self.memory)

In [148]:
class OrganMatchingEnv:
    def __init__(self, max_waitlist_members: int = 100, max_days: int = 3, start_date_month_bound: int = 1, end_date_month_bound: int = 3, is_test=False):
        """
        Args:
            max_waitlist_members: The number of max people to consider on the waitlist (it is sorted by MELD and days on waitlist)                             
            max_days: The horizon over which we measure performance or run the simulation.

        We assume:
        - Each day we have multiple donors and multiple recipients.
        - Each step in the environment involves selecting a recipient for the **current donor**.
          After all donors for that day are processed, we move to the next day.
        """
        self.start_date_month_bound = start_date_month_bound
        self.end_date_month_bound = end_date_month_bound
        self.max_waitlist_members = max_waitlist_members
        self.max_days = max_days
        self.is_test = is_test

        # These will be set at reset
        self.reset()
    
    def _get_start_date(self):
        return pd.Timestamp(f'2020-{random.randint(self.start_date_month_bound, self.end_date_month_bound):02d}-{random.randint(1, 28):02d}')

    def reset(self):
        self.total_days_allocated = -1
        self.initial_date = self._get_start_date() if not self.is_test else pd.Timestamp('2020-11-1')
        self.current_day_idx = self.initial_date
        self.current_donor_idx = 0
        self.donor_ids_allocated = set()
        self.refetch_data(self.initial_date)
        return self._get_state()
    
    def refetch_data(self, date: pd.Timestamp):
        logging.info(f'Refetching data for date: {date}. Allocated donor IDs: {self.donor_ids_allocated}')
        self.current_day_idx, data = get_next_day(date, allocated_ids=self.donor_ids_allocated, max_waitlist=self.max_waitlist_members)
        self.available_donor_df = data[0]
        self.waitlist_member_df = data[1]
        # Check if there are any donors left to allocate
        available_donors = self.available_donor_df[Column.DONOR_ID.name].tolist()
        unallocated_donors = [i for i, donor_id in enumerate(available_donors) if donor_id not in self.donor_ids_allocated]
        if len(unallocated_donors) == 0:
            logging.info(f'WARNING: No unallocated donors available for date: {self.current_day_idx}')
            self.refetch_data(self.current_day_idx + pd.Timedelta(days=1))
        else:
            self.current_donor_idx = unallocated_donors[0]
        
        if date < self.current_day_idx:
            logging.info(f'Moving onto next day {self.current_day_idx} with donor index {self.current_donor_idx} for donor ID: {self._get_donor()}')
        else: 
            logging.info(f'Staying on {self.current_day_idx} to keep allocating for donor idx {self.current_donor_idx} because we have the remaining donors: {self._get_remaining_donors()}')
    
    def _get_valid_actions_for_df(self, donor_id: int, organs: pd.DataFrame, waitlist: pd.DataFrame):
        if donor_id > len(organs):
            logging.info(f'ERROR: {donor_id} > length of organs {len(organs)} on day {self.current_day_idx}')
        donor = organs.iloc[donor_id]
        donor_blood_type = donor[Column.DONOR_BLOOD_TYPE.name]
        valid_allocations = [i for i in range(len(waitlist)) if get_match_value(donor_blood_type=donor_blood_type, recipient_blood_type=waitlist.iloc[i][Column.RECIPIENT_BLOOD_TYPE.name]) >= 1.0]
        return valid_allocations
    
    def _get_remaining_donors(self):
        return self.available_donor_df[Column.DONOR_ID.name].tolist()

    def get_valid_actions(self):
        return self._get_valid_actions_for_df(self.current_donor_idx, self.available_donor_df, self.waitlist_member_df)

    def _get_state(self):
        self.refetch_data(self.current_day_idx)
        encoded_array = np.array([
            list(row) for row in zip(self.waitlist_member_df[Column.INIT_MELD_PELD_LAB_SCORE.name], 
                                    self.waitlist_member_df[Column.DAYS_ON_WAITLIST.name], 
                                    self.waitlist_member_df[Column.RECIPIENT_BLOOD_TYPE_AS_CODE.name],
                                    self.waitlist_member_df[Column.DIAGNOSIS_CODE.name],
                                    self.waitlist_member_df[Column.FUNCTIONAL_CODE.name])
        ])
        state = np.array(encoded_array, dtype=np.float32)
        return state
    
    def _get_donor(self):
        return int(self.available_donor_df.iloc[self.current_donor_idx][Column.DONOR_ID.name])
    
    def _get_chosen_recipient(self, action: int):
        if action > len(self.waitlist_member_df):
            return pd.Series.empty()
        return self.waitlist_member_df.iloc[action]

    def step(self, action: int):
        """
        Args:
            action: index corresponding to the chosen recipient in waitlist_member_df.
                    This means action is in [0, max_waitlist_members_to_consider].
        
        Returns:
            next_state: state after this allocation
            reward: float, e.g. graft lifespan or related metric
            done: bool, whether the episode ended
            info: dict, extra info
        """
        chosen_recipient = self._get_chosen_recipient(action)
        self.total_days_allocated += 1
        self.donor_ids_allocated.add(self._get_donor())
        self.donor_ids_allocated.add(int(chosen_recipient[Column.DONOR_ID.name]))
        reward = 0.0
        if chosen_recipient.empty:
            reward = -100
        else:  
            # It doesn't matter who they got the organ from. Just include the rewards of this person getting an organ.
            if chosen_recipient[Column.GRAFT_LIFESPAN.name] is not None:
                reward += float(chosen_recipient[Column.GRAFT_LIFESPAN.name])
            if chosen_recipient[Column.GRAFT_LIFESPAN.name] is not None:
                reward += float(chosen_recipient[Column.PATIENT_SURVIVAL_TIME.name])
        chosen_donor = self._get_donor()
        logging.info(f'''Chose recipient ID: {chosen_recipient[Column.RECIPIENT_ID.name]} for donor {chosen_donor} with : {reward}''')

        reward = np.log10(reward) if reward > 0 else 0
        done = (self.total_days_allocated >= self.max_days)
        valid_allocations = []
        invalid_allocations = []
        if not done:
            next_state = self._get_state()
            next_actions = np.arange(self.max_waitlist_members)
            valid_allocations = self.get_valid_actions()
            invalid_allocations = [i for i in next_actions if i not in valid_allocations]
        else:
            next_state = np.zeros_like(self._get_state())

        return next_state, reward, done, {'next_valid_allocations': valid_allocations, 'next_invalid_allocations': invalid_allocations}

In [None]:
done = False
env = OrganMatchingEnv()
env.reset()
total_reward = 0
while not done:
    valid_actions = env.get_valid_actions()
    next_state, reward, done, _ = env.step(random.choice(valid_actions))
    total_reward += reward
    print(f'Reward: {total_reward} and reached done? {done}')

In [103]:
def get_filename(lr: float, batch_size: int, waitlist_members: int, max_days: int, ep: int, rewards: bool = False):
    return f"{'rewards' if rewards else 'policy_net'}_lr_{lr}_bs_{batch_size}_ws_{waitlist_members}_ma_{max_days}_ep{ep if not rewards else ''}.{'csv' if rewards else 'pkl'}"

def plot_rewards(all_rewards, batch_size, max_waitlist_members, max_days):
    window_size = 5  # Adjust the window size for smoothing
    smoothed_rewards = np.convolve(all_rewards, np.ones(window_size) / window_size, mode='valid')
    # Plot original and smoothed rewards
    plt.plot(all_rewards, label='Original', alpha=0.4)
    plt.plot(range(len(smoothed_rewards)), smoothed_rewards, label='Smoothed', color='red')
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.title(f"Reward per Episode (batch_size={batch_size}, max_waitlist_members={max_waitlist_members}, max allocations={max_days})")
    plt.legend()
    plt.savefig(f"bs_{batch_size}_ws_{max_waitlist_members}_ma_{max_days}.png")
    plt.show()

In [116]:
def test_dqn(env: OrganMatchingEnv,
             policy_net: DQN,
             episodes: int = 10) -> float:
    """
    Run inference using the trained policy_net on the given test environment.
    Returns the average reward over the given number of test episodes.
    """
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    
    # Set the policy_net to evaluation mode
    policy_net.eval()

    total_reward = 0.0
    chosen_patients = pd.DataFrame()

    for ep in range(episodes):
        state = env.reset()
        done = False
        episode_reward = 0.0

        while not done:
            # The available actions on this step
            action_size = env.max_waitlist_members
            actions = np.arange(action_size)
            valid_allocations = env.get_valid_actions()
            invalid_allocations = [i for i in actions if i not in valid_allocations]

            # Choose action greedily based on Q-values
            with torch.no_grad():
                q_values = policy_net(torch.tensor([state], dtype=torch.float32, device=device).view(-1))
                # Mask out invalid actions
                q_values[invalid_allocations] = -1e9
                action = q_values.argmax().item()
                chosen_patients = pd.concat([chosen_patients, env.waitlist_member_df.iloc[[action]]], ignore_index=True)

            next_state, reward, done, info = env.step(action)
            episode_reward += reward
            state = next_state

        total_reward += episode_reward
        print(f"Test Episode {ep+1}/{episodes}, Current date: {env.current_day_idx}, Reward: {episode_reward:.2f}")

    avg_reward = total_reward / episodes
    print(f"Average Test Reward over {episodes} episodes: {avg_reward:.2f}")
    return avg_reward, chosen_patients

In [106]:
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle

def train_dqn(env: OrganMatchingEnv,
              episodes: int = 100,
              gamma: float=0.99,
              epsilon_start: float = 1.0,
              epsilon_end: float = 0.01,
              epsilon_decay: int = 500,
              lr: float = 1e-3,
              batch_size: int = 128,
              memory_capacity: int = 1000,
              target_update: int = 200):
    # Initialize replay memory
    memory = ReplayMemory(memory_capacity)

    state = env.reset()
    state_size = state.shape
    action_size = env.max_waitlist_members
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    policy_net = DQN(state_size, action_size).to(device)
    target_net = DQN(state_size, action_size).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()
    
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    epsilon = epsilon_start
    steps_done = 0
    all_rewards = []

    for ep in tqdm(range(episodes), desc="Training Episodes"):
        state = env.reset()
        done = False
        episode_reward = 0.0
        
        while not done:
            steps_done += 1

            # The actual available actions on this step is equal to the current day's recipients
            actions = np.arange(action_size)
            valid_allocations = env.get_valid_actions()
            invalid_allocations = [i for i in actions if i not in valid_allocations]
            
            # Epsilon-greedy action selection
            if random.random() < epsilon:
                # Random action from the valid range of recipients
                action = random.choice(valid_allocations)
            else:
                with torch.no_grad():
                    q_values = policy_net(torch.tensor([state], dtype=torch.float32, device=device).view(-1))
                    # Mask out invalid actions by setting Q-values of invalid actions to a large negative number
                    # print(q_values.shape)
                    # print(invalid_allocations.shape)
                    q_values[invalid_allocations] = -1e9
                    action = q_values.argmax().item()

            next_state, reward, done, info = env.step(action)
            episode_reward += reward
            
            memory.push(state, action, reward, next_state, done, info)
            state = next_state

            # Train if memory is sufficient
            if len(memory) >= batch_size:
                states_b, actions_b, rewards_b, next_states_b, dones_b, infos = memory.sample(batch_size)
                
                states_t = torch.tensor(states_b, dtype=torch.float32, device=device)
                # size (batch size)
                actions_t = torch.tensor(actions_b, device=device).unsqueeze(1)
                rewards_t = torch.tensor(rewards_b, dtype=torch.float32, device=device)
                next_states_t = torch.tensor(next_states_b, dtype=torch.float32, device=device)
                dones_t = torch.tensor(dones_b, dtype=torch.bool, device=device)
                
                # Compute Q(s,a)
                q_values = policy_net(states_t.view(batch_size, -1)).gather(1, actions_t)
                
                # Compute max Q(s',a') from target net
                next_q_values_all = target_net(next_states_t.view(batch_size, -1))

                # next_q_values_all is (batch_size, action_size)
                for batch_idx, info_entry in enumerate(infos):
                    invalid_allocations = info_entry['next_invalid_allocations']
                    next_q_values_all[batch_idx, invalid_allocations] = -1e9

                next_q_values = next_q_values_all.max(1)[0]
                next_q_values[dones_t] = 0.0
                
                target = rewards_t + gamma * next_q_values
                
                loss = nn.MSELoss()(q_values.squeeze(), target)
            
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Decay epsilon
            epsilon = epsilon_end + (epsilon_start - epsilon_end) * np.exp(-steps_done / epsilon_decay)
        
        # Update target network
        if ep % target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())

        all_rewards.append(episode_reward)
        if ep > 50 and episode_reward >= (max(all_rewards) if len(all_rewards) > 0 else 0):    
            # Save the policy net to a pickle file
            with open(get_filename(lr, batch_size, env.max_waitlist_members, env.max_days, ep, rewards=False), 'wb') as f:
                torch.save(policy_net.state_dict(), f)
            # Save the rewards to a CSV file
            rewards_df = pd.DataFrame(all_rewards, columns=['Reward'])
            rewards_df.to_csv(get_filename(lr, batch_size, env.max_waitlist_members, env.max_days, ep, rewards=True), index=False)
            plot_rewards(all_rewards, batch_size, env.max_waitlist_members, env.max_days)

        print(f"Episode {ep+1}/{episodes}, Reward: {episode_reward:.2f}, Date: {env.current_day_idx} Epsilon: {epsilon:.2f}")
    return policy_net, all_rewards

In [None]:
max_waitlist_members = 250
max_days = 100
episodes = 125
env = OrganMatchingEnv(max_waitlist_members=max_waitlist_members, max_days=max_days)
trained_policy, rewards = train_dqn(env,batch_size=128, episodes=episodes)

In [None]:
max_waitlist_members = 250
# test_policy_net = DQN((max_waitlist_members, 5), max_waitlist_members)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# filename = "policy_net_lr_0.001_bs_128_ws_250_ma_100_ep160.pkl" 

# with open(filename, 'rb') as f:
#     state_dict = torch.load(f, map_location=device)
#     test_policy_net.load_state_dict(state_dict)

# Move the model to the device and set to eval mode
test_policy_net = trained_policy
test_policy_net.to(device)
test_policy_net.eval()
test_env = OrganMatchingEnv(max_waitlist_members=max_waitlist_members, max_days=500, is_test=True)
print(f"Starting at {test_env.initial_date}")
average_test_reward, chosen_patients = test_dqn(test_env, test_policy_net, episodes=1)
print(f"Reward: {average_test_reward}")
chosen_patients

1566

In [162]:
def compare_to_baseline(current_date: pd.Timestamp, end_date: pd.Timestamp, max_allocations: int):
    total_rewards = []
    all_filtered_patients = pd.DataFrame()

    date = current_date
    allocated_ids = []
    while date <= end_date and len(allocated_ids) < max_allocations:
        # Get the next day's data
        waitlist_members = get_mininal_columns_waitlist(by_date=date)     
        filtered_waitlist_members = waitlist_members[waitlist_members[Column.TRANSPLANT_DATE.name] == date]
        num_to_allocate = max_allocations - len(allocated_ids)
        filtered_waitlist_members = filtered_waitlist_members.head(num_to_allocate) if len(filtered_waitlist_members) > num_to_allocate else filtered_waitlist_members
        allocated_ids.extend(filtered_waitlist_members[Column.DONOR_ID.name].to_list())
        
        # For each row, add graft and patient survival time, take np.log10 and add to total_reward
        for _, row in filtered_waitlist_members.iterrows():
            graft_and_patient_survival = row[Column.GRAFT_LIFESPAN.name] + row[Column.PATIENT_SURVIVAL_TIME.name]
            if graft_and_patient_survival > 0:
                total_rewards.append(np.log10(graft_and_patient_survival))
        
        # Append the filtered waitlist members to the all_filtered_patients DataFrame
        all_filtered_patients = pd.concat([all_filtered_patients, filtered_waitlist_members], ignore_index=True)
        
        date += pd.Timedelta(days=1)

    return sum(total_rewards), all_filtered_patients

In [None]:
import matplotlib.pyplot as plt

def compare_along(column: Column, label: str, filename: str):
    # Assuming 'chosen_patients' and 'baseline_patients' are DataFrames with a column 'INIT_MELD_PELD_LAB_SCORE'
    plt.figure(figsize=(12, 6))

    # Create a side-by-side chart for chosen and baseline patients
    chosen_scores = chosen_patients[column.name]
    baseline_scores = baseline_patients[column.name]

    # Define bins
    bins = range(int(min(chosen_scores.min(), baseline_scores.min())), int(max(chosen_scores.max(), baseline_scores.max())) + 1)

    # Plot side-by-side histogram
    n, bins, patches = plt.hist([chosen_scores, baseline_scores], bins=bins, color=['skyblue', 'salmon'], edgecolor='black', stacked=False, label=['Chosen Patients', 'Baseline Patients'])

    # Set background color for each histogram
    for patch, color in zip(patches, ['skyblue', 'salmon']):
        for rect in patch:
            rect.set_facecolor(color)

    plt.title(f'Side-by-Side Comparison of {label} for Chosen and Baseline Patients')
    plt.xlabel(label)
    plt.ylabel('Frequency')
    plt.grid(axis='y', alpha=0.75)
    plt.legend()  # Add legend
    # Save the plot to a file with the same name as 'filename' but with a .png extension
    plt.savefig(f"{label}_{filename.rsplit('.', 1)[0]}.png")

baseline_rewards, baseline_patients = compare_to_baseline(pd.Timestamp('2020-07-01'), pd.Timestamp('2020-08-01'), max_allocations = 501)
print(f"Baseline rewards: {baseline_rewards}")
baseline_patients
compare_along(Column.INIT_MELD_PELD_LAB_SCORE, 'MELD', filename)
compare_along(Column.FUNCTIONAL_CODE, 'Functional Status', filename)
compare_along(Column.RECIPIENT_AGE, 'Recipient Age', filename)

In [None]:
import os
import pandas as pd

# Initialize a dictionary to store rewards data
rewards_data = {}

# Iterate over files in the current directory
for filename in os.listdir('.'):
    if filename.startswith('rewards_') and filename.endswith('.csv'):
        # Extract parameters from the filename
        parts = filename.split('_')
        lr = parts[2]  # Correct index for 'lr'
        ws = parts[6]  # Correct index for 'ws'
        ma = parts[8]  # Correct index for 'ma'
        
        # Create a label for the legend
        label = f"lr: {lr}, waitlist size: {ws}, max allocations: {ma}"
        
        # Read the CSV file
        df = pd.read_csv(filename)
        
        # Store the rewards data with the label
        rewards_data[label] = df['Reward']

# Plot the rewards data
plt.figure(figsize=(12, 6))
for label, rewards in rewards_data.items():
    plt.plot(rewards, label=label)

plt.title('Rewards Comparison Across Different Parameters')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.legend()
plt.grid(True)
plt.show()


In [160]:
with open(get_filename(0.001, 128, env.max_waitlist_members, env.max_days, 200, rewards=False), 'wb') as f:
    torch.save(trained_policy.state_dict(), f)