In [None]:
# !pip install --upgrade gensim
# !pip install transformers
# !pip install -U sentence-transformers
# !pip install kornia
# # # !pip install "torch==1.7.0"
# !pip install flair
# !pip install captum

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, hamming_loss, confusion_matrix
from transformers import AutoTokenizer
from transformers import BertForSequenceClassification, AdamW, BertConfig
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data.sampler import SubsetRandomSampler
import transformers
from transformers import BertTokenizer, BertModel, AdamW
from transformers import get_linear_schedule_with_warmup
import time
# from kornia.losses import BinaryFocalLossWithLogits

# from captum.attr import LayerIntegratedGradients
# from captum.attr import visualization as viz

# import flair
# from flair.data import Sentence
# from flair.embeddings import (
#     TransformerDocumentEmbeddings, 
#     SentenceTransformerDocumentEmbeddings, 
#     FlairEmbeddings, 
#     StackedEmbeddings, 
#     CharacterEmbeddings, 
#     DocumentPoolEmbeddings,
#     WordEmbeddings,
#     TransformerWordEmbeddings)

import warnings
import traceback

import logging
# logger = logging.getLogger('flair')
# logger.setLevel(logging.ERROR)

In [None]:
!nvidia-smi

Thu Jun 17 15:35:46 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.04   Driver Version: 450.119.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 2080    Off  | 00000000:03:00.0 Off |                  N/A |
| 26%   61C    P2   101W / 260W |   5274MiB /  7982MiB |     45%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 2080    Off  | 00000000:05:00.0 Off |                  N/A |
|  0%   43C    P8    27W / 260W |      3MiB /  7982MiB |      0%      Default |
|       

In [None]:
# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda:1")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(1))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

torch.backends.cudnn.enabled=False

There are 2 GPU(s) available.
We will use the GPU: GeForce RTX 2080


## Preparing Dataset

### Load dataset

In [None]:
categories_number_words = {
        1: "Apoyo Pedagógico en asignaturas",
        3: "Apoyo pedagógico personal",
        4: "Tutoría entre pares",
        7: "Hacer a la familia partícipe del proceso",
        8: "Apoyo psicóloga(o)",
        9: "Apoyo fonoaudióloga(o)",
        10: "Apoyo Educador(a) Diferencial",
        11: "Apoyo Kinesióloga(o)",
        12: "Apoyo Médico General",
        13: "Apoyo Terapeuta Ocupacional",
        14: "Control Neurólogo",
        15: "Apoyo Interdisciplinario",
        16: "Adecuación curricular de acceso",
        17: "Adecuación curricular de objetivos"
    }
categories_words_number = {v: k for k, v in categories_number_words.items()}

diagnoses_codes = {
    "Trastorno específico del lenguaje": 0,
    "Trastorno por déficit atencional": 1,
    "Dificultad específica de aprendizaje": 2,
    "Discapacidad intelectual": 3,
    "Discapacidad visual": 4,
    "Trastorno del espectro autista": 5,
    "Discapacidad auditiva - Hipoacusia": 6,
    "Funcionamiento intelectual limítrofe": 7,
    "Síndrome de Down": 8,
    "Trastorno motor": 9,
    "Multidéficit": 10,
    "Retraso global del desarrollo": 11
}

diagnoses_keys = list(diagnoses_codes.keys())

def transform_diag_to_array(code):
    arr = np.zeros(len(diagnoses_keys), dtype=int)
    for (index, label) in enumerate(diagnoses_keys):
        if diagnoses_codes[label]==code:
            arr[index] = 1
    return arr

In [None]:
train_dataset = pd.read_csv('/research/jamunoz/datasets/train_ds.csv', keep_default_na=False)
val_dataset = pd.read_csv('/research/jamunoz/datasets/val_ds.csv', keep_default_na=False)
test_dataset = pd.read_csv('/research/jamunoz/datasets/test_ds.csv', keep_default_na=False)
# train_dataset = pd.read_csv('gdrive/My Drive/magister/train_ds.csv', keep_default_na=False)
# val_dataset = pd.read_csv('gdrive/My Drive/magister/val_ds.csv', keep_default_na=False)
# test_dataset = pd.read_csv('gdrive/My Drive/magister/test_ds.csv', keep_default_na=False)


# Add OHE diagnosis
train_OHE_diags = []
for diag in train_dataset['Encoded Diagnosis']:
    train_OHE_diags.append(transform_diag_to_array(diag))
temp_train_diags_df = pd.DataFrame(train_OHE_diags, columns=diagnoses_keys)
train_dataset = pd.concat([train_dataset, temp_train_diags_df], axis=1)

val_OHE_diags = []
for diag in val_dataset['Encoded Diagnosis']:
    val_OHE_diags.append(transform_diag_to_array(diag))
temp_val_diags_df = pd.DataFrame(val_OHE_diags, columns=diagnoses_keys)
val_dataset = pd.concat([val_dataset, temp_val_diags_df], axis=1)

test_OHE_diags = []
for diag in test_dataset['Encoded Diagnosis']:
    test_OHE_diags.append(transform_diag_to_array(diag))
