In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from transformers import AutoModel
from sklearn.preprocessing import OneHotEncoder
import pandas as pd
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import f1_score

In [34]:
import torch
import torch.nn as nn
from torch.nn import LSTM
import torch.optim as optim
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModel
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import torch.nn.functional as F

class MultimodalRegressorDataset(Dataset):
    '''
    Parameters
    ----------
    static: the static DataFrame
    dynamic: dynamic dataset where each timestep is concatenated into one dimension
        ```
            id
            20008098    [[21.0, 20.0, 21.0, 8.9, 100.0, 0.9, 178.0, 13...
            20013244    [[13.0, 29.0, 17.0, 8.9, 103.0, 1.1, 127.0, 14...
            20015730    [[11.0, 25.0, 17.0, 8.1, 112.0, 1.6, 121.0, 14...
            20020562    [[12.0, 22.0, 21.0, 8.2, 104.0, 2.2, 91.0, 134...
            20021110    [[16.0, 27.0, 32.0, 9.7, 103.0, 1.2, 88.0, 141...
        ```

    id_lengths: a dictionary where the key is the patient_id and the value is the true length the time series associated with each patient id (to be used for packed padding)
        ```
            {
                20008098: 9,
                20013244: 7,
                20015730: 10,
                20020562: 10,
                20021110: 10,
                20022095: 6,
                20022465: 6,
                20024177: 7
            }

    Outputs
    -------
    packed_dynamic_X: A sequence of time steps, dynamically packed and padded, representing data for a specific patient
    '''

    def __init__(self, static, dynamic, id_lengths):
        self.static = static
        self.static_dict = static.set_index('id')['los_icu'].to_dict()
        self.dynamic = dynamic
        self.id_lengths = id_lengths

    def __len__(self):
        return len(self.static)
    
    def __getitem__(self, idx):
        patient_id = self.static.iloc[idx]['id']

        # time series
        dynamic_X = self.dynamic[patient_id]
        dynamic_X = torch.tensor(dynamic_X, dtype=torch.float32)
        patient_timesteps = self.id_lengths[patient_id]

        # los
        los = [self.static_dict.get(patient_id, [])]
        return dynamic_X, patient_timesteps, los
    
class MultimodalClassifierDataset(MultimodalRegressorDataset):
    def __init__(self, static, dynamic, id_lengths):
        super(MultimodalClassifierDataset, self).__init__(
            static, dynamic, id_lengths
        )

        ohe = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
        encoded = ohe.fit_transform(static[['los_bin']])
        ohe_cols = ohe.get_feature_names_out(['los_bin'])
        ohe_los = pd.DataFrame(encoded, columns=ohe_cols)

        self.static = pd.concat([static, ohe_los], axis=1)
        self.static_dict = {idx: row[ohe_cols].astype(np.float32).to_list() for idx, row in self.static.set_index('id').iterrows()}

    def __len__(self):
        return len(self.static)
    
    def __getitem__(self, idx):
        patient_id = self.static.iloc[idx]['id']

        # time series
        dynamic_X = self.dynamic[patient_id]
        dynamic_X = torch.tensor(dynamic_X, dtype=torch.float32)
        patient_timesteps = self.id_lengths[patient_id]

        # los
        los = self.static_dict.get(patient_id, [])
        return dynamic_X, patient_timesteps,los
    
class LOSNetWeighted(nn.Module):
    '''
    time_series_model: expects an input of packed padded sequences
    text_model: expects an input of dict with keys {'input_ids', 'token_type_ids', 'attention_mask'} of tokenized sequences
    '''
    def __init__(
            self, input_size, out_features, 
            hidden_size, 
            batch_first=True, 
            task='cls', num_cells=1, **kwargs
            ):
        
        assert (task == 'reg' or task == 'cls'), 'task must be either `reg` or `cls`'
        
        super(LOSNetWeighted, self).__init__(**kwargs)
        self.task = task
        
        self.time_series_model = LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=batch_first, num_layers=num_cells)

        self.fc = nn.Sequential(
            nn.LayerNorm(normalized_shape=hidden_size),
            nn.Linear(in_features=hidden_size, out_features=64, bias=True),
            nn.Linear(in_features=64,out_features=32,bias=True),
            nn.Linear(in_features=32,out_features=16,bias=True),
            nn.Linear(in_features=16,out_features=out_features,bias=True),
            nn.Softplus()
        )

    def forward(self, packed_dynamic_X_batch):
        _, (ht, _) = self.time_series_model(packed_dynamic_X_batch)
        ht = ht[-1]
        logits = self.fc(ht)
        y_pred = logits if self.task == 'reg' else F.softmax(logits, dim=-1)

        return y_pred

