In [60]:
import os
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import seaborn as sns
from joblib import load
from tqdm import tqdm
import plotly.express as px
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.nn import LSTM
from sklearn.datasets import make_blobs
from sklearn.manifold import TSNE
from transformers import AutoTokenizer, AutoModel
from multimodal import MultimodalClassifierDataset, collation

In [21]:
class LOSNetWeighted(nn.Module):
    '''
    time_series_model: expects an input of packed padded sequences
    text_model: expects an input of dict with keys {'input_ids', 'token_type_ids', 'attention_mask'}
                of tokenized sequences
    '''

    def __init__(
            self, dynamic_input_size, static_input_size, out_features,
            hidden_size, text_model=None,
            decay_factor=0.1, batch_first=True,
            task='reg', **kwargs
    ):
        assert (task == 'reg' or task == 'cls'), 'task must be either `reg` or `cls`'

        super(LOSNetWeighted, self).__init__(**kwargs)
        self.decay_factor = decay_factor
        self.task = task

        self.time_series_model = LSTM(input_size=dynamic_input_size, hidden_size=hidden_size,
                                      batch_first=batch_first)
        self.ht_layer_norm = nn.LayerNorm(normalized_shape=hidden_size)

        self.text_model = text_model if text_model is not None \
            else AutoModel.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'mps')

        self.text_model = torch.nn.DataParallel(self.text_model, device_ids=[0,1,2,3])

        self.fc1 = nn.Sequential(

            nn.Linear(in_features=hidden_size + 768 + static_input_size,
                      out_features=256,
                      bias=True),
            nn.LayerNorm(normalized_shape=256),
            nn.ReLU(),


            nn.Linear(in_features=256, out_features=128, bias=True),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(in_features=128, out_features=64, bias=True),
            nn.LayerNorm(64),
            nn.ReLU(),
        )

        self.fc2 = nn.Linear(in_features=64,
                             out_features=out_features,
                             bias=True)

    def weighted_sum(self, embeddings, interval, decay_factor):
        device = embeddings.device
        weights = ((1 - decay_factor) ** interval).to(device)
        weighted_sum = torch.matmul(weights, embeddings)

        return weighted_sum

    def forward(self, packed_dynamic_X_batch, notes_X_batch, notes_intervals_batch, static_batch):
        _, (ht, _) = self.time_series_model(packed_dynamic_X_batch)
        ht = ht[-1]
        ht = self.ht_layer_norm(ht).to(self.device)

        embeddings = []
        for (patient_notes, notes_interval) in zip(notes_X_batch, notes_intervals_batch):
            patient_embeddings = self.text_model(**patient_notes).pooler_output
            weighted_sum = self.weighted_sum(embeddings=patient_embeddings, interval=notes_interval,
                                             decay_factor=self.decay_factor)
            embeddings.append(weighted_sum)

        zt = torch.stack(embeddings)
        zt = zt.to(self.device)

        st = static_batch
        st = st.to(self.device)

        combined_representation = torch.cat((ht, zt, st), dim=1)

        fc1_out = self.fc1(combined_representation)
        logits = self.fc2(fc1_out)
        y_pred = logits if self.task == 'reg' else F.softmax(logits, dim=-1)

        return y_pred, combined_representation

In [22]:
base_path = '../data/split/with-outliers/combined/one-hot-encoded'

static_val = pd.read_csv(f'{base_path}/static_val.csv')
static_test = pd.read_csv(f'{base_path}/static_test.csv')

In [23]:
to_drop = ['los_icu', 'icu_death']

to_scale = [
    'admission_age',
    'weight_admit',
    'charlson_score',
 ]

In [24]:
feature_cols = [col for col in static_val.select_dtypes(include=[np.number]).columns.tolist() if col not in to_drop]

static_val = static_val[feature_cols]
static_test = static_test[feature_cols]

In [25]:
static_scaler = load('../scalers/static_scaler.joblib')

static_val[to_scale] = static_scaler.transform(static_val[to_scale])
static_test[to_scale] = static_scaler.transform(static_test[to_scale])

In [26]:
dynamic = pd.read_csv('../data/dynamic_cleaned.csv')

dynamic_val = dynamic[dynamic['id'].isin(static_val['id'])].copy()
dynamic_test = dynamic[dynamic['id'].isin(static_test['id'])].copy()

In [27]:
def truncate_and_average(df, id_col, max_records=4):
    df_sorted = df.sort_values(by=[id_col, 'charttime'])

    def process_group(group):
        if len(group) > max_records:
            average_data = group.iloc[:-max_records].drop(columns=['charttime']).mean().to_dict()
            average_data[id_col] = group[id_col].iloc[0]
            average_row = pd.DataFrame([average_data])

            return pd.concat([average_row, group.tail(max_records)], ignore_index=True)
        else:
            return group

    return df_sorted.groupby(id_col).apply(process_group).reset_index(drop=True)