temp_test_diags_df = pd.DataFrame(test_OHE_diags, columns=diagnoses_keys)
test_dataset = pd.concat([test_dataset, temp_test_diags_df], axis=1)

In [None]:
# y_keys = list(strat_present.keys())
Y_KEYS = list(categories_words_number.keys())

# df = pd.DataFrame(data=new_dataset_to_export)
# X = df
# Y = df[y_keys]
X_train = train_dataset.drop(Y_KEYS, axis=1)
Y_train = train_dataset[Y_KEYS]
X_val = val_dataset.drop(Y_KEYS, axis=1)
Y_val = val_dataset[Y_KEYS]
X_test = test_dataset.drop(Y_KEYS, axis=1)
Y_test = test_dataset[Y_KEYS]

strats_amounts = {
              'Adecuación curricular de acceso': 2264,
              'Hacer a la familia partícipe del proceso': 2048,
              'Apoyo Interdisciplinario': 1441, 
              'Apoyo Educador(a) Diferencial': 1311,
              'Apoyo pedagógico personal': 1240,
              'Apoyo fonoaudióloga(o)': 378,
              'Apoyo psicóloga(o)': 588,
              'Apoyo Terapeuta Ocupacional': 153,
              'Tutoría entre pares': 350,
              'Control Neurólogo': 63,
              'Apoyo Médico General': 64,
              'Apoyo Kinesióloga(o)': 32,
              'Adecuación curricular de objetivos': 281,
              'Apoyo Pedagógico en asignaturas': 1314
}
most_unbalanced_strategies = [strategy for strategy in Y_KEYS if (
    strats_amounts[strategy] < (len(X_train) + len(X_val) + len(X_test))*0.15 or strats_amounts[strategy] > (len(X_train) + len(X_val) + len(X_test))*0.85)]
less_unbalanced_strategies = [strategy for strategy in Y_KEYS if strategy not in most_unbalanced_strategies]
only_one_strat = [Y_KEYS[0]]

### Dataset

In [None]:
tokenizer = BertTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-uncased")

In [None]:
class AllJoinedObservationsDataset(Dataset):

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
      data_row = self.data.iloc[idx]
      labels =  data_row[Y_KEYS]

      tensor_labels = torch.tensor(labels, dtype=torch.int)
      tensor_diags = torch.tensor(data_row[diagnoses_keys], dtype=torch.int)

      all_tokens = tokenizer.encode(data_row['All perceptions'],
                                    add_special_tokens=False,
                                    max_length=tokenizer.model_max_length,
                                    padding='max_length',
                                    truncation=True,
                                    return_tensors="pt")

      sne_tokens = tokenizer.encode(data_row['Special Education Teacher Perceptions'],
                                    add_special_tokens=False,
                                    max_length=tokenizer.model_max_length,
                                    padding='max_length',
                                    truncation=True,
                                    return_tensors="pt")

      st_tokens = tokenizer.encode(data_row['Speech Therapist Perceptions'],
                                    add_special_tokens=False,
                                    max_length=tokenizer.model_max_length,
                                    padding='max_length',
                                    truncation=True,
                                    return_tensors="pt")

      m_tokens = tokenizer.encode(data_row['Medical Perceptions'],
                                    add_special_tokens=False,
                                    max_length=tokenizer.model_max_length,
                                    padding='max_length',
                                    truncation=True,
                                    return_tensors="pt")

      p_tokens = tokenizer.encode(data_row['Psychologist Perceptions'],
                                    add_special_tokens=False,
                                    max_length=tokenizer.model_max_length,
                                    padding='max_length',
                                    truncation=True,
                                    return_tensors="pt")

      dict_to_return = dict(
          all_tokens=all_tokens,
          sne_tokens=sne_tokens,
          st_tokens=st_tokens,
          p_tokens=p_tokens,
          m_tokens=m_tokens,
          labels=tensor_labels,
          diagnostics=tensor_diags
      )

      return dict_to_return

### Dataloaders

In [None]:
TRAIN_BATCH_SIZE=1

def my_collate1(batches):
  modified_batches = []
  for batch in batches:
    batch_dict = {}
    for key, value in batch.items():
      batch_dict[key] = value
    modified_batches.append(batch_dict)
  return modified_batches

transformed_train_dataset=AllJoinedObservationsDataset(
    train_dataset)

transformed_val_dataset=AllJoinedObservationsDataset(
    val_dataset)

transformed_test_dataset=AllJoinedObservationsDataset(
    test_dataset)

train_data_loader=DataLoader(
    transformed_train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    collate_fn=my_collate1)

val_data_loader=DataLoader(
    transformed_val_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    collate_fn=my_collate1)

test_data_loader=DataLoader(
    transformed_test_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    collate_fn=my_collate1)

### Utils

In [None]:
import statistics

def get_results(targets, outputs):
  TN = 0
  TP = 0
  FP = 0
  FN = 0
  for (i, output) in enumerate(outputs):
    if output==0:
      if targets[i]==0:
        TN += 1
      else:
        FN += 1
    else:
      if targets[i]==1:
        TP += 1
      else:
        FP += 1
  return TP, TN, FP, FN

