[![Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/giordamaug/IEEE-JBHI/blob/main/CV_emb_lgbm.ipynb)

## Libraries

In [None]:
import math
import argparse
import pandas as pd
import numpy as np
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import warnings
warnings.simplefilter(action='ignore')
import category_encoders as ce
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import lightgbm as lgb
import seaborn as sns
from sklearn.metrics import (
    matthews_corrcoef, confusion_matrix, accuracy_score, roc_auc_score,
    precision_score, recall_score, f1_score
)
from tqdm.notebook import tqdm

In [None]:
!wget https://raw.githubusercontent.com/ggithub.com/giordamaug/IEEE-JBHI/blob/main/dataset_sintetico_completo.json
!wget https://raw.githubusercontent.com/ggithub.com/giordamaug/IEEE-JBHI/blob/main/targets.csv

## Setting parameters

In [128]:
class Settings:
    patologies = "0,1,2,3,4,5,6,7"
    input_file = "dataset_sintetico_completo.json"
    target_file = "targets.csv"
    splen = "1"
    min_events = 3
    display = False
    to_latex = True
    use_vars = "static"

args = Settings()

## Load the dataset

In [129]:
sequences = {}
with open(args.input_file) as f:
    synt_list = json.load(f)
    for elem in tqdm(synt_list, desc="Extracting events"):
        sequences[elem['id']] = [(elem['events'][i]['event'], elem['events'][i]['date']) for i in range(len(elem['events']))]
    sequences
    Xstatic = pd.read_json(args.input_file).drop(columns=['events'])

    print(f"X_static shape: {Xstatic.shape}")
    print(f"ATTRIBUTES: {sorted(Xstatic.columns)}")

Extracting events:   0%|          | 0/2399 [00:00<?, ?it/s]

X_static shape: (2399, 23)
ATTRIBUTES: ['base_pathology_area', 'bmi', 'days_after', 'days_before', 'dosage_num', 'drug', 'dyslipidemia', 'eventi_infettivi', 'gender', 'genotype_alpha1', 'genotype_alpha2', 'genotype_beta1', 'genotype_beta2', 'hbf', 'heparin', 'id', 'is_splenectomized?', 'primary_pathology', 'smoking', 'splenectomy_indication', 'splenectomy_method', 'splenectomy_response', 'tsh']


## Methods

### LSTM

In [130]:
# Creazione dataset
class TextDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx]


class LSTMEmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pooling=False):
        super().__init__()
        self.pooling = pooling
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=False)

    def forward(self, x):
        embedded = self.embedding(x)
        output, (hn, cn) = self.lstm(embedded)
        if self.pooling:
            return output.mean(dim=1)
        else:
            return hn.squeeze(0)

    def train_model(self, dataloader, num_epochs=10, enable_plot=False, disable=False):

        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        criterion = torch.nn.MSELoss()

        loss_history = []

        self.train()
        pbar = tqdm(range(num_epochs), disable=disable, desc=f"Embedding:")
        for epoch in pbar:
            total_loss = 0
            for batch in dataloader:
                optimizer.zero_grad()
                batch = batch.long()
                output = self(batch)

                target = torch.zeros_like(output)
                loss = criterion(output, target)

                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            avg_loss = total_loss / len(dataloader)
            pbar.set_description(f"Embedding: Loss {avg_loss:.4f}")
            loss_history.append(avg_loss)

            if enable_plot:
                clear_output(wait=True)
                plt.plot(loss_history, label="Loss")
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.title("Training Loss Over Time")
                plt.legend()
                plt.grid()
                plt.show()

def similarity_matrix(events, attributes, targets, emb1, emb2):
    attributes = [a.lower() for a in attributes]
    targets = [t.lower() for t in targets]
    zero_data = np.ones((len(events), len(attributes)))
    X_df = pd.DataFrame(zero_data, columns=attributes, index=events.keys())

    Wmul = emb1.T@emb2
    for id, dcount in tqdm(events.items(), desc="Risk calculating"):
        for concept in dcount.keys():
            concept_l = concept.lower()
            if concept_l in attributes:
                X_df.loc[id, concept_l] += dcount[concept] * np.array([math.exp(Wmul[concept_l][disease]) for disease in targets]).mean()
    return X_df