In [28]:
dynamic_val = truncate_and_average(dynamic_val, 'id')
dynamic_test = truncate_and_average(dynamic_test, 'id')

  return df_sorted.groupby(id_col).apply(process_group).reset_index(drop=True)
  return df_sorted.groupby(id_col).apply(process_group).reset_index(drop=True)


In [29]:
dynamic_val.groupby('id').size().describe()

count    1940.000000
mean        4.059278
std         0.883691
min         3.000000
25%         3.000000
50%         4.000000
75%         5.000000
max         5.000000
dtype: float64

In [30]:
dynamic_test.groupby('id').size().describe()

count    776.000000
mean       4.103093
std        0.898567
min        3.000000
25%        3.000000
50%        4.000000
75%        5.000000
max        5.000000
dtype: float64

In [31]:
features = ['aniongap', 'bicarbonate', 'bun', 'calcium', 'chloride', 'creatinine', 'glucose', 'sodium', 'potassium']

dynamic_scaler = load('../scalers/dynamic_scaler.joblib')

dynamic_val[features] = dynamic_scaler.transform(dynamic_val[features])
dynamic_test[features] = dynamic_scaler.transform(dynamic_test[features])

dynamic_val.head()

Unnamed: 0,id,charttime,aniongap,bicarbonate,bun,calcium,chloride,creatinine,glucose,sodium,potassium
0,20003425,7/21/55 23:27,-1.612645,1.20301,-0.388377,1.142891,-0.713011,-0.631568,-0.426633,-0.605924,0.839394
1,20003425,7/22/55 18:11,-0.616936,-0.241383,-0.499739,0.286751,-0.44331,-0.58516,-0.221369,-0.777459,0.839394
2,20003425,7/23/55 2:19,-1.01522,-0.241383,-0.499739,0.179733,-0.173609,-0.631568,0.116714,-0.777459,0.839394
3,20008098,,0.737227,-0.349712,-0.492315,-0.077109,-1.090592,-0.566597,-0.119944,-1.360676,0.357591
4,20008098,2/9/75 12:40,0.577914,0.119715,-0.722465,0.500786,-1.387264,-0.631568,-0.330038,-1.463597,0.712604


### Dynamic val preprocessing

In [32]:
id_lengths_val = dynamic_val['id'].value_counts().to_dict()
dynamic_val = dynamic_val.sort_values(by=['id', 'charttime'])
dynamic_val = dynamic_val.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_val['id']).agg(list)

dynamic_val