def findMinDiff(arr):
    n = len(arr)
    arr = sorted(arr)
    diff = 0.5
    for i in range(n-1):
        if arr[i+1] - arr[i] > 0 and arr[i+1] - arr[i] < diff:
            diff = arr[i+1] - arr[i]
    return diff

def get_thresholds(targets, outputs):
  best_thresholds = []
  for i in range(len(outputs[0])):
    real_preds = outputs[:, i]
    trues = targets[:, i]
    max_g = 0
#     max_f1 = 0
    delta_threshold = 0.0001 # findMinDiff(real_preds)*0.9
    positive_ratio = sum(trues)/len(trues)
#     print('pr: ', positive_ratio)
    if positive_ratio > 0.6:
      local_best = 0
      curr_threshold = min(real_preds)
#       print(curr_threshold)
      while curr_threshold < 1:
        preds = [1 if pred > curr_threshold else 0 for pred in real_preds]
        tp, tn, fp, fn = get_results(trues, preds)
        recall = tp/(tp+fn)
        specificity = tn/(tn+fp)
        g_mean = np.sqrt(recall*specificity)
#         f1 = f1_score(trues, preds)
#         print(f1, max_f1, curr_threshold, local_best, tp, tn, fp, fn)
        if tp < tn:
          break
        if g_mean > max_g:
#         if f1 > max_f1:
          max_g = g_mean
#           max_f1 = f1
          local_best = curr_threshold
        curr_threshold += delta_threshold
      best_thresholds.append(local_best)
    elif positive_ratio < 0.4:
      local_best = 1
      curr_threshold = max(real_preds)
      while curr_threshold > 0:
        preds = [1 if pred > curr_threshold else 0 for pred in real_preds]
        tp, tn, fp, fn = get_results(trues, preds)
        recall = tp/(tp+fn)
        specificity = tn/(tn+fp)
        g_mean = np.sqrt(recall*specificity)
#         f1 = f1_score(trues, preds)
        if tn < tp:
          break
        if g_mean > max_g:
#         if f1 > max_f1:
          max_g = g_mean
#           max_f1 = f1
          local_best = curr_threshold
        curr_threshold -= delta_threshold
      best_thresholds.append(local_best)
    else:
      local_best = 0.5
      best_thresholds.append(local_best)
  return best_thresholds

def get_individual_threshold(target, output):
    real_preds = output
    trues = target
    max_g = 0
    # max_f1 = 0
    delta_threshold = 0.0001 # findMinDiff(real_preds)*0.9
    positive_ratio = sum(trues)/len(trues)
#     print('pr: ', positive_ratio)
    if positive_ratio > 0.5:
      local_best = 0
      curr_threshold = min(real_preds)
#       print(curr_threshold)
      while curr_threshold < 1:
        preds = [1 if pred > curr_threshold else 0 for pred in real_preds]
        tp, tn, fp, fn = get_results(trues, preds)
        recall = tp/(tp+fn)
        specificity = tn/(tn+fp)
        g_mean = np.sqrt(recall*specificity)
        # f1 = f1_score(trues, preds)
#         print(f1, max_f1, curr_threshold, local_best, tp, tn, fp, fn)
        if tp < tn:
          break
        if g_mean > max_g:
        # if f1 > max_f1:
          max_g = g_mean
          # max_f1 = f1
          local_best = curr_threshold
        curr_threshold += delta_threshold
      return local_best
    else:
      local_best = 1
      curr_threshold = max(real_preds)
      while curr_threshold > 0:
        preds = [1 if pred > curr_threshold else 0 for pred in real_preds]
        tp, tn, fp, fn = get_results(trues, preds)
        recall = tp/(tp+fn)
        specificity = tn/(tn+fp)
        g_mean = np.sqrt(recall*specificity)
        # f1 = f1_score(trues, preds)
        if tn < tp:
          break
        if g_mean > max_g:
        # if f1 > max_f1:
          max_g = g_mean
          # max_f1 = f1
          local_best = curr_threshold
        curr_threshold -= delta_threshold
      return local_best
    # else:
    #   local_best = 0.5
    #   return local_best

In [None]:
def loss_fun(outputs, targets):
    loss = nn.BCEWithLogitsLoss()
    # loss = BinaryFocalLossWithLogits(alpha=0.25, reduction='mean')
    try:
      return loss(outputs, targets)
    except Exception:
      print(outputs, targets)
      traceback.print_exc()
    # return nn.BCEWithLogitsLoss()(outputs, targets)

def individual_evaluation(target, predicted):
  individual = {}
  for i in range(len(target[0])):
    temp_t = target[:, i]
    temp_p = predicted[:, i]
    diction = dict(
        accuracy=accuracy_score(temp_t, temp_p),
        f1=f1_score(temp_t, temp_p)
    )
    individual[str(i)] = diction
  return individual

