In [1]:
from transformers import ElectraConfig, ElectraForSequenceClassification, ElectraTokenizerFast
from transformers import RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer
from transformers import LongformerConfig, LongformerForSequenceClassification, LongformerTokenizer, LongformerSelfAttention
import torch
import datasets
import numpy as np
import pandas as pd

# Model loading

### Test Bio-lm

In [2]:
model_path = "models/RoBERTa-base-PM-M3-Voc-distill-align-hf"
biolm_config = RobertaConfig.from_pretrained(model_path)
biolm_config.num_labels = 10
biolm_config.num_multi_labels = [3, 4, 3]

biolm_model = RobertaForSequenceClassification.from_pretrained(
    model_path,
    config=biolm_config
)
biolm_tokenizer = RobertaTokenizer.from_pretrained(model_path)

x = biolm_tokenizer("The quick brown fox jumps over the lazy dog", return_tensors='pt')
out = biolm_model(**x)
out

Some weights of the model checkpoint at models/RoBERTa-base-PM-M3-Voc-distill-align-hf were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'lm_head.decoder.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at models/RoBERTa-base-

SequenceClassifierOutput(loss=None, logits=tensor([[ 0.4522,  0.0965, -0.4064,  0.7278,  0.2293,  0.1794,  0.3154, -0.0503,
         -0.1672,  0.1967]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)

In [3]:
biolm_model

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50008, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerN

### Test BioElectra

In [4]:
bioelectra_config = ElectraConfig.from_pretrained("models/bioelectra-base-discriminator-pubmed")
bioelectra_config.num_labels = 10
bioelectra_config.num_multi_labels = [3, 4, 3]

bioelectra_model = ElectraForSequenceClassification.from_pretrained(
    "models/bioelectra-base-discriminator-pubmed", 
    config=bioelectra_config
)
bioelectra_tokenizer = ElectraTokenizerFast.from_pretrained("models/bioelectra-base-discriminator-pubmed")

sentence = "The quick brown fox jumps over the lazy dog"
x = bioelectra_tokenizer(sentence, return_tensors="pt")
out = bioelectra_model(**x)
out

Some weights of the model checkpoint at models/bioelectra-base-discriminator-pubmed were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at models/bioelectra-base-discriminator-pubmed and are newly initial

SequenceClassifierOutput(loss=None, logits=tensor([[ 0.0275,  0.0487,  0.0708,  0.0400,  0.0409,  0.0927, -0.1049,  0.1389,
         -0.0044, -0.0192]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)

In [5]:
bioelectra_model

ElectraForSequenceClassification(
  (electra): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ElectraEncoder(
      (layer): ModuleList(
        (0): ElectraLayer(
          (attention): ElectraAttention(
            (self): ElectraSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ElectraSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm

# Test multilabel

In [6]:
from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset

In [7]:
class BioDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

encodings = bioelectra_tokenizer([sentence, sentence], return_tensors="pt")
labels = [[0, 1, 2],[0, 2, 1]]

train_dataset = BioDataset(encodings, labels)
val_dataset = BioDataset(encodings, labels)
test_dataset = BioDataset(encodings, labels)

In [8]:
class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.CrossEntropyLoss()

        rem_logits = logits
        loss = 0
        if 'num_multi_labels' in bioelectra_model.config.__dict__:
            for i, num_label in enumerate(self.model.config.num_multi_labels):
                class_logits = rem_logits[:,:num_label]
                rem_logits = rem_logits[:,num_label:]
                
                class_loss = loss_fct(class_logits, labels[:,i])
                loss += class_loss
        else:
            loss = loss_fct(logits, labels)

        return (loss, outputs) if return_outputs else loss

In [9]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10
)

In [10]:
## Define Metrics
acc_metric = datasets.load_metric("accuracy")
f1_metric = datasets.load_metric("f1","")
num_multi_labels = biolm_model.config.num_multi_labels
tags = {0: 'diag', 1: 'readmission', 2: 'mortality'}

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    rem_logits = logits
    metrics = {}
    for i, num_label in enumerate(num_multi_labels):
        class_logits = rem_logits[:,:num_label]
        rem_logits = rem_logits[:,num_label:]

        class_pred =  np.argmax(class_logits, axis=-1)
        acc = acc_metric.compute(predictions=class_pred, references=labels[:,i])
        f1 = f1_metric.compute(predictions=class_pred, references=labels[:,i], average="macro")

        metrics[f'{tags[i]}_acc'] = acc
        metrics[f'{tags[i]}_f1'] = f1
    return metrics

In [11]:
# Run Bio-lm
trainer = MultilabelTrainer(
    model=biolm_model,          # the instantiated 🤗 Transformers model to be trained
    args=training_args,              # training arguments, defined above
    train_dataset=train_dataset,     # training dataset
    eval_dataset=val_dataset,        # evaluation dataset
    tokenizer=biolm_tokenizer,
    compute_metrics=compute_metrics
)
trainer.train()
trainer.evaluate()

***** Running training *****
  Num examples = 2
  Num Epochs = 3
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 3


Step,Training Loss




Training completed. Do not forget to share your model on huggingface.co/models =)


***** Running Evaluation *****
  Num examples = 2
  Batch size = 64


Trainer is attempting to log a value of "{'accuracy': 1.0}" of type <class 'dict'> for key "eval/diag_acc" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'f1': 1.0}" of type <class 'dict'> for key "eval/diag_f1" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'accuracy': 0.0}" of type <class 'dict'> for key "eval/readmission_acc" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'f1': 0.0}" of type <class 'dict'> for key "eval/readmission_f1" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'accuracy': 0.5}" of type <class 'dict'> for key "eval/mortality_acc" as a scalar. This invoca

{'eval_loss': 3.4143261909484863,
 'eval_diag_acc': {'accuracy': 1.0},
 'eval_diag_f1': {'f1': 1.0},
 'eval_readmission_acc': {'accuracy': 0.0},
 'eval_readmission_f1': {'f1': 0.0},
 'eval_mortality_acc': {'accuracy': 0.5},
 'eval_mortality_f1': {'f1': 0.3333333333333333},
 'eval_runtime': 0.1554,
 'eval_samples_per_second': 12.867,
 'eval_steps_per_second': 6.433,
 'epoch': 3.0}

In [12]:
## Define Metrics
acc_metric = datasets.load_metric("accuracy")
f1_metric = datasets.load_metric("f1")
num_multi_labels = bioelectra_model.config.num_multi_labels
tags = {0: 'diag', 1: 'readmission', 2: 'mortality'}

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    rem_logits = logits
    metrics = {}
    for i, num_label in enumerate(num_multi_labels):
        class_logits = rem_logits[:,:num_label]
        rem_logits = rem_logits[:,num_label:]

        class_pred =  np.argmax(class_logits, axis=-1)
        acc = acc_metric.compute(predictions=class_pred, references=labels[:,i])
        f1 = f1_metric.compute(predictions=class_pred, references=labels[:,i], average="macro")

        metrics[f'{tags[i]}_acc'] = acc
        metrics[f'{tags[i]}_f1'] = f1
    return metrics

In [13]:
# Run BioElectra
trainer = MultilabelTrainer(
    model=bioelectra_model,          # the instantiated 🤗 Transformers model to be trained
    args=training_args,              # training arguments, defined above
    train_dataset=train_dataset,     # training dataset
    eval_dataset=val_dataset,        # evaluation dataset
    tokenizer=bioelectra_tokenizer,
    compute_metrics=compute_metrics
)
trainer.train()
trainer.evaluate()

***** Running training *****
  Num examples = 2
  Num Epochs = 3
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 3


Step,Training Loss




Training completed. Do not forget to share your model on huggingface.co/models =)


***** Running Evaluation *****
  Num examples = 2
  Batch size = 64


Trainer is attempting to log a value of "{'accuracy': 0.0}" of type <class 'dict'> for key "eval/diag_acc" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'f1': 0.0}" of type <class 'dict'> for key "eval/diag_f1" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'accuracy': 0.5}" of type <class 'dict'> for key "eval/readmission_acc" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'f1': 0.3333333333333333}" of type <class 'dict'> for key "eval/readmission_f1" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'accuracy': 0.0}" of type <class 'dict'> for key "eval/mortality_acc" as a scal

{'eval_loss': 3.609358310699463,
 'eval_diag_acc': {'accuracy': 0.0},
 'eval_diag_f1': {'f1': 0.0},
 'eval_readmission_acc': {'accuracy': 0.5},
 'eval_readmission_f1': {'f1': 0.3333333333333333},
 'eval_mortality_acc': {'accuracy': 0.0},
 'eval_mortality_f1': {'f1': 0.0},
 'eval_runtime': 0.1999,
 'eval_samples_per_second': 10.004,
 'eval_steps_per_second': 5.002,
 'epoch': 3.0}

# Test Longformer for Roberta and Electra

In [None]:
def convert_roberta_like_to_longformer(state_dict, model_name):
    orig_keys = [key for key in state_dict]
    for key in orig_keys:
        if model_name in key:
            new_key = key.replace(model_name,'longformer')
            state_dict[new_key] = state_dict[key]
            if 'query.' in new_key:
                state_dict[new_key.replace('.query.','.query_global.')] = state_dict[key]
            if 'key.' in new_key:
                state_dict[new_key.replace('.key.','.key_global.')] = state_dict[key]
            if 'value.' in new_key:
                state_dict[new_key.replace('.value.','.value_global.')] = state_dict[key]

            if '.position_embeddings' in new_key:
                state_dict[new_key] = state_dict[new_key].repeat([8,1])

            if '.position_ids' in new_key:
                    state_dict[new_key] = torch.arange(state_dict[key].shape[1] * 8).view(1, -1)
            del state_dict[key]
    return state_dict

### Load Long Bio-lm

In [14]:
long_biolm_config = RobertaConfig.from_pretrained("models/RoBERTa-base-PM-M3-Voc-distill-align-hf")
long_biolm_config.num_labels = 10
long_biolm_config.num_multi_labels = [3, 4, 3]
long_biolm_config.attention_window = [512] * 12
long_biolm_config.max_position_embeddings = long_biolm_config.max_position_embeddings * 8
long_biolm_model = LongformerForSequenceClassification(config=long_biolm_config)

long_biolm_state_dict = convert_roberta_like_to_longformer(biolm_model.state_dict(), 'roberta')
long_biolm_model.load_state_dict(long_biolm_state_dict, strict=True)

loading configuration file models/RoBERTa-base-PM-M3-Voc-distill-align-hf/config.json
Model config RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.12.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50008
}



<All keys matched successfully>

### Load Long BioElectra

In [15]:
long_bioelectra_config = ElectraConfig.from_pretrained("models/bioelectra-base-discriminator-pubmed")
long_bioelectra_config.num_labels = 10
long_bioelectra_config.num_multi_labels = [3, 4, 3]
long_bioelectra_config.num_multi_labels = [3, 4, 3]
long_bioelectra_config.attention_window = [512] * 12
long_bioelectra_config.max_position_embeddings = long_bioelectra_config.max_position_embeddings * 8
long_bioelectra_model = LongformerForSequenceClassification(config=long_bioelectra_config)

long_bioelectra_state_dict = convert_roberta_like_to_longformer(bioelectra_model.state_dict(), 'electra')
long_bioelectra_model.load_state_dict(long_bioelectra_state_dict, strict=True)

loading configuration file models/bioelectra-base-discriminator-pubmed/config.json
Model config ElectraConfig {
  "architectures": [
    "ElectraForPreTraining"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "embedding_size": 768,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "electra",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "summary_activation": "gelu",
  "summary_last_dropout": 0.1,
  "summary_type": "first",
  "summary_use_proj": true,
  "transformers_version": "4.12.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}



<All keys matched successfully>

# Test Forward with multiple text

In [48]:
from itertools import chain
from torch.utils.data import Dataset

In [70]:
labels = [0, 1, 2]
features = [0,1,1,1,0,0]
patient_notes = [
    "The quick brown fox jumps over the lazy dog",
    "round the rugged rocks the ragged rascal ran",
    "peter piper pickled pepper picker",
    "Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod",
    "tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam,",
    "quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo",
    "consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse",
    "cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non",
    "proident, sunt in culpa qui officia deserunt mollit anim id est laborum.",
] * 10

notes_data = biolm_tokenizer(patient_notes)
# notes_data['labels'] = [labels]
# notes_data['features'] = features
notes_data['input_ids'] = [list(chain.from_iterable(notes_data['input_ids']))]
notes_data['attention_mask'] = [list(chain.from_iterable(notes_data['attention_mask']))]
if 'token_type_ids' in notes_data:
    notes_data['token_type_ids'] = [list(chain.from_iterable(notes_data['token_type_ids']))]

notes_data = biolm_tokenizer.pad(notes_data, pad_to_multiple_of=512, return_tensors='pt')
out = long_biolm_model(**notes_data)
out

Initializing global attention on CLS token...


LongformerSequenceClassifierOutput(loss=None, logits=tensor([[-0.0098, -0.1054, -0.0342, -0.1817,  0.2197, -0.1663, -0.0367,  0.0484,
          0.1922, -0.2257]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None, global_attentions=None)

In [71]:
labels = [0, 1, 2]
features = [0,1,1,1,0,0]
patient_notes = [
    "The quick brown fox jumps over the lazy dog",
    "round the rugged rocks the ragged rascal ran",
    "peter piper pickled pepper picker",
    "Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod",
    "tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam,",
    "quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo",
    "consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse",
    "cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non",
    "proident, sunt in culpa qui officia deserunt mollit anim id est laborum.",
] * 10

notes_data = bioelectra_tokenizer(patient_notes)
# notes_data['labels'] = [labels]
# notes_data['features'] = features
notes_data['input_ids'] = [list(chain.from_iterable(notes_data['input_ids']))]
notes_data['attention_mask'] = [list(chain.from_iterable(notes_data['attention_mask']))]
if 'token_type_ids' in notes_data:
    notes_data['token_type_ids'] = [list(chain.from_iterable(notes_data['token_type_ids']))]

notes_data = bioelectra_tokenizer.pad(notes_data, pad_to_multiple_of=512, return_tensors='pt')
out = long_bioelectra_model(**notes_data)
out

Initializing global attention on CLS token...


LongformerSequenceClassifierOutput(loss=None, logits=tensor([[ 0.2746, -0.3405, -0.0646,  0.0806,  0.1690, -0.0237, -0.1101, -0.0776,
         -0.1227, -0.0241]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None, global_attentions=None)

# Wrap everything into Dataset

In [None]:
###
# EHR Dataset
###
class EHRDataset(Dataset):
    @staticmethod
    def group_readmission(relative_readmission):
        relative_readmission[relative_readmission <= 30] = 0
        relative_readmission[relative_readmission <= 60] = 1
        relative_readmission[relative_readmission <= 90] = 2
        relative_readmission[relative_readmission <= 120] = 3
        relative_readmission[relative_readmission <= 150] = 4
        relative_readmission[relative_readmission <= 180] = 5
        relative_readmission[relative_readmission > 180] = 6
        return relative_readmission
        
    def __init__(self, patient_to_path_dict, tokenizer, eval_patient_to_path_dict=None):
        self.patient_to_path_dict = patient_to_path_dict
        self.patient_id_list = list(self.patient_to_path_dict.keys())
        self.tokenizer = tokenizer
        
        self.eval_patient_to_path_dict = None
        self.eval_patient_id_list = []
        if eval_patient_to_path_dict is not None:
            self.eval_patient_to_path_dict = eval_patient_to_path_dict
            self.eval_patient_id_list = list(self.eval_patient_to_path_dict.keys())
            
        
    def __getitem__(self, index):
        if index < len(self.patient_id_list):
            # Load file
            patient_id = self.patient_id_list[index]
            data_path, note_path = self.patient_to_path_dict[patient_id]
            patient_df = pd.read_pickle(data_path).dropna()
            notes_df = pd.read_pickle(note_path).dropna()
            
            # Extract texts, features, & labels            
            next_diags = patient_df['next_diagnosis'].values
            next_mortals = patient_df['next_mortality'].values
            next_rel_readmis = self.group_readmission(patient_df['next_relative_readmission'].values)
            labels = [next_diags, next_rel_readmis, next_mortals]

            last_readmission = patient_df['readmission'].values[-1]
            features = np.stack(patient_df['features'].values)[:,:-1]
            texts = notes_df.loc[notes_df['timestamp'] < patient_df.iloc[-1]['timestamp'],'texts'].values
        else:
            # Load file
            patient_id = self.eval_patient_id_list[index - len(self.patient_id_list)]
            data_path = self.eval_patient_to_path_dict[patient_id]
            patient_df = pd.read_pickle(data_path).dropna()
               
            # Extract texts, features & labels                  
            next_diags = patient_df['next_diagnosis'].values[:-1],
            next_mortals = patient_df['next_mortality'].values[:-1]
            next_rel_readmis = self.group_readmission(patient_df['next_relative_readmission'].values[:-1])
            labels = [next_diags, next_rel_readmis, next_mortals]
            
            last_readmission = patient_df['readmission'].values[-2]
            features = np.stack(patient_df['features'].values)[:-1,:-1]
            texts = patient_df.loc[patient_df['readmission'] < last_readmission, 'texts'].values            
            
    # Process texts
    patient_data = self.tokenizer(texts)
    patient_data['labels'] = labels
    patient_data['features'] = features
    patient_data['input_ids'] = list(chain.from_iterable(patient_data['input_ids']))
    patient_data['attention_mask'] = list(chain.from_iterable(patient_data['attention_mask']))            
    if 'token_type_ids' in notes_data:
        patient_data['token_type_ids'] = [list(chain.from_iterable(patient_data['token_type_ids']))]

    return patient_data
        
    def __len__(self):
        return len(self.patient_id_list) + len(self.eval_patient_id_list)