id
20003425    [[-1.6126446935823784, 1.203009557809952, -0.3...
20008098    [[0.5779135333859822, 0.11971546779974476, -0....
20014219    [[1.374480161374477, -0.7830296072087612, -0.8...
20015722    [[-1.0152197225910073, 0.30026448280144596, -0...
20020590    [[-0.6169364085967599, 0.6613625128048484, 0.0...
                                  ...                        
29978469    [[-0.21865309460251256, 0.11971546779974476, -...
29985535    [[-0.41779475159963625, 1.203009557809952, -0....
29989089    [[0.5779135333859822, -0.24138256220365767, -1...
29991038    [[0.17963021939173485, -0.4219315772053589, 1....
29993312    [[1.7727634753687245, -0.4219315772053589, 2.4...
Length: 1940, dtype: object

### Dynamic test preprocessing

In [33]:
id_lengths_test = dynamic_test['id'].value_counts().to_dict()
dynamic_test = dynamic_test.sort_values(by=['id', 'charttime'])
dynamic_test = dynamic_test.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_test['id']).agg(list)

dynamic_test

id
20001305    [[-0.6169364085967599, 0.4808134978031472, 0.4...
20009550    [[0.7770551903831059, 1.3835585728116533, 0.46...
20011505    [[1.1753385043773532, -1.5052256672155662, -0....
20017985    [[-1.0152197225910073, 0.6613625128048484, -0....
20026358    [[-1.0152197225910073, 2.2863036478201595, -0....
                                  ...                        
29957999    [[-1.214361379588131, 0.8419115278065495, -0.7...
29961750    [[-1.214361379588131, 0.11971546779974476, -0....
29967192    [[-0.01951143760538886, -0.6024805922070601, -...
29981257    [[-0.01951143760538886, -1.5052256672155662, -...
29994296    [[0.9761968473802296, -0.6024805922070601, 3.3...
Length: 776, dtype: object

In [34]:
notes = pd.read_csv('../data/notes_cleaned.csv')
notes = notes[['id', 'charttime', 'text', 'interval']]

notes_val = notes[notes['id'].isin(static_val['id'])].copy()
notes_test = notes[notes['id'].isin(static_test['id'])].copy()

In [35]:
validation_data = MultimodalClassifierDataset(
    static=static_val, dynamic=dynamic_val, 
    id_lengths=id_lengths_val, notes=notes_val
    )

test_data = MultimodalClassifierDataset(
    static=static_test, dynamic=dynamic_test, 
    id_lengths=id_lengths_test, notes=notes_test
    )

val_loader = DataLoader(validation_data, batch_size=400, shuffle=False, collate_fn=collation)
test_loader = DataLoader(test_data, batch_size=400, shuffle=False, collate_fn=collation)

In [36]:
seed_value = 24
num_lstm_cells = 1
out_features = 3

torch.manual_seed(seed_value)

cuda_available = torch.cuda.is_available()
if cuda_available:
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

In [37]:
static_input_size = 14
dynamic_input_size = 9
hidden_size = 64

state_dict = torch.load('../saved-models/highest_f1_model.pth', map_location=torch.device('cpu'))

model = LOSNetWeighted(dynamic_input_size=dynamic_input_size, static_input_size=static_input_size,
                       out_features=out_features, hidden_size=hidden_size,
                       task='cls')

model.load_state_dict(state_dict)

device = torch.device('cuda' if torch.cuda.is_available() else 'mps')

model = model.to(device)

model.eval()

print(f'device: {device}')

device: mps


In [38]:
representations = []
labels = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc='processing batches'):
        packed_dynamic_X, notes_X, notes_intervals, static_batch, los = batch
        
        packed_dynamic_X = packed_dynamic_X.to(device)
        los = los.to(device)
        static_X_gpu = static_batch.to(device)

        notes_X_gpu = []
        for notes in notes_X:
            notes_gpu = {key: value.to(device) for key, value in notes.items()}
            notes_X_gpu.append(notes_gpu)

        outputs, combined_representation = model(packed_dynamic_X, notes_X_gpu, notes_intervals, static_X_gpu)
        
        representations.append(combined_representation.cpu().detach().numpy())
        labels.extend(los.cpu().numpy())

    representations = np.vstack(representations)

    df_rep = pd.DataFrame(representations)
    df_rep['y'] = labels

processing batches: 100%|██████████| 5/5 [06:24<00:00, 76.85s/it]


In [43]:
mapping = {
    0: '1 to 2 days',
    1: '2 to 4 days',
    2: '4+ days'
}

df_rep['y'] = df_rep['y'].map(lambda x: mapping[pd.Series(x).idxmax()])

In [44]:
df_rep

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,837,838,839,840,841,842,843,844,845,Label
0,-0.340581,-3.833747,0.066294,1.018408,-0.179014,0.966328,-1.339455,0.234124,-1.016070,2.188171,...,0.0,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,2 to 4 days
1,-0.549434,0.434148,-0.325898,0.467696,-1.527336,2.786237,-0.655234,-0.476394,-1.204274,-1.020449,...,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1 to 2 days
2,0.881510,-1.262348,0.812749,0.974836,1.032071,1.307224,0.538471,0.775202,0.452322,1.378600,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,2 to 4 days
3,-0.560801,1.527657,0.311670,0.553606,-0.361317,1.767441,0.049091,0.108104,-1.473664,-0.478795,...,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,4+ days
4,-1.228068,-0.113064,-0.031771,1.332343,0.984589,1.359454,1.653088,0.607137,0.377995,0.271055,...,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1 to 2 days
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1935,-0.935805,1.056361,0.328936,0.945346,1.599790,1.786199,0.034798,-0.019038,-1.648360,-0.866797,...,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1 to 2 days
1936,-0.964273,-0.185481,0.580981,1.234195,-0.127731,1.607367,-0.633664,0.186823,-1.399708,0.200779,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4+ days
1937,0.312890,0.806366,0.529166,0.847469,-0.031594,0.132591,2.678933,1.308712,0.968461,1.477778,...,0.0,1.0,1.0,0.0,1.0,1.0,0.0,1.0,0.0,2 to 4 days
1938,0.893491,-2.289811,0.920180,0.490652,-0.026090,0.667497,-1.214156,-0.294808,-2.459921,1.110260,...,1.0,1.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,4+ days


In [51]:
rep_X = df_rep.drop('Label', axis=1)
rep_y = df_rep['Label']

In [59]:
tsne_2d = TSNE(n_components=2)

proj_2d = tsne_2d.fit_transform(rep_X)

color_map = {
    '2 to 4 days': '#1C8356',
    '1 to 2 days': '#B82E2E',
    '4+ days': '#636EFA'
}

fig = px.scatter(
    proj_2d, x=0, y=1,
    color=rep_y, labels={'color': 'los'},
    color_discrete_map=color_map,
    height=800,
    width=800
)
fig.show()

In [62]:
X, y = make_blobs(n_samples=2000, centers=3, n_features=50, random_state=42)

los_types = {0: '2 to 4 days', 1: '1 to 2 days', 2: '4+ days'}
rep_y = np.array([los_types[label] for label in y])

tsne_2d = TSNE(n_components=2, random_state=42)
proj_2d = tsne_2d.fit_transform(X)

color_map = {
    '2 to 4 days': '#1C8356',
    '1 to 2 days': '#B82E2E',
    '4+ days': '#636EFA'
}

fig = px.scatter(
    pd.DataFrame(proj_2d, columns=['x', 'y']), x='x', y='y',
    color=rep_y,
    labels={'color': 'LOS'},
    color_discrete_map=color_map,
    height=800,
    width=800
)
fig.show()