def evaluate(target, predicted):
    thresholds = get_thresholds(target, predicted)
    print('thresholds: ', thresholds)
    true_predicted = np.array([[1 if val > thresholds[i] else 0 for (i, val) in enumerate(pred)] for pred in predicted])
    accuracy = accuracy_score(target, true_predicted)
    macro_f1 = f1_score(target, true_predicted, average='macro')
    micro_f1 = f1_score(target, true_predicted, average='micro')
    weighted_f1 = f1_score(target, true_predicted, average='weighted')
    hl = hamming_loss(target, true_predicted)
    js = jaccard_score(target, true_predicted)
    macro_js = jaccard_score(target, true_predicted, average="macro")
    micro_js = jaccard_score(target, true_predicted, average="micro")
    individual = individual_evaluation(target, true_predicted)
    return {
        "accuracy": accuracy,
        "jaccard_score_average": js,
        "jaccard_score_macro": macro_js,
        "jaccard_score_micro": micro_js,
        "macro-f1": macro_f1,
        "micro-f1": micro_f1,
        "Hamming Loss": hl,
        "Individual": individual
    }

def individual_evaluation(target, predicted):
    threshold = get_individual_threshold(target, predicted)
    print('threshold: ',threshold)
    true_predicted = np.array([1 if val > threshold else 0 for val in predicted])
    default_true_predicted = np.array([1 if val > 0.5 else 0 for val in predicted])
    accuracy = accuracy_score(target, true_predicted)
    f1 = f1_score(target, true_predicted)
    tp, tn, fp, fn = get_results(target, true_predicted)
    recall = tp/(tp+fn)
    specificity = tn/(tn+fp)
    pr = sum(target)/len(target)

    default_accuracy = accuracy_score(target, default_true_predicted)
    default_f1 = f1_score(target, default_true_predicted)
    tp, tn, fp, fn = get_results(target, default_true_predicted)
    default_recall = tp/(tp+fn)
    default_specificity = tn/(tn+fp)
    return {
        "Positive Rate": pr,
        "threshold": threshold[0],
        "accuracy": accuracy,
        "f1": f1,
        "recall": recall,
        "specificity": specificity,
        "default_accuracy": default_accuracy,
        "default_f1": default_f1,
        "default_recall": default_recall,
        "default_specificity": default_specificity,
    }

In [None]:
def individual_test(target, predicted, threshold):
    true_predicted = np.array([1 if val > threshold else 0 for val in predicted])
    default_true_predicted = np.array([1 if val > 0.5 else 0 for val in predicted])
    accuracy = accuracy_score(target, true_predicted)
    f1 = f1_score(target, true_predicted)
    tp, tn, fp, fn = get_results(target, true_predicted)
    recall = tp/(tp+fn)
    specificity = tn/(tn+fp)
    pr = sum(target)/len(target)

    default_accuracy = accuracy_score(target, default_true_predicted)
    default_f1 = f1_score(target, default_true_predicted)
    tp, tn, fp, fn = get_results(target, default_true_predicted)
    default_recall = tp/(tp+fn)
    default_specificity = tn/(tn+fp)
    return {
        "Positive Rate": pr,
        "accuracy": accuracy,
        "f1": f1,
        "recall": recall,
        "specificity": specificity,
        "default_accuracy": default_accuracy,
        "default_f1": default_f1,
        "default_recall": default_recall,
        "default_specificity": default_specificity,
    }

In [None]:
def individual_eval_loop_fun1(data_loader, model, device, label_index=0):
    model.eval()
    fin_targets = []
    fin_outputs = []
    losses = []
    for batch_idx, batch in enumerate(data_loader):
        text = [Sentence(data["all_perceptions"]) for data in batch]
        labels = [data["labels"][label_index] for data in batch]
        targets = []
        if len(labels) > 1:
            for label_set in labels:
              miniset = []
              for label in label_set:
                miniset.append(torch.tensor([label]))
              targets.append(torch.stack(miniset))
        else:
            miniset = [torch.tensor([labels[0]])]
            targets.append(torch.stack(miniset))
        diagnostics = [data["diagnostics"] for data in batch]

        # text = torch.cat(text)
        targets = torch.cat(targets)
        diagnostics = torch.cat(diagnostics)

        # ids = text.to(device, dtype=torch.long)
        # mask = mask.to(device, dtype=torch.long)
        # token_type_ids = token_type_ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)
        diagnostics = diagnostics.to(device, dtype=torch.long)

        with torch.no_grad():
            outputs = torch.stack([model(sentence_inp=text, diagnostics=diagnostics)])
            loss = loss_fun(outputs, targets)
            losses.append(loss.item())
            
        outputs = torch.sigmoid(outputs)
        fin_targets.append(targets.cpu().detach().numpy())
        fin_outputs.append(outputs.cpu().detach().numpy())
    return np.concatenate(fin_outputs), np.concatenate(fin_targets), losses

