In [1]:
%matplotlib notebook
import pandas as pd
import pandasql as pds
from pandasql import sqldf
import plotly.graph_objects as go
import matplotlib.pyplot as plt
pd.set_option('display.max_rows', 50)
pd.set_option('display.min_rows', 50)
import seaborn as sns

In [None]:
data_raw = pd.read_csv(r"C:\Users\MichaelDiFelice\Documents\Training\snowflake\Sankey Data\Sankey Data.csv")


data_sorted = data_raw.drop_duplicates(subset=['PATIENTID', 'LINENUMBER'])

# Sort the data by PATIENTID and LINENUMBER to ensure the order
data_sorted = data_sorted.sort_values(by=["PATIENTID", "LINENUMBER"])

# Shift the REGIMEN column to get the next regimen for each patient and line
data_sorted["NEXT_REGIMEN"] = data_sorted.groupby("PATIENTID")["REGIMEN"].shift(-1)
data_sorted['PREVIOUS_REGIMEN'] = data_sorted.groupby("PATIENTID")["REGIMEN"].shift(1)
# Fill NaN values in "NEXT_REGIMEN" with "No Advancement"
data_sorted["NEXT_REGIMEN"].fillna("No Advancement", inplace=True)
data_sorted["PREVIOUS_REGIMEN"].fillna("", inplace=True)


# If NEXT_REGIMEN is 'No Advancement', set NEXT_LINENUMBER to the same as LINENUMBER
# Otherwise, shift the LINENUMBER column to get the next line number for each patient
data_sorted["NEXT_LINENUMBER"] = data_sorted.apply(
    lambda row: row["LINENUMBER"] if row["NEXT_REGIMEN"] == "No Advancement" else row["LINENUMBER"] + 1, axis=1
)

# Filter out rows where the next line is not consecutive (e.g., line 1 followed by line 3) or 'No Advancement'
data_sorted = data_sorted[(data_sorted["LINENUMBER"] + 1 == data_sorted["NEXT_LINENUMBER"]) | (data_sorted["NEXT_REGIMEN"] == "No Advancement")]

# For our current task, we only need certain columns
data = data_sorted[["PATIENTID", "LINENUMBER", "REGIMEN","PREVIOUS_REGIMEN", "NEXT_REGIMEN", "LEN_FLAG", "CD38_FLAG", "CD38_EXPOSED_FLAG", "TRANSPLANT_FLAG", "START_YEAR","LEN_REFRACTORY_FLAG","NEXT_LINENUMBER"]]
data = data[data['LINENUMBER'].isin([1, 2, 3, 4])]

# Concatenate LINENUMBER to REGIMEN and NEXT_REGIMEN columns
data['REGIMEN'] = data['REGIMEN'].astype(str) + data['LINENUMBER'].astype(str)
data['NEXT_REGIMEN'] = data['NEXT_REGIMEN'].astype(str) + (data['LINENUMBER'] + 1).astype(str)
data['PREVIOUS_REGIMEN'] = data['PREVIOUS_REGIMEN'].astype(str) + (data['LINENUMBER'] - 1).astype(str)
data['PREVIOUS_REGIMEN'] = data['PREVIOUS_REGIMEN'].str.replace('0', '')


data.to_clipboard()