def collation(batch):
    dynamic_X_batch, patient_timesteps, los_batch = zip(*batch)
    padded_dynamic_batch = pad_sequence(dynamic_X_batch, batch_first=True, padding_value=0.0)
    packed_dynamic_X_batch = pack_padded_sequence(input=padded_dynamic_batch, lengths=patient_timesteps, batch_first=True, enforce_sorted=False)

    los_batch= torch.tensor(los_batch, dtype=torch.float32)

    return packed_dynamic_X_batch,los_batch

In [35]:
!nvidia-smi

Sun Apr 14 10:51:52 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A4000    Off  | 00000000:00:05.0 Off |                  Off |
| 41%   55C    P8    20W / 140W |   1229MiB / 16376MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [36]:
static_train = pd.read_csv('static_train_3.csv')
static_val = pd.read_csv('static_val_3.csv')
static_test = pd.read_csv('static_test_3.csv')

In [37]:
dynamic = pd.read_csv('dynamic_cleaned.csv')
def truncate_and_average(df, id_col, max_records=4):
    df_sorted = df.sort_values(by=[id_col, 'charttime'])

    def process_group(group):
        if len(group) > max_records:
            average_data = group.iloc[:-max_records].drop(columns=['charttime']).mean().to_dict()
            average_data[id_col] = group[id_col].iloc[0]
            average_row = pd.DataFrame([average_data])

            return pd.concat([average_row, group.tail(max_records)], ignore_index=True)
        else:
            return group

    return df_sorted.groupby(id_col).apply(process_group).reset_index(drop=True)

dynamic = truncate_and_average(dynamic, 'id')
dynamic_train = dynamic[dynamic['id'].isin(static_train['id'])].copy()
dynamic_val = dynamic[dynamic['id'].isin(static_val['id'])].copy()
dynamic_test = dynamic[dynamic['id'].isin(static_test['id'])].copy()

In [38]:
features = ['aniongap', 'bicarbonate', 'bun', 'calcium', 'chloride', 'creatinine', 'glucose', 'sodium', 'potassium']

scaler = StandardScaler()

dynamic_train.loc[:, features] = scaler.fit_transform(dynamic_train[features])
dynamic_val.loc[:, features] = scaler.transform(dynamic_val[features])
dynamic_test.loc[:, features] = scaler.transform(dynamic_test[features])  

dynamic_train.head()

Unnamed: 0,id,charttime,aniongap,bicarbonate,bun,calcium,chloride,creatinine,glucose,sodium,potassium
0,20001305,3/25/78 13:45,-0.613818,0.482155,0.430577,2.606973,0.631369,-0.584669,-0.036981,0.420316,0.465106
1,20001305,3/25/78 21:55,-0.613818,0.300675,0.504882,2.606973,0.766311,-0.584669,-0.255141,0.5921,-0.296522
2,20001305,3/25/78 8:20,-0.214408,0.119194,0.393424,3.241469,0.766311,-0.63097,0.023619,0.763883,-0.042646
3,20001361,5/4/43 17:24,-0.414113,-0.062287,-0.312478,-2.151747,0.631369,0.156136,0.108458,-0.095034,1.861422
4,20001361,5/4/43 21:07,-0.214408,-0.425249,-0.163867,-1.940248,0.766311,0.156136,-0.33998,-0.095034,1.734484


In [39]:
id_lengths_train = dynamic_train['id'].value_counts().to_dict()
dynamic_train = dynamic_train.sort_values(by=['id', 'charttime'])
dynamic_train = dynamic_train.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_train['id']).agg(list)

dynamic_train