In [None]:
def individual_test_loop_fun1(data_loader, model, device, label_index=0):
    model.eval()
    fin_targets = []
    fin_outputs = []
    for batch_idx, batch in enumerate(data_loader):
        text = [Sentence(data["all_perceptions"]) for data in batch]
        labels = [data["labels"][label_index] for data in batch]
        targets = []
        if len(labels) > 1:
            for label_set in labels:
              miniset = []
              for label in label_set:
                miniset.append(torch.tensor([label]))
              targets.append(torch.stack(miniset))
        else:
            miniset = [torch.tensor([labels[0]])]
            targets.append(torch.stack(miniset))
        diagnostics = [data["diagnostics"] for data in batch]

        # text = torch.cat(text)
        targets = torch.cat(targets)
        diagnostics = torch.cat(diagnostics)

        # ids = text.to(device, dtype=torch.long)
        # mask = mask.to(device, dtype=torch.long)
        # token_type_ids = token_type_ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)
        diagnostics = diagnostics.to(device, dtype=torch.long)

        with torch.no_grad():
            outputs = torch.stack([model(sentence_inp=text, diagnostics=diagnostics)])
            
        outputs = torch.sigmoid(outputs)
        fin_targets.append(targets.cpu().detach().numpy())
        fin_outputs.append(outputs.cpu().detach().numpy())
    return np.concatenate(fin_outputs), np.concatenate(fin_targets)

### Utils

In [None]:
import statistics

def get_results(targets, outputs):
  TN = 0
  TP = 0
  FP = 0
  FN = 0
  for (i, output) in enumerate(outputs):
    if output==0:
      if targets[i]==0:
        TN += 1
      else:
        FN += 1
    else:
      if targets[i]==1:
        TP += 1
      else:
        FP += 1
  return TP, TN, FP, FN

def findMinDiff(arr):
    n = len(arr)
    arr = sorted(arr)
    diff = 0.5
    for i in range(n-1):
        if arr[i+1] - arr[i] > 0 and arr[i+1] - arr[i] < diff:
            diff = arr[i+1] - arr[i]
    return diff

def get_thresholds(targets, outputs):
  best_thresholds = []
  for i in range(len(outputs[0])):
    real_preds = outputs[:, i]
    trues = targets[:, i]
    max_g = 0
#     max_f1 = 0
    delta_threshold = 0.0001 # findMinDiff(real_preds)*0.9
    positive_ratio = sum(trues)/len(trues)
#     print('pr: ', positive_ratio)
    if positive_ratio > 0.6:
      local_best = 0
      curr_threshold = min(real_preds)
#       print(curr_threshold)
      while curr_threshold < 1:
        preds = [1 if pred > curr_threshold else 0 for pred in real_preds]
        tp, tn, fp, fn = get_results(trues, preds)
        recall = tp/(tp+fn)
        specificity = tn/(tn+fp)
        g_mean = np.sqrt(recall*specificity)
#         f1 = f1_score(trues, preds)
#         print(f1, max_f1, curr_threshold, local_best, tp, tn, fp, fn)
        if tp < tn:
          break
        if g_mean > max_g:
#         if f1 > max_f1:
          max_g = g_mean
#           max_f1 = f1
          local_best = curr_threshold
        curr_threshold += delta_threshold
      best_thresholds.append(local_best)
    elif positive_ratio < 0.4:
      local_best = 1
      curr_threshold = max(real_preds)
      while curr_threshold > 0:
        preds = [1 if pred > curr_threshold else 0 for pred in real_preds]
        tp, tn, fp, fn = get_results(trues, preds)
        recall = tp/(tp+fn)
        specificity = tn/(tn+fp)
        g_mean = np.sqrt(recall*specificity)
#         f1 = f1_score(trues, preds)
        if tn < tp:
          break
        if g_mean > max_g:
#         if f1 > max_f1:
          max_g = g_mean
#           max_f1 = f1
          local_best = curr_threshold
        curr_threshold -= delta_threshold
      best_thresholds.append(local_best)
    else:
      local_best = 0.5
      best_thresholds.append(local_best)
  return best_thresholds

def get_individual_threshold(target, output):
    real_preds = output
    trues = target
    max_g = 0
    # max_f1 = 0
    delta_threshold = 0.0001 # findMinDiff(real_preds)*0.9
    positive_ratio = sum(trues)/len(trues)
#     print('pr: ', positive_ratio)
    if positive_ratio > 0.5:
      local_best = 0
      curr_threshold = min(real_preds)
#       print(curr_threshold)
      while curr_threshold < 1:
        preds = [1 if pred > curr_threshold else 0 for pred in real_preds]
        tp, tn, fp, fn = get_results(trues, preds)
        recall = tp/(tp+fn)
        specificity = tn/(tn+fp)
        g_mean = np.sqrt(recall*specificity)
        # f1 = f1_score(trues, preds)
#         print(f1, max_f1, curr_threshold, local_best, tp, tn, fp, fn)
        if tp < tn:
          break
        if g_mean > max_g:
        # if f1 > max_f1:
          max_g = g_mean
          # max_f1 = f1
          local_best = curr_threshold
        curr_threshold += delta_threshold
      return local_best
    else:
      local_best = 1
      curr_threshold = max(real_preds)
      while curr_threshold > 0:
        preds = [1 if pred > curr_threshold else 0 for pred in real_preds]
        tp, tn, fp, fn = get_results(trues, preds)
        recall = tp/(tp+fn)
        specificity = tn/(tn+fp)
        g_mean = np.sqrt(recall*specificity)
        # f1 = f1_score(trues, preds)
        if tn < tp:
          break
        if g_mean > max_g:
        # if f1 > max_f1:
          max_g = g_mean
          # max_f1 = f1
          local_best = curr_threshold
        curr_threshold -= delta_threshold
      return local_best
    # else:
    #   local_best = 0.5
    #   return local_best

