In [1]:
import os
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from joblib import dump
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import f1_score, balanced_accuracy_score
from multimodal import MultimodalClassifierDataset, LOSNet, collation

In [2]:
static = pd.read_csv('../../data/static_cleaned.csv')

static['los_icu_binned'].value_counts()

los_icu_binned
4+ days        3712
2 to 4 days    3290
1 to 2 days    2694
Name: count, dtype: int64

In [3]:
base_path = '../../data/split/with-outliers/combined/one-hot-encoded'

static_train = pd.read_csv(f'{base_path}/static_train.csv')
static_val = pd.read_csv(f'{base_path}/static_val.csv')

In [4]:
to_drop = ['los_icu', 'icu_death']

to_scale = [
    'admission_age',
    'weight_admit',
    'charlson_score',
 ]

In [5]:
feature_cols = [col for col in static_train.select_dtypes(include=[np.number]).columns.tolist() if col not in to_drop]

static_train = static_train[feature_cols]
static_val = static_val[feature_cols]

In [6]:
scaler = StandardScaler()

static_train[to_scale] = scaler.fit_transform(static_train[to_scale])
static_val[to_scale] = scaler.transform(static_val[to_scale])

dump(scaler, '../../scalers/static_scaler.joblib')

['../../scalers/static_scaler.joblib']

In [7]:
dynamic = pd.read_csv('../../data/dynamic_cleaned.csv')
dynamic_train = dynamic[dynamic['id'].isin(static_train['id'])].copy()
dynamic_val = dynamic[dynamic['id'].isin(static_val['id'])].copy()

In [8]:
def truncate_and_average(df, id_col, max_records=4):
    df_sorted = df.sort_values(by=[id_col, 'charttime'])

    def process_group(group):
        if len(group) > max_records:
            average_data = group.iloc[:-max_records].drop(columns=['charttime']).mean().to_dict()
            average_data[id_col] = group[id_col].iloc[0]
            average_row = pd.DataFrame([average_data])

            return pd.concat([average_row, group.tail(max_records)], ignore_index=True)
        else:
            return group

    return df_sorted.groupby(id_col).apply(process_group).reset_index(drop=True)


In [9]:
dynamic_train = truncate_and_average(dynamic_train, 'id')
dynamic_val = truncate_and_average(dynamic_val, 'id')

  return df_sorted.groupby(id_col).apply(process_group).reset_index(drop=True)
  return df_sorted.groupby(id_col).apply(process_group).reset_index(drop=True)


In [10]:
dynamic_train.groupby('id').size().describe()

count    6980.000000
mean        4.081662
std         0.886078
min         3.000000
25%         3.000000
50%         4.000000
75%         5.000000
max         5.000000
dtype: float64

In [11]:
dynamic_val.groupby('id').size().describe()

count    1940.000000
mean        4.059278
std         0.883691
min         3.000000
25%         3.000000
50%         4.000000
75%         5.000000
max         5.000000
dtype: float64

In [12]:
features = ['aniongap', 'bicarbonate', 'bun', 'calcium', 'chloride', 'creatinine', 'glucose', 'sodium', 'potassium']

scaler = StandardScaler()

dynamic_train[features] = scaler.fit_transform(dynamic_train[features])
dynamic_val[features] = scaler.transform(dynamic_val[features])

dump(scaler, '../../scalers/dynamic_scaler.joblib')

dynamic_train.head()

Unnamed: 0,id,charttime,aniongap,bicarbonate,bun,calcium,chloride,creatinine,glucose,sodium,potassium
0,20001361,5/4/43 17:24,-0.417795,-0.060834,-0.314135,-2.174652,0.635494,0.15736,0.104639,-0.091321,1.853715
1,20001361,5/4/43 21:07,-0.218653,-0.421932,-0.165651,-1.960617,0.770345,0.15736,-0.342113,-0.091321,1.726925
2,20001361,5/5/43 15:02,0.378772,0.119715,0.131317,-0.462372,0.635494,0.71425,-0.354187,1.109421,-0.682088
3,20001361,5/5/43 4:27,-0.218653,0.119715,-0.017167,-1.211494,0.770345,0.34299,-0.656046,0.766352,-0.048137
4,20003491,12/17/97 15:33,-0.218653,-0.241383,-0.351256,0.500786,0.096092,-0.39953,0.599689,0.251748,2.107295


In [13]:
dynamic_train.groupby('id').size().describe()

count    6980.000000
mean        4.081662
std         0.886078
min         3.000000
25%         3.000000
50%         4.000000
75%         5.000000
max         5.000000
dtype: float64