### Binary matrix

## Selection of patients by patology

In [131]:
# === Generazione eventi sequenze =========
def truncevents(sequences, infection_list, max_inf=1, max_flwup=5, debug=False):
    trunc_sequences = {}
    # truncate event sequence to the k-th occurrence of target
    for id in tqdm(sequences.keys(), desc=f"Truncating to {max_inf}"):
        inf_cnt = 0
        flw_cnt = 0
        new_evset = set()
        for e, d in sequences[id]:
              if e in infection_list:
                if debug: print(f"INF[{id}] {e}")
                new_evset.add((e,d))
                inf_cnt += 1
                if inf_cnt >= max_inf: break
              elif e == "followup" :
                new_evset.add((e,d))
                flw_cnt += 1
                if flw_cnt >= max_flwup: break
              else:
                if debug: print(f"eve[{id}] {e}")
                new_evset.add((e,d))
        trunc_sequences[id] = sorted(list(new_evset), key=lambda x: x[1])
    return trunc_sequences

In [135]:
# Selezione pazienti
def select_patients(df, event_counts, patologies, splen_flags=[0,1], min_ev_count=3):
    filtered_events = {}
    selected_patient_ids = df[
        df['base_pathology_area'].isin(patologies) &
        df['is_splenectomized?'].isin(splen_flags)
    ].index.tolist()

    filtered_events = {
        int(key): value for key, value in event_counts.items()
        if int(key) in selected_patient_ids and len(value.keys()) > min_ev_count
    }

    selected_patient_ids = np.array(list(filtered_events.keys()))
    return selected_patient_ids, filtered_events

# === Costruzione vocabolario ed eveconteggi eventi target ===
event_counts = {}
tot_counts = {}
for k, v in sequences.items():
    lista = list(map(lambda x: x[0], filter(lambda x: x[0] != ['followup'], v)))
    flat_list = list(itertools.chain.from_iterable(lista)) if lista and isinstance(lista[0], list) else lista
    event_counts[k] = dict((x, flat_list.count(x)) for x in set(flat_list))
    tot_counts[int(k)] = len(flat_list)


# === Vocabolario === 
vocab = set()
for patient_events in sequences.values():
    for event,_ in patient_events:
        vocab.update([event] if isinstance(event, str) and event != 'followup' else event)
word_to_idx = {word: idx for idx, word in enumerate(sorted(vocab))}
#word_to_idx["followup"] = 0
idx_to_word = {i: word for word, i in word_to_idx.items()}
print(f"VOCABOLARY SIZE: {len(vocab)}")

# Sequences indicizzate e padding (for embeddings)
indexed_sentences = [[word_to_idx[word] if word in word_to_idx else 0
                      for _, word in patient_events if word in word_to_idx]
                     for patient_events in sequences.values()]
padded_sentences = pad_sequence([torch.tensor(s) for s in indexed_sentences], batch_first=True)

# Target infezioni
#infections_terms = set(json.load(open("translator_infection.json")).values())
targets = pd.read_csv(args.target_file)['targets'].to_list()
embed_attributes = [w for w in sorted(vocab) if w not in targets]

# ===== Truncate sequences to first infection or up tofifth follow-up
sequences = truncevents(sequences, targets)

infected_ids = [id for id,l in sequences.items() if any([x[0].strip().lower() in targets for x in l]) ]
infected_ids_last = [id for id,l in sequences.items() if len(l) > 0 and l[-1][0] in targets]
print(f"1ST TARGET SEQUENCE: {sequences[infected_ids[0]]}")
print(f"# INFECTED: {len(infected_ids)}")
print(f"TARGETS: {targets}")

# === Selezione pazienti ===
patologies = []
for p in args.patologies.split(','):
    if int(p.strip()) in Xstatic['base_pathology_area'].unique():
        patologies.append(int(p.strip()))
    else:
        raise Exception(f"Primary patology {p} not in dataset!")