id
20001305    [[-0.613818013628808, 0.48215538458125584, 0.4...
20001361    [[-0.41411294618117345, -0.062286997871766066,...
20003491    [[-0.21440787873353895, -0.24376779202277338, ...
20009330    [[0.18500225616173, -0.9696909686268026, -0.72...
20009550    [[0.7841174585046334, 1.3895593553362924, 0.46...
                                  ...                        
29991539    [[0.3847073236093645, 1.2080785611852851, 0.83...
29994296    [[0.9838225259522679, -0.606729380324788, 3.40...
29996513    [[0.5844123910569989, -0.24376779202277338, 0....
29997500    [[-0.21440787873353895, 1.9340017377893144, -0...
29998399    [[-1.2129332159717114, 0.11919379627924125, -0...
Length: 7756, dtype: object

In [40]:
id_lengths_val = dynamic_val['id'].value_counts().to_dict()
dynamic_val = dynamic_val.sort_values(by=['id', 'charttime'])
dynamic_val = dynamic_val.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_val['id']).agg(list)

dynamic_val

id
20015722    [[-1.0132281485240768, 0.3006745904302485, -0....
20020590    [[-0.613818013628808, 0.6636361787322631, 0.05...
20026345    [[-1.2129332159717114, -0.062286997871766066, ...
20031816    [[-0.613818013628808, 1.5710401494872996, -0.5...
20034400    [[-0.014702811285904484, -0.7882101744757953, ...
                                  ...                        
29946363    [[-0.41411294618117345, 1.3895593553362924, -0...
29970039    [[-0.8135230810764423, 0.3006745904302485, -0....
29978469    [[-0.21440787873353895, 0.11919379627924125, -...
29985535    [[-0.41411294618117345, 1.2080785611852851, -0...
29989089    [[0.5844123910569989, -0.24376779202277338, -1...
Length: 970, dtype: object

In [41]:
id_lengths_test = dynamic_test['id'].value_counts().to_dict()
dynamic_test = dynamic_test.sort_values(by=['id', 'charttime'])
dynamic_test = dynamic_test.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_test['id']).agg(list)

dynamic_test

id
20015722    [[-1.0132281485240768, 0.3006745904302485, -0....
20020590    [[-0.613818013628808, 0.6636361787322631, 0.05...
20026345    [[-1.2129332159717114, -0.062286997871766066, ...
20031816    [[-0.613818013628808, 1.5710401494872996, -0.5...
20034400    [[-0.014702811285904484, -0.7882101744757953, ...
                                  ...                        
29946363    [[-0.41411294618117345, 1.3895593553362924, -0...
29970039    [[-0.8135230810764423, 0.3006745904302485, -0....
29978469    [[-0.21440787873353895, 0.11919379627924125, -...
29985535    [[-0.41411294618117345, 1.2080785611852851, -0...
29989089    [[0.5844123910569989, -0.24376779202277338, -1...
Length: 970, dtype: object

In [42]:
train_data = MultimodalClassifierDataset(
    static=static_train, dynamic=dynamic_train, 
    id_lengths=id_lengths_train
    )
validation_data = MultimodalClassifierDataset(
    static=static_val, dynamic=dynamic_val, 
    id_lengths=id_lengths_val
    )
test_data = MultimodalClassifierDataset(
    static=static_test, dynamic=dynamic_test, 
    id_lengths=id_lengths_test
    )
train_loader = DataLoader(train_data, batch_size=2000, shuffle=True, collate_fn=collation)
val_loader = DataLoader(validation_data, batch_size=200, shuffle=True, collate_fn=collation)
test_loader = DataLoader(test_data, batch_size = 200, shuffle= True, collate_fn= collation)

In [43]:
out_features = static_train['los_bin'].nunique()

out_features

3

In [44]:
model = LOSNetWeighted(input_size=9, out_features=out_features, hidden_size=32,task='cls')
seed_value = 22
torch.manual_seed(seed_value)

cuda_available = torch.cuda.is_available()
if cuda_available:
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
model = model.to(device)

print(f'device: {device}')

device: cuda


In [45]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [48]:
from sklearn.metrics import f1_score, accuracy_score
epochs = 200
training_loss = []
validation_loss = []
train_weighted_f1_scores = []
val_weighted_f1_scores = []
train_accuracies = []
val_accuracies = []
test_accuracies = []
test_weighted_f1_scores = []
patience = 10
stagnation = 0

for epoch in range(1, epochs + 1):

    print(f'Training epoch: [{epoch}/{epochs}]')
    model.train()
    training_loss_epoch = 0
    all_true_labels = []
    all_predicted_labels = []

    for step, batch in enumerate(train_loader):
        packed_dynamic_X, los = batch
        packed_dynamic_X = packed_dynamic_X.to(device)
        los = los.to(device)

        outputs = model(packed_dynamic_X)
        predicted_labels = torch.argmax(outputs, dim=1)
        true_labels = torch.argmax(los, dim=1)

        loss = criterion(outputs, true_labels)
        training_loss_epoch += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        all_true_labels.extend(true_labels.cpu().numpy())
        all_predicted_labels.extend(predicted_labels.cpu().numpy())

    avg_training_loss_epoch = training_loss_epoch / len(train_loader)
    training_loss.append(avg_training_loss_epoch)

    train_f1_score = f1_score(all_true_labels, all_predicted_labels, average='macro')
    train_weighted_f1_score = f1_score(all_true_labels, all_predicted_labels, average='weighted')
    train_accuracy = accuracy_score(all_true_labels, all_predicted_labels)

    train_weighted_f1_scores.append(train_weighted_f1_score)
    train_accuracies.append(train_accuracy)

    print(f'Validation epoch: [{epoch}/{epochs}]')
    model.eval()
    validation_loss_epoch = 0
    val_all_true_labels = []
    val_all_predicted_labels = []

    for val_step, val_batch in enumerate(val_loader):
        packed_dynamic_X_val, los_val = val_batch
        packed_dynamic_X_val = packed_dynamic_X_val.to(device)
        los_val = los_val.to(device)

        val_outputs = model(packed_dynamic_X_val)
        val_predicted_labels = torch.argmax(val_outputs, dim=1)
        val_true_labels = torch.argmax(los_val, dim=1)

        val_loss = criterion(val_outputs, val_true_labels)
        validation_loss_epoch += val_loss.item()

        val_all_true_labels.extend(val_true_labels.cpu().numpy())
        val_all_predicted_labels.extend(val_predicted_labels.cpu().numpy())

    avg_validation_loss = validation_loss_epoch / len(val_loader)
    validation_loss.append(avg_validation_loss)

    val_weighted_f1_score = f1_score(val_all_true_labels, val_all_predicted_labels, average='weighted')
    val_accuracy = accuracy_score(val_all_true_labels, val_all_predicted_labels)

    val_weighted_f1_scores.append(val_weighted_f1_score)
    val_accuracies.append(val_accuracy)

    if avg_validation_loss < min(validation_loss):
        stagnation = 0
        torch.save(model.state_dict(), 'lstm_fc.pth')
        print('New minimum validation loss, model saved')
    else:
        stagnation += 1

    if stagnation >= patience:
        print(f'No improvement over {patience} epochs, early stopping')
        break

    print('===============================')
    
model.eval()
with torch.no_grad():
    test_all_true_labels = []
    test_all_predicted_labels = []

    for test_step, test_batch in enumerate(test_loader):
        packed_dynamic_X_test, los_test = test_batch
        packed_dynamic_X_test = packed_dynamic_X_test.to(device)
        los_test = los_test.to(device)

        test_outputs = model(packed_dynamic_X_test)
        test_predicted_labels = torch.argmax(test_outputs, dim=1)
        test_true_labels = torch.argmax(los_test, dim=1)

        test_all_true_labels.extend(test_true_labels.cpu().numpy())
        test_all_predicted_labels.extend(test_predicted_labels.cpu().numpy())

    test_f1 = f1_score(test_all_true_labels, test_all_predicted_labels, average='weighted')
    test_acc = accuracy_score(test_all_true_labels, test_all_predicted_labels)
    test_weighted_f1_scores.append(test_f1)
    test_accuracies.append(test_acc)

print(f'Max Training Weighted F1 Score: {max(train_weighted_f1_scores):.4f}')
print(f'Max Validation Weighted F1 Score: {max(val_weighted_f1_scores):.4f}')
print(f'Max Training Accuracy: {max(train_accuracies):.4f}')
print(f'Max Validation Accuracy: {max(val_accuracies):.4f}')
print(f'Max Test Weighted F1 Score: {max(test_weighted_f1_scores):.4f}')
print(f'Max Test Accuracy: {max(test_accuracies):.4f}')

Training epoch: [1/200]
Validation epoch: [1/200]
Training epoch: [2/200]
Validation epoch: [2/200]
Training epoch: [3/200]
Validation epoch: [3/200]
Training epoch: [4/200]
Validation epoch: [4/200]
Training epoch: [5/200]
Validation epoch: [5/200]
Training epoch: [6/200]
Validation epoch: [6/200]
Training epoch: [7/200]
Validation epoch: [7/200]
Training epoch: [8/200]
Validation epoch: [8/200]
Training epoch: [9/200]
Validation epoch: [9/200]
Training epoch: [10/200]
Validation epoch: [10/200]
No improvement over 10 epochs, early stopping
Max Training Weighted F1 Score: 0.4074
Max Validation Weighted F1 Score: 0.3874
Max Training Accuracy: 0.4357
Max Validation Accuracy: 0.4093
Max Test Weighted F1 Score: 0.3804
Max Test Accuracy: 0.4052