In [None]:
def loss_fun(outputs, targets):
    loss = nn.BCEWithLogitsLoss()
    # loss = BinaryFocalLossWithLogits(alpha=0.25, reduction='mean')
    try:
      return loss(outputs, targets)
    except Exception:
      print(outputs, targets)
      traceback.print_exc()
    # return nn.BCEWithLogitsLoss()(outputs, targets)

def individual_evaluation(target, predicted):
  individual = {}
  for i in range(len(target[0])):
    temp_t = target[:, i]
    temp_p = predicted[:, i]
    diction = dict(
        accuracy=accuracy_score(temp_t, temp_p),
        f1=f1_score(temp_t, temp_p)
    )
    individual[str(i)] = diction
  return individual

def evaluate(target, predicted):
    thresholds = get_thresholds(target, predicted)
    print('thresholds: ', thresholds)
    true_predicted = np.array([[1 if val > thresholds[i] else 0 for (i, val) in enumerate(pred)] for pred in predicted])
    accuracy = accuracy_score(target, true_predicted)
    macro_f1 = f1_score(target, true_predicted, average='macro')
    micro_f1 = f1_score(target, true_predicted, average='micro')
    weighted_f1 = f1_score(target, true_predicted, average='weighted')
    hl = hamming_loss(target, true_predicted)
    js = jaccard_score(target, true_predicted)
    macro_js = jaccard_score(target, true_predicted, average="macro")
    micro_js = jaccard_score(target, true_predicted, average="micro")
    individual = individual_evaluation(target, true_predicted)
    return {
        "accuracy": accuracy,
        "jaccard_score_average": js,
        "jaccard_score_macro": macro_js,
        "jaccard_score_micro": micro_js,
        "macro-f1": macro_f1,
        "micro-f1": micro_f1,
        "Hamming Loss": hl,
        "Individual": individual
    }

def individual_evaluation(target, predicted):
    threshold = get_individual_threshold(target, predicted)
    print('threshold: ',threshold)
    true_predicted = np.array([1 if val > threshold else 0 for val in predicted])
    default_true_predicted = np.array([1 if val > 0.5 else 0 for val in predicted])
    accuracy = accuracy_score(target, true_predicted)
    f1 = f1_score(target, true_predicted)
    tp, tn, fp, fn = get_results(target, true_predicted)
    recall = tp/(tp+fn)
    specificity = tn/(tn+fp)
    pr = sum(target)/len(target)

    default_accuracy = accuracy_score(target, default_true_predicted)
    default_f1 = f1_score(target, default_true_predicted)
    tp, tn, fp, fn = get_results(target, default_true_predicted)
    default_recall = tp/(tp+fn)
    default_specificity = tn/(tn+fp)
    return {
        "Positive Rate": pr,
        "threshold": threshold[0],
        "accuracy": accuracy,
        "f1": f1,
        "recall": recall,
        "specificity": specificity,
        "default_accuracy": default_accuracy,
        "default_f1": default_f1,
        "default_recall": default_recall,
        "default_specificity": default_specificity,
    }

In [None]:
def individual_test(target, predicted, threshold):
    true_predicted = np.array([1 if val > threshold else 0 for val in predicted])
    default_true_predicted = np.array([1 if val > 0.5 else 0 for val in predicted])
    accuracy = accuracy_score(target, true_predicted)
    f1 = f1_score(target, true_predicted)
    tp, tn, fp, fn = get_results(target, true_predicted)
    recall = tp/(tp+fn)
    specificity = tn/(tn+fp)
    pr = sum(target)/len(target)

    default_accuracy = accuracy_score(target, default_true_predicted)
    default_f1 = f1_score(target, default_true_predicted)
    tp, tn, fp, fn = get_results(target, default_true_predicted)
    default_recall = tp/(tp+fn)
    default_specificity = tn/(tn+fp)
    return {
        "Positive Rate": pr,
        "accuracy": accuracy,
        "f1": f1,
        "recall": recall,
        "specificity": specificity,
        "default_accuracy": default_accuracy,
        "default_f1": default_f1,
        "default_recall": default_recall,
        "default_specificity": default_specificity,
    }

### Models

In [None]:
class New_Model(nn.Module):

  def __init__(self, embedding_model, n_diags):
    super(New_Model, self).__init__()

    # Pass the flair
    self.embedding_model = embedding_model

    self.n_diags = n_diags
        
    self.embedding_model.eval()
    self.embedding_model.zero_grad()

    self.lstm_output = 100

    self.lstm = nn.LSTM(768, self.lstm_output, num_layers=1, bidirectional=True)
    self.out = nn.Linear( self.lstm_output*2 + n_diags, 1)

  def forward(self, input_ids_diags):
    input_ids = input_ids_diags[:,:-1*self.n_diags]
    diags = input_ids_diags[:, -1*self.n_diags:]
    sequence_output = self.embedding_model(input_ids=input_ids).last_hidden_state

    b = sequence_output.transpose(0, 1)
    packed_output, (h_t, h_c) = self.lstm(b, )
    hidden = torch.cat((h_t[0],h_t[1]),dim=1)
    output = torch.cat((hidden, diags), dim=1)
    output = self.out(output)
    return output