splen_flags = list(map(int, args.splen.split(',')))
min_ev_count = args.min_events

selected_patient_ids, events = select_patients(Xstatic, event_counts, patologies, splen_flags=splen_flags, min_ev_count=min_ev_count)
print(f"# [{args.patologies}] PATIENTS: {len(selected_patient_ids)}")

VOCABOLARY SIZE: 101


Truncating to 1:   0%|          | 0/2399 [00:00<?, ?it/s]

1ST TARGET SEQUENCE: [('intervention adenoidectomy', '2020-01-01'), ('bacterial infection of the respiratory tract', '2020-01-17')]
# INFECTED: 1030
TARGETS: ['bacterial/viral infection of the heart', 'bacterial/viral infection of the throat', 'parasitic infection of the blood', 'fungal infection of the skin', 'bacterial infection of the gallbladder', 'bacterial infection of the biliary tract', 'autoimmune or inflammatory infection of the blood vessels', 'intestinal parasitic infection', 'viral infection of the skin and mucous membranes', 'bacterial infection of the pleural cavities', 'bacterial/viral infection of the pancreas', 'systemic viral infection', 'other infection', 'bacterial/viral infection of the respiratory tract', 'bacterial/parasitic infection of the blood', 'viral/automimetic infection of the eye', 'zoonotic bacterial infection', 'viral infection of the respiratory tract', 'bacterial/viral infection of the oral mucous membranes', 'bacterial infection of the vertebrae', 

## Cross validation + embedding

In [None]:
# Embedding parameters
num_epochs = 10  # 10
vocab_size = len(vocab)
embedding_dim = 64  # 16
hidden_dim = 128
batch_size = 32

window_size = 2  # for skipgram

# Bilanciamento automatico
lgb_params = {
    'objective': 'binary',
    'metric': 'None',
    'verbosity': -3,
    'is_unbalance': True
}

def mcc_eval(y_pred, dataset):
    y_true = dataset.get_label()
    y_pred_labels = (y_pred > 0.5).astype(int)
    mcc = matthews_corrcoef(y_true, y_pred_labels)
    return 'MCC', mcc, True

mcc_scores = []
acc_scores = []
rocauc_scores = []
prec_scores = []
recall_scores = []
f1_scores = []

y_valid_all = []
y_pred_all = []

from sklearn.model_selection import StratifiedKFold

# remove outliers
if False:
    selected_patient_ids = np.setdiff1d(selected_patient_ids, outliers_idx)
    events = {id: events[id] for id in selected_patient_ids}
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
zero_data = np.zeros((len(events.keys()),))
y_df = pd.DataFrame(zero_data, columns=['target']).set_index(pd.Series([e for e in list(events.keys())]))
cnt = 0
for id, dcount in tqdm(events.items(), desc="Targets"):
    #if disease in [e.lower() for e in dcount.keys()]:  # if there's at least one occurrence of the target set 1
    if len(set([e.lower() for e in dcount.keys()]).intersection(set(targets))) > 0:  # if there's at least one occurrence of the target set 1
        y_df.loc[id, 'target'] = 1
y = y_df.values.astype(np.float32).ravel()
print(y_df.shape)
cvfolding = tqdm(skf.split(selected_patient_ids, y), total=n_splits, desc="Folds")
print(np.unique(y_df.values, return_counts=True), f"count: {cnt}")

use_vars = []
for v in args.use_vars.split(","):
    if v.strip() in ["static", "binary", "lstm"]:
        use_vars += [v.strip()]
    else:
        raise Exception("Wrong methid in paramters")

print("\nüöÄ Inizio cross-validation...\n")
for fold, (t_idx, v_idx) in enumerate(cvfolding):

    train_idx, valid_idx = selected_patient_ids[t_idx], selected_patient_ids[v_idx]
    Xtrains = dict(zip(use_vars, [pd.DataFrame(index=train_idx)]* len(use_vars)))
    Xtests = dict(zip(use_vars, [pd.DataFrame(index=valid_idx)]* len(use_vars)))
    if "lstm" in use_vars:
        train_sentences = [[word_to_idx[word] for word,_ in sequences[id]] for id in train_idx]
        padded_train = pad_sequence([torch.tensor(s) for s in train_sentences], batch_first=True)
        # Dataset e dataloader
        embedding_dataset = TextDataset(padded_train)
        dataloader = DataLoader(embedding_dataset, batch_size=32, shuffle=True)

        # Inizializza e allena modello LSTM unidirezionale
        embmodel = LSTMEmbeddingModel(len(word_to_idx), embed_dim=64, hidden_dim=128, pooling=False)
        embmodel.train_model(dataloader, enable_plot=False, disable=False)
        word_indices = [idx for word, idx in word_to_idx.items() if word != "<PAD>"]
        word_tensors = torch.tensor(word_indices).unsqueeze(1)  # Shape (num_words, 1)
        embedding = embmodel(word_tensors).detach().numpy()
        W = pd.DataFrame(embedding.T, columns=[w.lower() for w in vocab])
        X_df = similarity_matrix(events, embed_attributes, list(targets), W, W)
        Xtrains['lstm'] = X_df.loc[train_idx]
        Xtests['lstm'] = X_df.loc[valid_idx]
    if "dome" in use_vars:
        raise Exception("DOME not included in demo...")
    if "static" in use_vars:
        Xtrains['static'] = Xstatic.loc[train_idx]
        Xtests['static'] = Xstatic.loc[valid_idx]
    if "binary" in use_vars:
        raise Exception("Binary not included in demo...")
        #Xtrains['binary'] = Xbin[bincolumns].loc[train_idx]
        #Xtests['binary'] = Xbin[bincolumns].loc[valid_idx]
    if "dummy" in use_vars:
        Xtrains['dummy'] = pd.DataFrame(np.random.rand(len(train_idx), 128), index=train_idx)
        Xtests['dummy'] = pd.DataFrame(np.random.rand(len(valid_idx), 128), index=valid_idx)
    for v in use_vars:
        print(f"X_{v} shape {Xtrains[v].shape} {Xtests[v].shape}", end=' ')
    X_df_train = pd.concat(list(Xtrains.values()), axis=1)
    X_df_tests = pd.concat(list(Xtests.values()), axis=1)
    X_train, X_valid = X_df_train.values.astype(np.float32), X_df_tests.values.astype(np.float32)
    y_train, y_valid = y_df.loc[train_idx].values.astype(np.float32).ravel(), y_df.loc[valid_idx].values.astype(np.float32).ravel()
    train_data = lgb.Dataset(X_train, label=y_train)
    valid_data = lgb.Dataset(X_valid, label=y_valid)

    model = lgb.train(
        lgb_params,
        train_data,
        num_boost_round=1000,
        valid_sets=[valid_data],
        feval=mcc_eval,
        callbacks=[
            lgb.early_stopping(50),
            lgb.log_evaluation(0)
        ]
    )

    y_pred = model.predict(X_valid)
    y_pred_labels = (y_pred > 0.5).astype(int)

    # Metriche fold
    mcc = matthews_corrcoef(y_valid, y_pred_labels)
    acc = accuracy_score(y_valid, y_pred_labels)
    rocauc = roc_auc_score(y_valid, y_pred)
    prec = precision_score(y_valid, y_pred_labels)
    recall = recall_score(y_valid, y_pred_labels)
    f1 = f1_score(y_valid, y_pred_labels)

    # Salva
    mcc_scores.append(mcc)
    acc_scores.append(acc)
    rocauc_scores.append(rocauc)
    prec_scores.append(prec)
    recall_scores.append(recall)
    f1_scores.append(f1)
    y_valid_all.extend(y_valid)
    y_pred_all.extend(y_pred_labels)

    # üß† Aggiorna tqdm
    cvfolding.set_postfix({
        "Fold": fold + 1,
        "MCC": f"{mcc:.4f}",
        "AUC": f"{rocauc:.4f}",
        "Acc": f"{acc:.4f}",
        "F1": f"{f1:.4f}"
    })

# Final confusion matrix
cm_final = confusion_matrix(y_valid_all, y_pred_all)

print(f"\nüìä Risultati medi su {n_splits} fold:")
print(f"üìà AUC:      {np.mean(rocauc_scores):.4f} ¬± {np.std(rocauc_scores):.4f}")
print(f"üß™ F1-score: {np.mean(f1_scores):.4f} ¬± {np.std(f1_scores):.4f}")
print(f"‚öñÔ∏è Precision:{np.mean(prec_scores):.4f} ¬± {np.std(prec_scores):.4f}")
print(f"üîÅ Recall:   {np.mean(recall_scores):.4f} ¬± {np.std(recall_scores):.4f}")
print(f"üßÆ MCC:      {np.mean(mcc_scores):.4f} ¬± {np.std(mcc_scores):.4f}")
print(f"üéØ Accuracy: {np.mean(acc_scores):.4f} ¬± {np.std(acc_scores):.4f}")

print(f"\nüß© Confusion Matrix finale (aggregata):\n{cm_final}")

if args.display:
    plt.figure(figsize=(5, 4))
    sns.heatmap(cm_final, annot=True, fmt='d', cmap='Blues', cbar=False,
                xticklabels=['Pred 0', 'Pred 1'],
                yticklabels=['True 0', 'True 1'])
    plt.title('Final Confusion Matrix (all folds)')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.show()

if args.to_latex:
    outstr = '||'.join(['\\mathbf{X}^{\\text{'+attrtype+'}}' for attrtype in use_vars])
    #print(f"{'+'.join(args.patologies.split(','))}({'+'.join(['S' if s == '1' else 'N' if s == '0' else 'X' for s in args.splen.split(',')])}) & {len(selected_patient_ids)} ", end='')
    print(f"{'+'.join(args.patologies.split(','))} & {len(selected_patient_ids)} ", end='')
    print(f"& ${outstr}$ & ", end='')
    print(f"${np.mean(rocauc_scores):.4f}\\pm{np.std(rocauc_scores):.4f}$ & ", end='')
    print(f"${np.mean(f1_scores):.4f}\\pm{np.std(f1_scores):.4f}$ & ", end='')
    print(f"${np.mean(prec_scores):.4f}\\pm{np.std(prec_scores):.4f}$ & ", end='')
    print(f"${np.mean(recall_scores):.4f}\\pm{np.std(recall_scores):.4f}$ & ", end='')
    print(f"${np.mean(mcc_scores):.4f}\\pm{np.std(mcc_scores):.4f}$ & ", end='')
    #print(f"${np.mean(acc_scores):.4f}\\pm{np.std(acc_scores):.4f}$ & ", end='')
    print(f"{cm_final}".replace('\n', ''), end='')
    print(f"\\\\ \n", end='')

Targets:   0%|          | 0/1188 [00:00<?, ?it/s]

(1188, 1)


Folds:   0%|          | 0/5 [00:00<?, ?it/s]

(array([0., 1.]), array([1014,  174])) count: 0

üöÄ Inizio cross-validation...

X_static shape (950, 23) (238, 23) Training until validation scores don't improve for 50 rounds
Early stopping, best iteration is:
[10]	valid_0's MCC: 0.0527838
X_static shape (950, 23) (238, 23) Training until validation scores don't improve for 50 rounds
Early stopping, best iteration is:
[12]	valid_0's MCC: 0.0516939
X_static shape (950, 23) (238, 23) Training until validation scores don't improve for 50 rounds
Early stopping, best iteration is:
[1]	valid_0's MCC: 0
X_static shape (951, 23) (237, 23) Training until validation scores don't improve for 50 rounds
Early stopping, best iteration is:
[101]	valid_0's MCC: 0.101251
X_static shape (951, 23) (237, 23) Training until validation scores don't improve for 50 rounds
Early stopping, best iteration is:
[63]	valid_0's MCC: 0.139568

üìä Risultati medi su 5 fold:
üìà AUC:      0.5319 ¬± 0.0393
üß™ F1-score: 0.1847 ¬± 0.1060
‚öñÔ∏è Precision:0.1629 ¬± 