In [14]:
dynamic_val.groupby('id').size().describe()

count    1940.000000
mean        4.059278
std         0.883691
min         3.000000
25%         3.000000
50%         4.000000
75%         5.000000
max         5.000000
dtype: float64

### Dynamic train preprocessing

In [15]:
id_lengths_train = dynamic_train['id'].value_counts().to_dict()
dynamic_train = dynamic_train.sort_values(by=['id', 'charttime'])
dynamic_train = dynamic_train.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_train['id']).agg(list)

dynamic_train

id
20001361    [[-0.41779475159963625, -0.06083354720195645, ...
20003491    [[-0.21865309460251256, -0.24138256220365767, ...
20009330    [[0.17963021939173485, -0.9635786222104625, -0...
20012928    [[1.7727634753687245, 0.30026448280144596, 0.9...
20013244    [[-1.0152197225910073, 1.0224605428082507, -0....
                                  ...                        
29990494    [[-0.01951143760538886, -0.7830296072087612, 0...
29991539    [[0.3787718763888585, 1.203009557809952, 0.836...
29996513    [[0.5779135333859822, -0.24138256220365767, 0....
29997500    [[-0.21865309460251256, 1.925205617816757, -0....
29998399    [[-1.214361379588131, 0.11971546779974476, -0....
Length: 6980, dtype: object

### Dynamic val preprocessing

In [16]:
id_lengths_val = dynamic_val['id'].value_counts().to_dict()
dynamic_val = dynamic_val.sort_values(by=['id', 'charttime'])
dynamic_val = dynamic_val.apply(lambda x: list(x[features]), axis=1).groupby(dynamic_val['id']).agg(list)

dynamic_val

id
20003425    [[-1.6126446935823784, 1.203009557809952, -0.3...
20008098    [[0.5779135333859822, 0.11971546779974476, -0....
20014219    [[1.374480161374477, -0.7830296072087612, -0.8...
20015722    [[-1.0152197225910073, 0.30026448280144596, -0...
20020590    [[-0.6169364085967599, 0.6613625128048484, 0.0...
                                  ...                        
29978469    [[-0.21865309460251256, 0.11971546779974476, -...
29985535    [[-0.41779475159963625, 1.203009557809952, -0....
29989089    [[0.5779135333859822, -0.24138256220365767, -1...
29991038    [[0.17963021939173485, -0.4219315772053589, 1....
29993312    [[1.7727634753687245, -0.4219315772053589, 2.4...
Length: 1940, dtype: object

In [17]:
train_data = MultimodalClassifierDataset(
    static=static_train, dynamic=dynamic_train, 
    id_lengths=id_lengths_train
    )
validation_data = MultimodalClassifierDataset(
    static=static_val, dynamic=dynamic_val, 
    id_lengths=id_lengths_val
    )

train_loader = DataLoader(train_data, batch_size=1000, shuffle=True, collate_fn=collation)
val_loader = DataLoader(validation_data, batch_size=400, shuffle=False, collate_fn=collation)

In [18]:
seed_value = 24
num_lstm_cells = 1
out_features = 3

torch.manual_seed(seed_value)

cuda_available = torch.cuda.is_available()
if cuda_available:
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

In [19]:
static_input_size = 14
dynamic_input_size = 9
hidden_size = 32

model = LOSNet(static_input_size=static_input_size, dynamic_input_size=dynamic_input_size, out_features=out_features, hidden_size=hidden_size, num_cells=num_lstm_cells)
device = torch.device('cuda' if torch.cuda.is_available() else 'mps')

model = model.to(device)

print(f'device: {device}')

device: mps


In [20]:
print(f'total fc input size: {static_input_size + hidden_size}')

total fc input size: 46


In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
epochs = 200

  from .autonotebook import tqdm as notebook_tqdm


In [22]:
def gen_dir(num_lstm_cells, dataset_type='combined_regression', hidden_size=32, trial_num=1):
    loss_base_path = f'../../losses/trial_num_{trial_num}/{dataset_type}/{num_lstm_cells}_cells_{epochs}_epochs_{hidden_size}_hidden_size'
    model_save_path = f'../../saved-models/trial_num_{trial_num}/{num_lstm_cells}_cells_{dataset_type}_{epochs}_epochs_{hidden_size}_hidden_size'
    tensorboard_path = f'../../tensorboard/runs/trial_num_{trial_num}/{dataset_type}_static_dynamic_{num_lstm_cells}_cells_{epochs}_epochs_{hidden_size}_hidden_size'
    
    
    if not os.path.exists(loss_base_path):
        os.makedirs(loss_base_path)
        print(f"Created loss directory: {loss_base_path}")
    else:
        print(f"Directory for loss exists: {loss_base_path}")
    
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
        print(f"Created model directory: {model_save_path}")
    else:
        print(f"Directory for model exists: {model_save_path}")

    return loss_base_path, model_save_path, tensorboard_path

In [23]:
loss_base_path, model_save_path, tensorboard_path = gen_dir(num_lstm_cells, 'combined_classification_expansion_fc_truncated', hidden_size=hidden_size, trial_num=5)

writer = SummaryWriter(tensorboard_path)

Directory for loss exists: ../../losses/trial_num_5/combined_classification_expansion_fc_truncated/1_cells_200_epochs_32_hidden_size
Created model directory: ../../saved-models/trial_num_5/1_cells_combined_classification_expansion_fc_truncated_200_epochs_32_hidden_size


In [24]:
f1_type = 'weighted'
training_loss = []
validation_loss = []
train_f1_scores = []
val_f1_scores = []
train_acc_scores = []
val_acc_scores = []
patience = 10
stagnation = 0

In [25]:
for epoch in range (1, epochs+1):

    print(f'training epoch: [{epoch}/{epochs}]')
    model.train()
    training_loss_epoch = 0
    all_true_labels = []
    all_predicted_labels = []

    for step, batch in enumerate(train_loader):
        packed_dynamic_X, static_X, los = batch

        packed_dynamic_X = packed_dynamic_X.to(device)
        static_X = static_X.to(device)
        los = los.to(device)

        outputs = model(packed_dynamic_X, static_X)
        predicted_labels = torch.argmax(outputs, dim=1)
        true_labels = torch.argmax(los, dim=1)

        loss = criterion(outputs, true_labels)
        writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + step)
        training_loss_epoch += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        all_true_labels.extend(true_labels.cpu().numpy())
        all_predicted_labels.extend(predicted_labels.cpu().numpy())

        if step % max(1, round(len(train_loader) * 0.1)) == 0:
            print(f'step: [{step+1}/{len(train_loader)}] | loss: {loss.item():.4}')

            if step+1 == 1 and epoch == 1:
                with open(f'{loss_base_path}/loss_step.txt', 'w') as loss_step_f:
                    loss_step_f.write(f'{loss.item():.4f}\n')

            else:
                with open(f'{loss_base_path}/loss_step.txt', 'a') as loss_step_f:
                    loss_step_f.write(f'{loss.item():.4f}\n')

    avg_training_loss_epoch = training_loss_epoch / len(train_loader)
    writer.add_scalar('Loss/train_avg', avg_training_loss_epoch, epoch)

    training_loss.append(avg_training_loss_epoch)
    print(f'\nTraining epoch loss: {avg_training_loss_epoch:.4f}')

    train_f1_score = f1_score(all_true_labels, all_predicted_labels, average=f1_type)
    train_f1_scores.append(round(train_f1_score, 4))
    print(f'Training {f1_type} F1 epoch score: {train_f1_score:.4f}')
    writer.add_scalar('F1/train', train_f1_score, epoch)

    train_acc_score = balanced_accuracy_score(all_true_labels, all_predicted_labels)
    train_acc_scores.append(round(train_acc_score, 4))
    print(f'Training balanced accuracy epoch score: {train_acc_score:.4f}\n')
    writer.add_scalar('Accuracy/train', train_acc_score, epoch)

    if epoch == 1:
        with open(f'{loss_base_path}/training_loss_epoch.txt', 'w') as loss_epoch_train_f:
            loss_epoch_train_f.write(f'{avg_training_loss_epoch:.4f}\n')

    else:
        with open(f'{loss_base_path}/training_loss_epoch.txt', 'a') as loss_epoch_train_f:
            loss_epoch_train_f.write(f'{avg_training_loss_epoch:.4f}\n')

    print(f'validation epoch: [{epoch}/{epochs}]')

    model.eval()
    with torch.no_grad():
        validation_loss_epoch = 0
        val_all_true_labels = []
        val_all_predicted_labels = []

        for val_step, val_batch in enumerate(val_loader):
            packed_dynamic_X_val, static_X_val, los_val = val_batch

            packed_dynamic_X_val = packed_dynamic_X_val.to(device)
            static_X_val = static_X_val.to(device)
            los_val = los_val.to(device)
        
            val_outputs = model(packed_dynamic_X_val, static_X_val)
            val_predicted_labels = torch.argmax(val_outputs, dim=1)
            val_true_labels = torch.argmax(los_val, dim=1)

            val_loss = criterion(val_outputs, val_true_labels)
            writer.add_scalar('Loss/val', val_loss.item(), epoch * len(val_loader) + val_step)
            validation_loss_epoch += val_loss.item()

            val_all_true_labels.extend(val_true_labels.cpu().numpy())
            val_all_predicted_labels.extend(val_predicted_labels.cpu().numpy())

        avg_validation_loss = validation_loss_epoch / len(val_loader)
        writer.add_scalar('Loss/val_avg', avg_validation_loss, epoch)
        print(f'Validation epoch loss: {avg_validation_loss:.4f}')

        val_f1_score = f1_score(val_all_true_labels, val_all_predicted_labels, average=f1_type)
        print(f'Validation {f1_type} F1 epoch score: {val_f1_score:.4f}')
        writer.add_scalar('F1/val', val_f1_score, epoch)

        val_acc_score = balanced_accuracy_score(val_all_true_labels, val_all_predicted_labels)
        print(f'Validation balanced accuracy epoch score: {val_acc_score:.4f}\n')
        writer.add_scalar('Accuracy/val', val_acc_score, epoch)
        
        if len(validation_loss) == 0 or (avg_validation_loss < min(validation_loss)):
            stagnation = 0
            torch.save(model.state_dict(), f'{model_save_path}/lowest_loss_model.pth')
            print(f'new minimum validation loss')
            print(f'model saved\n')

        if len(val_f1_scores) == 0 or (val_f1_score > max(val_f1_scores)):
            torch.save(model.state_dict(), f'{model_save_path}/highest_f1_model.pth')
            print(f'new max {f1_type} F1 score')
            print(f'model saved\n')

        if len(val_acc_scores) == 0 or (val_acc_score > max(val_acc_scores)):
            torch.save(model.state_dict(), f'{model_save_path}/highest_accuracy_model.pth')
            print(f'new max balanced accuracy score')
            print(f'model saved\n')

        else:
            stagnation += 1

        validation_loss.append(avg_validation_loss)
        val_f1_scores.append(round(val_f1_score, 4))
        val_acc_scores.append(round(val_acc_score, 4))

        if epoch == 1:
            with open(f'{loss_base_path}/validation_loss_epoch.txt', 'w') as loss_epoch_val_f:
                loss_epoch_val_f.write(f'{avg_validation_loss:.4f}\n')

        else:
            with open(f'{loss_base_path}/validation_loss_epoch.txt', 'a') as loss_epoch_val_f:
                loss_epoch_val_f.write(f'{avg_validation_loss:.4f}\n')

        if stagnation >= patience:
            print(f'No improvement over {patience} epochs')
            print('Early stopping\n')
            break

    model.train()

    print('==========================================\n')

writer.close()
print(f'min training loss: {min(training_loss):.4f}')
print(f'min validation loss: {min(validation_loss):.4f}\n')

print(f'max training f1: {max(train_f1_scores):.4f}')
print(f'max validation f1: {max(val_f1_scores):.4f}\n')

print(f'max training acc: {max(train_acc_scores):.4f}')
print(f'max validation acc: {max(val_acc_scores):.4f}')

training epoch: [1/200]
step: [1/7] | loss: 1.159
step: [2/7] | loss: 1.135
step: [3/7] | loss: 1.107
step: [4/7] | loss: 1.109
step: [5/7] | loss: 1.098
step: [6/7] | loss: 1.104
step: [7/7] | loss: 1.096

Training epoch loss: 1.1154
Training weighted F1 epoch score: 0.3315
Training balanced accuracy epoch score: 0.3368

validation epoch: [1/200]
Validation epoch loss: 1.0957
Validation weighted F1 epoch score: 0.2292
Validation balanced accuracy epoch score: 0.3349

new minimum validation loss
model saved

new max weighted F1 score
model saved

new max balanced accuracy score
model saved


training epoch: [2/200]
step: [1/7] | loss: 1.094
step: [2/7] | loss: 1.089
step: [3/7] | loss: 1.108
step: [4/7] | loss: 1.095
step: [5/7] | loss: 1.097
step: [6/7] | loss: 1.094
step: [7/7] | loss: 1.089

Training epoch loss: 1.0950
Training weighted F1 epoch score: 0.2504
Training balanced accuracy epoch score: 0.3332

validation epoch: [2/200]
Validation epoch loss: 1.0907
Validation weighted F