In [None]:
def individual_train_loop_fun1(data_loader, model, optimizer, device, grad_accs, scheduler=None, label_index=0):
    model.train()
    t0 = time.time()
    losses = []
    optimizer.zero_grad()
    for batch_idx, batch in enumerate(data_loader):
        text = [data["all_tokens"] for data in batch]
        labels = [data["labels"][label_index] for data in batch]
        targets = []
        if len(labels) > 1:
            for label_set in labels:
              miniset = []
              for label in label_set:
                miniset.append(torch.tensor([label]))
              targets.append(torch.stack(miniset))
        else:
            miniset = [torch.tensor([labels[0]])]
            targets.append(torch.stack(miniset))
        diagnostics = [data["diagnostics"] for data in batch]

        text = torch.cat(text)
        text = text.to(device)

        # text = torch.cat(text)
        targets = torch.cat(targets)
        diagnostics = torch.cat(diagnostics)
        diagnostics = torch.stack([diagnostics])

        # ids = text.to(device, dtype=torch.long)
        # mask = mask.to(device, dtype=torch.long)
        # token_type_ids = token_type_ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)
        diagnostics = diagnostics.to(device, dtype=torch.long)

        input_ids_diags = torch.cat((text, diagnostics), dim=1)
        input_ids_diags = input_ids_diags.to(device)

        outputs = model(input_ids_diags=input_ids_diags)
        # outputs = torch.cat(torch.unbind(outputs))
        loss = loss_fun(outputs, targets)
        (loss / grad_accs).backward()
        model.float()
        if (batch_idx + 1) % grad_accs == 0:
            optimizer.step()
            optimizer.zero_grad()
            if scheduler:
                scheduler.step()
        losses.append(loss.item())
        if batch_idx % 250 == 0:
            print(
                f"___ batch index = {batch_idx} / {len(data_loader)} ({100*batch_idx / len(data_loader):.2f}%), loss = {np.mean(losses[-10:]):.4f}, time = {time.time()-t0:.2f} seconds ___")
            t0 = time.time()
    return losses

In [None]:
def individual_eval_loop_fun1(data_loader, model, device, label_index=0):
    model.eval()
    fin_targets = []
    fin_outputs = []
    losses = []
    for batch_idx, batch in enumerate(data_loader):
        text = [data["all_tokens"]  for data in batch]
        labels = [data["labels"][label_index] for data in batch]
        targets = []
        if len(labels) > 1:
            for label_set in labels:
              miniset = []
              for label in label_set:
                miniset.append(torch.tensor([label]))
              targets.append(torch.stack(miniset))
        else:
            miniset = [torch.tensor([labels[0]])]
            targets.append(torch.stack(miniset))
        diagnostics = [data["diagnostics"] for data in batch]

        text = torch.cat(text)
        text = text.to(device)

        # text = torch.cat(text)
        targets = torch.cat(targets)
        diagnostics = torch.cat(diagnostics)
        diagnostics = torch.stack([diagnostics])

        # ids = text.to(device, dtype=torch.long)
        # mask = mask.to(device, dtype=torch.long)
        # token_type_ids = token_type_ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)
        diagnostics = diagnostics.to(device, dtype=torch.long)

        input_ids_diags = torch.cat((text, diagnostics), dim=1)
        input_ids_diags = input_ids_diags.to(device)

        with torch.no_grad():
            outputs = model(input_ids_diags=input_ids_diags)
            loss = loss_fun(outputs, targets)
            losses.append(loss.item())
            
        outputs = torch.sigmoid(outputs)
        fin_targets.append(targets.cpu().detach().numpy())
        fin_outputs.append(outputs.cpu().detach().numpy())
    return np.concatenate(fin_outputs), np.concatenate(fin_targets), losses

In [None]:
def individual_test_loop_fun1(data_loader, model, device, label_index=0):
    model.eval()
    fin_targets = []
    fin_outputs = []
    for batch_idx, batch in enumerate(data_loader):
        text = [data["all_tokens"]  for data in batch]
        labels = [data["labels"][label_index] for data in batch]
        targets = []
        if len(labels) > 1:
            for label_set in labels:
              miniset = []
              for label in label_set:
                miniset.append(torch.tensor([label]))
              targets.append(torch.stack(miniset))
        else:
            miniset = [torch.tensor([labels[0]])]
            targets.append(torch.stack(miniset))
        diagnostics = [data["diagnostics"] for data in batch]

        text = torch.cat(text)
        text = text.to(device)

        # text = torch.cat(text)
        targets = torch.cat(targets)
        diagnostics = torch.cat(diagnostics)
        diagnostics = torch.stack([diagnostics])

        # ids = text.to(device, dtype=torch.long)
        # mask = mask.to(device, dtype=torch.long)
        # token_type_ids = token_type_ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)
        diagnostics = diagnostics.to(device, dtype=torch.long)

        input_ids_diags = torch.cat((text, diagnostics), dim=1)
        input_ids_diags = input_ids_diags.to(device)

        with torch.no_grad():
            outputs = model(input_ids_diags=input_ids_diags)
            
        outputs = torch.sigmoid(outputs)
        fin_targets.append(targets.cpu().detach().numpy())
        fin_outputs.append(outputs.cpu().detach().numpy())
    return np.concatenate(fin_outputs), np.concatenate(fin_targets)

