In [9]:
import logging
import torch
import numpy as np
from typing import Dict
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, RobertaConfig, get_scheduler
import operator
import os

from models.overlap_bert import MeanBERT, LSTMBERT
from data.dataset import OverlapDataset, collate_cases, collate_cases_time, DateVisitsDataset
from data.load_data import load_real_data
from pipeline.multi_bert import train_epoch, evaluate_model
from utils.evaluation import singlelabel_eval

from utils.import_config import import_config
from utils.logger import setup_logger

In [3]:
comparison_ops = {
    'val_loss': operator.lt,
    'ham_loss': operator.lt,
    'val_accuracy': operator.gt,
    'val_macro_f1': operator.gt,
    'val_weighted_f1': operator.gt,
}

In [5]:
class temp_args:
    def __init__(self):
        self.action = 'train'
        self.log = 'info'

args = temp_args()
config = import_config("config.yaml")

logger = setup_logger(config['general']['log_dir'], args)

Loading configuration from config/config.yaml


# Data

In [None]:
train_df, test_df, val_df = load_real_data(config['general']['dataset'], stratify_col='label', test_size=0.2, dev_size=0.1)

tokenizer = AutoTokenizer.from_pretrained(config['model']['tokenizer_path'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

method = 'lstm-attn'

output_dir = os.path.join(config['training']['output_dir'], method)

# Create datasets and dataloaders
train_dataset = OverlapDataset(train_df['visits'].to_list(), train_df['label'].to_list())
val_dataset = OverlapDataset(val_df['visits'].to_list(), val_df['label'].to_list())
test_dataset = OverlapDataset(test_df['visits'].to_list(), test_df['label'].to_list())

train_loader = DataLoader(
    train_dataset, 
    batch_size=config['training']['batch_size'], 
    shuffle=True, 
    collate_fn=lambda b: collate_cases(b, tokenizer, max_length=512)
    )
val_loader = DataLoader(
    val_dataset, 
    batch_size=config['training']['batch_size'], 
    collate_fn=lambda b: collate_cases(b, tokenizer, max_length=512)
    )
test_loader = DataLoader(
    test_dataset, 
    batch_size=config['training']['batch_size'], 
    collate_fn=lambda b: collate_cases(b, tokenizer, max_length=512)
    )

# Model

In [12]:
hf_config = RobertaConfig.from_pretrained(config['model']['tokenizer_path'])
hf_config.loss_fct = config['training']['loss']
hf_config.output_dim = config['model']['output_dim']
hf_config.num_tasks = config['model']['num_tasks']
hf_config.freeze_bert = config['model']['freeze_bert']
hf_config.classifier_dropout = config['model']['dropout']
if method == 'mean':
    model = MeanBERT(hf_config, pretrained_model=config['model']['pretrained_model'])
elif method == 'lstm':
    hf_config.lstm_hidden = 256
    model = LSTMBERT(hf_config, pretrained_model=config['model']['pretrained_model'])
elif method == 'lstm-attn':
    hf_config.lstm_hidden = 256
    hf_config.attn_dim = 128
    model = LSTMBERT(hf_config, pretrained_model=config['model']['pretrained_model'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Initialize training methods
accelerator = Accelerator()
optimizer = AdamW(model.parameters(), lr=config['training']['learning_rate'])

num_epochs = config['training']['num_epochs']
num_training_steps = num_epochs * len(train_loader)
scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

model, optimizer, train_loader, val_loader, test_loader, scheduler = accelerator.prepare(
    model, optimizer, train_loader, val_loader, test_loader, scheduler
)

Some weights of RobertaModel were not initialized from the model checkpoint at /home/isglobal.lan/erodriguez1/.cache/huggingface/hub/models--PlanTL-GOB-ES--bsc-bio-ehr-es/snapshots/cb0a3d3d85692d19e7cbd8b3f02f9263fa343837 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Train

In [None]:
# Training loop
best_es_metric_val = 0
es_drag = 0
es_metric = config['training']['early_stopping_metric']
es_patience = config['training']['early_stopping_patience']
es_compare = comparison_ops[es_metric]
for epoch in range(num_epochs):
    logger.info(f"\nEpoch {epoch + 1}/{num_epochs}")
    logger.info("-" * 50)
    
    # Train
    train_loss = train_epoch(
        model, train_loader, optimizer, scheduler, device
    )
    logger.info(f"Train Loss: {train_loss:.4f}")
    
    # Evaluate
    val_loss, preds, actual = evaluate_model(
        model, val_loader, device
    )
    logger.info(f"Val Loss: {val_loss:.4f}")

    val_metrics = singlelabel_eval(
        np.array(actual),
        np.array(preds),
        logging.getLogger(),
        target_names=['Neither', 'Cardiological Only', 'Digestive Only',]
    )
    val_metrics['es']['val_loss'] = val_loss  
    logger.info(f"Val {es_metric}: {val_metrics['es'][es_metric]:.4f}")

    # Save best model based on average F1 score
    if es_compare(val_metrics['es'][es_metric], best_es_metric_val) or epoch == 0:
        best_es_metric_val = val_metrics['es'][es_metric]
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
        tokenizer.save_pretrained(output_dir)
        logger.info(f"New best model saved in {output_dir}!")
        es_drag = 0
    else:
        es_drag += 1
        if es_drag >= es_patience:
            logger.info(f"Early stopping triggered after {epoch+1} epochs.")
            break

# Miscellaneous

In [20]:
visits_df = train_df['visits'].explode()

In [65]:
len("16oct2009 ")

10

In [68]:
"22jan2010ñlkjj"[9:]

'ñlkjj'

In [87]:
from typing import List, Union
from datetime import timedelta, date

def impute_missing_dates(dates: List[Union[date,None]]) -> List[date]:
    if all(d is None for d in dates):
        return [date(2024, 1, 1) for _ in dates]
    elif all(d is not None for d in dates):
        return dates

    n = len(dates)
    result = dates.copy()

    # Fill missing in middle using avg of left and right
    for i in range(n):
        if result[i] is None:
            # find left
            left = next((result[j] for j in range(i - 1, -1, -1) if result[j] is not None), None)
            # find right
            right = next((result[j] for j in range(i + 1, n) if result[j] is not None), None)

            if left and right:
                avg_seconds = (left.timestamp() + right.timestamp()) / 2
                result[i] = date.fromtimestamp(avg_seconds)
            elif left and not right:
                result[i] = left + timedelta(days=1)
            elif right and not left:
                result[i] = right - timedelta(days=1)
            else:
                result[i] = date(2024, 1, 1)

    return result
def pop_dates(visits: List[str]) -> pd.Series[List[str], List[date]]:
    new_visits = []
    dates = []
    for visit in visits:
        new_visit = visit.split("Fecha: ")[1]
        try:
            date_str, new_visit = new_visit[:9], new_visit[9:]
            d = date.strptime(date_str, "%d%b%Y")
        except ValueError:
            d = None
        new_visits.append(new_visit)
        dates.append(d)
    return pd.Series([new_visits, impute_missing_dates(dates)])

In [88]:
df[['new_visits', 'dates']] = df['visits'].apply(pop_dates)

In [90]:
df

Unnamed: 0,person_id,label,visits,num_visits,visits_num,new_visits,dates
0,PATTERN_M-00314259,1,[Fecha: 16oct2009 OBSERVACIONES: ESPECIMEN: B0...,8,8,[ OBSERVACIONES: ESPECIMEN: B0922962 FECHA ENT...,"[2009-10-16, 2015-06-08, 2015-06-29, 2017-03-1..."
1,PATTERN_M-00337378,1,[Fecha: 14nov2018 ANOTACIONES: Enfermedad de C...,2,2,[ ANOTACIONES: Enfermedad de Chagas en posible...,"[2018-11-14, 2018-12-12]"
2,PATTERN_M-00710666,1,[Fecha: 08mar2017 PROBLEMAS DE SALUD: CONSULTA...,9,9,[ PROBLEMAS DE SALUD: CONSULTA GENÉRICA GINECO...,"[2017-03-08, 2017-04-24, 2018-02-01, 2018-06-1..."
3,PATTERN_M-01649881,2,[Fecha: 15jul2013 ANOTACIONES: CONSULTA 15.7.1...,22,22,[ ANOTACIONES: CONSULTA 15.7.13 Mejoria con ci...,"[2013-07-15, 2013-12-12, 2014-01-08, 2014-01-2..."
4,PATTERN_M-01756648,1,[Fecha: 12sep2019 ANOTACIONES: Enfermedad de C...,105,40,[ ANOTACIONES: Enfermedad de Chagas en fase cr...,"[2019-09-12, 2019-11-18, 2020-02-11, 2020-02-2..."
...,...,...,...,...,...,...,...
163,PATTERN_M-03510507,0,[Fecha: 02sep2015 OBSERVACIONES: Problema Actu...,25,25,[ OBSERVACIONES: Problema Actual: Supuracion O...,"[2015-09-02, 2015-12-20, 2016-05-12, 2016-11-0..."
164,PATTERN_M-02301744,0,[Fecha: 17may2012 OBSERVACIONES: ESPECIMEN: C1...,18,18,[ OBSERVACIONES: ESPECIMEN: C1214844 FECHA ENT...,"[2012-05-17, 2017-07-02, 2020-05-12, 2020-06-1..."
165,PATTERN_M-02021706,0,[Fecha: 25may2021 PROBLEMAS DE SALUD: Antecede...,22,22,[ PROBLEMAS DE SALUD: Antecedentes Personales ...,"[2021-05-25, 2021-11-22, 2021-12-16, 2021-12-2..."
166,PATTERN_M-03657032,0,[Fecha: 21may2017 OBSERVACIONES: Motivo de Ing...,21,21,[ OBSERVACIONES: Motivo de Ingreso: Ojo rojo d...,"[2017-05-21, 2019-04-24, 2021-03-28, 2021-04-0..."


In [28]:
import pandas as pd
import ast

df = pd.read_csv("data/data_dir/cleaned_chagas_data.csv")
df['visits'] = df['visits'].apply(ast.literal_eval)
df['visits'] = df.apply(lambda row: row['visits'] if row['num_visits'] <= 40 else row['visits'][-40:], axis=1) # if there are more than 40 visits, take only the last 40
df.rename(columns={'clasificacion_diag': 'label'}, inplace=True)

In [36]:
visits_df.apply(lambda x: x.startswith('Fecha: ')).value_counts()

visits
True    2569
Name: count, dtype: int64

In [56]:
"Fecha: 22jan2010ñlkjj".split("Fecha: ")[1][:9]

'22jan2010'

In [57]:
fechas = visits_df.apply(lambda x: x.split("Fecha: ")[1][:9])

In [62]:
from torch import log1p



In [63]:
log1p(torch.tensor([0,1,32,45]))

tensor([0.0000, 0.6931, 3.4965, 3.8286])

In [60]:
fecha

'. OBSERVA'

In [61]:
for fecha in fechas:
    try:
        datetime.strptime(fecha, "%d%b%Y")
    except ValueError:
        print(fecha)

. OBSERVA
. OBSERVA
. OBSERVA
. OBSERVA
. OBSERVA
. OBSERVA


In [40]:
from datetime import datetime

date = '05mar2021'
datetime.strptime(date, "%d%b%Y")

datetime.datetime(2021, 3, 5, 0, 0)

In [44]:
dates = visits_df.apply(lambda x: datetime.strptime(x.split("Fecha: ")[1][:9], "%d%b%Y"))

ValueError: time data '. OBSERVA' does not match format '%d%b%Y'

In [33]:
df.visits.iloc[10]

["Fecha: 05mar2021 ANOTACIONES: Vuelve a citarse de nuevo porque nota inflamacion, sensacion de calor y oscurecimiento de lesiones de forma intermitente que dura unos 3 dias y remite espontaneamente. La ulitma vez hace una semana. Refiere mas molestias desde que usa mascarilla y con cambios de temperatura. Mujer de 51 anos, trasplantada hepatica el dia 17/04/2019 en relacion con hepatitis autoinmune overlap CEP, con AP de LES y EII. Enfermedad de injerto contra huesped aguda con afectacion mucocutanea severa tipo NET en 2019, resuelta. Tratamiento habitual: Tacrolimus 5 mg, Famotidina, Dapsona, Prednisona 5 mg, Dolquine, Mesalazina, Aciclovir, Amlodipino. EPO alfa A la exploracion presenta maculas marrones calras en zona lateral de mejillas y ligero eritema en mejillas y zona cejas sin descamacion. Unas pintadas. JC. Flushing/cuperosis. Tambien el tacrolimus puede aumentar la fotosensibilidad. Plan. Se recomiendo no usar esmalte unas Usar mascarilla de tela bajo la mascarilla quir'ruga

In [30]:
df['visits_num'] = df['visits'].apply(len)

In [31]:
df

Unnamed: 0,person_id,label,visits,num_visits,visits_num
0,PATTERN_M-00314259,1,[Fecha: 16oct2009 OBSERVACIONES: ESPECIMEN: B0...,8,8
1,PATTERN_M-00337378,1,[Fecha: 14nov2018 ANOTACIONES: Enfermedad de C...,2,2
2,PATTERN_M-00710666,1,[Fecha: 08mar2017 PROBLEMAS DE SALUD: CONSULTA...,9,9
3,PATTERN_M-01649881,2,[Fecha: 15jul2013 ANOTACIONES: CONSULTA 15.7.1...,22,22
4,PATTERN_M-01756648,1,[Fecha: 12sep2019 ANOTACIONES: Enfermedad de C...,105,40
...,...,...,...,...,...
163,PATTERN_M-03510507,0,[Fecha: 02sep2015 OBSERVACIONES: Problema Actu...,25,25
164,PATTERN_M-02301744,0,[Fecha: 17may2012 OBSERVACIONES: ESPECIMEN: C1...,18,18
165,PATTERN_M-02021706,0,[Fecha: 25may2021 PROBLEMAS DE SALUD: Antecede...,22,22
166,PATTERN_M-03657032,0,[Fecha: 21may2017 OBSERVACIONES: Motivo de Ing...,21,21
