In [1]:
import pandas as pd
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score
import torch
from chronos import ChronosPipeline

2025-04-07 20:09:11.293027: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-07 20:09:13.710006: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744049354.534865   86008 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744049354.776279   86008 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744049356.740272   86008 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [2]:
df_a = pd.read_parquet('../../data/set-a.parquet').drop(columns=['ICUType']).sort_values(by=['RecordID','Time'])
df_b = pd.read_parquet('../../data/set-b.parquet').drop(columns=['ICUType']).sort_values(by=['RecordID','Time'])
df_c = pd.read_parquet('../../data/set-c.parquet').drop(columns=['ICUType']).sort_values(by=['RecordID','Time'])

df_a = df_a.drop(columns=['Time'])
df_b = df_b.drop(columns=['Time'])
df_c = df_c.drop(columns=['Time'])


outcomes_a = pd.read_csv('../../data/Outcomes-a.txt').sort_values(by=['RecordID']).set_index("RecordID")
outcomes_b = pd.read_csv('../../data/Outcomes-b.txt').sort_values(by=['RecordID']).set_index("RecordID")
outcomes_c = pd.read_csv('../../data/Outcomes-c.txt').sort_values(by=['RecordID']).set_index("RecordID")

outcomes_a = outcomes_a["In-hospital_death"]
outcomes_b = outcomes_b["In-hospital_death"]
outcomes_c = outcomes_c["In-hospital_death"]




In [3]:
pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    device_map="cuda",
    torch_dtype=torch.float32,
)

In [4]:
def create_dataset(df, outcomes):


    # Prepare a list to hold the averaged embeddings for each dataframe
    averaged_embeddings = []
    labels = []


    # Group by 'Category' and iterate over each group
    for record_id, group in df.groupby('RecordID'):
        # Initialize a list to store the embeddings for each feature (column)
        feature_embeddings = []
        labels.append(outcomes[record_id]) 

        # Iterate over each column (feature) in the dataframe
        for column in group.columns:
            # Get the 1D tensor (a single column) for the feature
            context = torch.tensor(group[column].values, dtype=torch.float32)  # Shape: [49]

            # Compute the embedding for this column (feature)
            embeddings, _ = pipeline.embed(context)
        
            # We get an embeding for each timestep so we average over all timesteps
            embeddings = embeddings.squeeze().mean(axis=0)

            # Append the embedding for this feature to the list
            feature_embeddings.append(embeddings.detach().cpu().numpy())

        # Average embeddings across all features (columns) in this dataframe
        averaged_embedding = np.mean(feature_embeddings, axis=0)  # Averaging across all feature embeddings
        averaged_embeddings.append(averaged_embedding)
        

    return np.array(averaged_embeddings), np.array(labels)

In [None]:
X_train, y_train = create_dataset(df_a, outcomes_a)
X_test, y_test = create_dataset(df_c, outcomes_c)

In [None]:
model = LogisticRegression(max_iter=1000, random_state=42, class_weight='balanced')
model.fit(X_train, y_train)
probs = model.predict_proba(X_test)[:, 1]
auroc = roc_auc_score(y_test, probs)
auprc = average_precision_score(y_test, probs)

print(f"Logistic Regression - AUROC: {auroc:.4f}, AUPRC: {auprc:.4f}")