In [None]:
class Single_Flair_Model(nn.Module):
    """ A Model for bert fine tuning """

    def __init__(self, n_diags):
        super(Single_Flair_Model, self).__init__()
        self.embedding = TransformerDocumentEmbeddings("dccuchile/bert-base-spanish-wwm-cased", fine_tune=True, layers='-1')
        self.out = nn.Linear(768 + n_diags, 1)

    def forward(self, sentence_inp, diagnostics):
        self.embedding.embed(sentence_inp)
        emb = [sentence.get_embedding() for sentence in sentence_inp]
        output = torch.stack(emb)
        output = torch.cat((output, diagnostics), dim=1)
        output = self.out(output)
        return output

In [None]:
emb_model = torch.load("/research/jamunoz/models/flair_fine_tuning/model_ft_0.pt", map_location=device)

In [None]:
EPOCH = 30
LABEL_INDEX = 0
GRADIENT_ACCUMULATIONS = 16
lr = 1e-5

num_training_steps=int(len(transformed_train_dataset) / TRAIN_BATCH_SIZE * EPOCH)

model=New_Model(embedding_model=emb_model.embedding.model, n_diags=len(diagnoses_keys)).to(device)
optimizer=AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                      num_warmup_steps=0,
                                      num_training_steps=num_training_steps)
val_losses=[]
batches_losses=[]
val_acc=[]

best_model = None
best_f1 = 0
th = 0

for epoch in range(EPOCH):
  t0 = time.time()    
  print(f"\n=============== EPOCH {epoch+1} / {EPOCH} ===============\n")
  # Modify according to individual or all
  batches_losses_tmp=individual_train_loop_fun1(train_data_loader, model, optimizer, device, GRADIENT_ACCUMULATIONS, scheduler=scheduler, label_index=LABEL_INDEX)
  epoch_loss=np.mean(batches_losses_tmp)
  print(f"\n*** avg_loss : {epoch_loss:.2f}, time : ~{(time.time()-t0)//60} min ({time.time()-t0:.2f} sec) ***\n")
  t1=time.time()
  # Modify according to individual or all
  output, target, val_losses_tmp=individual_eval_loop_fun1(val_data_loader, model, device, label_index=LABEL_INDEX)
  print(f"==> evaluation : avg_loss = {np.mean(val_losses_tmp):.2f}, time : {time.time()-t1:.2f} sec\n")
  tmp_evaluate=individual_evaluation(target, output)
  print(f"=====>\t{tmp_evaluate}")
  th = tmp_evaluate["threshold"]
  val_acc.append(tmp_evaluate['accuracy'])
  val_losses.append(val_losses_tmp)
  batches_losses.append(batches_losses_tmp)
    
torch.save(model, f"/research/jamunoz/models/flair_fine_tuning/lstm_model_"+str(LABEL_INDEX)+"_v2.pt")
# model = torch.load("/research/jamunoz/models/flair_fine_tuning/lstm_model_"+str(LABEL_INDEX)+".pt", map_location=device)
# th = 0.44041115
output, target=individual_test_loop_fun1(test_data_loader, model, device, label_index=LABEL_INDEX)
tmp_test=individual_test(target, output, th)



___ batch index = 0 / 1836 (0.00%), loss = 0.5615, time = 0.85 seconds ___
___ batch index = 250 / 1836 (13.62%), loss = 0.9095, time = 101.21 seconds ___
___ batch index = 500 / 1836 (27.23%), loss = 0.5609, time = 101.80 seconds ___
___ batch index = 750 / 1836 (40.85%), loss = 0.6130, time = 100.98 seconds ___
___ batch index = 1000 / 1836 (54.47%), loss = 0.7059, time = 101.98 seconds ___
___ batch index = 1250 / 1836 (68.08%), loss = 0.7225, time = 101.14 seconds ___
___ batch index = 1500 / 1836 (81.70%), loss = 0.5673, time = 101.66 seconds ___
___ batch index = 1750 / 1836 (95.32%), loss = 0.8253, time = 102.23 seconds ___

*** avg_loss : 0.70, time : ~12.0 min (747.08 sec) ***

==> evaluation : avg_loss = 0.70, time : 64.73 sec

threshold:  [0.34935382]
=====>	{'Positive Rate': array([0.4332784], dtype=float32), 'threshold': 0.34935382, 'accuracy': 0.4695222405271829, 'f1': 0.47039473684210525, 'recall': 0.5437262357414449, 'specificity': 0.4127906976744186, 'default_accurac

In [None]:
print(tmp_test)

{'Positive Rate': array([0.44425675], dtype=float32), 'accuracy': 0.6452702702702703, 'f1': 0.6534653465346536, 'recall': 0.752851711026616, 'specificity': 0.5592705167173252, 'default_accuracy': 0.6452702702702703, 'default_f1': 0.6379310344827587, 'default_recall': 0.7034220532319392, 'default_specificity': 0.5987841945288754}
