This notebook is useful to test the model loading and prediction functionalities in isolation. It creates a test folder with a saved model and runs a prediction, which doesn't have any importance since the model is not trained. It is to be used to understand the information flow as well as to debug in the case models don't match or something. 

In [None]:
from typing import List, Optional
from dataclasses import dataclass
from pathlib import Path
from transformers import RobertaConfig

"""
This config file contains the parameters of the best trained model. They must be copied exactly in order to build the exact architecture the weights will subsequently fill.
"""

class LSTMBERTConfig(RobertaConfig):
    model_type = "lstm-attn-bert"

    def __init__(self, **kwargs):
        # Let RobertaConfig initialize all standard attributes
        super().__init__(**kwargs)

        # Store any extra keys as attributes without listing them
        for k, v in kwargs.items():
            if not hasattr(self, k):
                setattr(self, k, v)


In [None]:
from transformers import RobertaForSequenceClassification, RobertaModel
from transformers.modeling_outputs import SequenceClassifierOutput
from torch import nn
import torch
from typing import Optional

"""
This is the model architecture corresponding to the weights saved in the saved_models folder. The layers should match exactly the parameters defined in model_config.py
"""

class LSTMBERT(RobertaForSequenceClassification):
    def __init__(self, config: LSTMBERTConfig, **kwargs):
        super().__init__(config)
        self.config = config

        self.roberta = RobertaModel(config)

        lstm_input_dim = self.roberta.config.hidden_size + config.visit_time_dim

        self.lstm = nn.LSTM(
            lstm_input_dim,
            getattr(config, "lstm_hidden"),
            batch_first=True,
            bidirectional=True,
            num_layers=getattr(config, "lstm_layers", 1),
        )

        attn_dim = getattr(config, "attn_dim")
        self.attn = nn.Sequential(
            nn.Linear(config.lstm_hidden * 2, attn_dim),
            nn.Tanh(),
            nn.Linear(attn_dim, 1)
        )

        self.classifier = nn.Linear(config.lstm_hidden * 2, config.output_dim)

    def forward(self, *args, visit_times: torch.Tensor, **kwargs) -> SequenceClassifierOutput:
        """
        kwargs:
            input_ids: (V, S) long
            attention_mask: (V, S) long
        visit_times: tensor (V, visit_time_dim) float
        
        where V = number of visits in batch, S = max seq len per visit
        """
        input_ids = kwargs.get("input_ids", args[0] if len(args) > 0 else None)
        attention_mask = kwargs.get("attention_mask", args[1] if len(args) > 1 else None)
        if input_ids is None or attention_mask is None:
            raise ValueError("You have to specify input_ids and attention_mask")
        V, S = input_ids.shape
        
        # Check visit_times shape if needed
        if visit_times is None:
            raise ValueError("You have to provide visit_times tensor")
        if visit_times.shape != (V, self.config.visit_time_dim):
            raise ValueError(f"visit_times shape must be (V, {self.config.visit_time_dim})")
        
        # Process each visit through RoBERTa
        pooled_visits = []
        for i in range(V):
            out = self.roberta(
                input_ids=input_ids[i:i+1],  # (1, S)
                attention_mask=attention_mask[i:i+1],
                return_dict=True
            )
            last_hidden = out.last_hidden_state  # (1, S, hidden)
            mask = attention_mask[i:i+1].unsqueeze(-1)  # (1, S, 1)
            cls_vec = (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)  # (1, hidden)
            pooled_visits.append(cls_vec)
        
        proj = torch.cat(pooled_visits, dim=0)  # (V, hidden)
        
        # Concatenate visit times if provided
        proj = torch.cat([proj, visit_times], dim=-1)  # (V, hidden + visit_time_dim)
        
        # Add batch dimension for LSTM
        proj = proj.unsqueeze(0)  # (1, V, hidden + visit_time_dim)
        
        # Run LSTM
        lstm_out, _ = self.lstm(proj)  # (1, V, 2*lstm_hidden)
        lstm_out = lstm_out.squeeze(0)  # (V, 2*lstm_hidden)
        
        # Attention pooling over visits
        scores = self.attn(lstm_out).squeeze(-1)  # (V,)
        attention_weights = torch.softmax(scores, dim=0)  # (V,)
        
        # Weighted sum
        pooled = (attention_weights.unsqueeze(-1) * lstm_out).sum(dim=0)  # (2*lstm_hidden,)
        
        # Classification
        logits = self.classifier(pooled)  # (output_dim,)
        
        return SequenceClassifierOutput(
            logits=logits,
            attentions=attention_weights.detach().cpu().numpy().tolist()
        )

In [None]:
from transformers import RobertaTokenizerFast

"""
Script to save the model architecture and tokenizer in a folder so that they can be reloaded later for testing purposes. In production, this model shouldn't be used since it's untrained.
"""

# load config that must be previously created with the right parameters and the tokenizer (will use the one from the encoder used during training)
cfg = LSTMBERTConfig.from_pretrained("test")
model = LSTMBERT(cfg)
tokenizer = RobertaTokenizerFast.from_pretrained('PlanTL-GOB-ES/bsc-bio-ehr-es')

# save with the right architecture
model.save_pretrained('test')
tokenizer.save_pretrained('test')

('test/tokenizer_config.json',
 'test/special_tokens_map.json',
 'test/vocab.json',
 'test/merges.txt',
 'test/added_tokens.json',
 'test/tokenizer.json')

In [None]:
## utils for date processing

import pandas as pd
import numpy as np
from typing import cast, List
from datetime import date, datetime

def date_linear_impute(dates: list[datetime | None]) -> list[date]:
    s = pd.to_datetime(pd.Series(dates), errors="coerce")
    n = len(s)

    # All null -> [1,2,...,n]
    if s.dropna().count() < 1: # if only nulls or only one non-null (don't have reference to interpolate)
        return [date(2024 + (i // 12), (i % 12) + 1 , 1) for i in range(n)] # by default, do it monthly from Jan 1, 2024

    # No nulls -> return same values (as floats)
    if all([d is not None for d in dates]):
        return cast(List[date], list(dates))

    s_int = s.apply(lambda x: x.value if pd.notna(x) else np.nan).astype("float64")  # convert Timestamp -> integer ns since epoch (use float to allow NaN)
    # Linear interpolation, allow extrapolation at ends
    s_interp = s_int.interpolate(method="linear", limit_direction="both").tolist()
    
    # there is at least 2 non-nulls, so we can extrapolate
    if not dates[0]:
        first_valid_index = s.first_valid_index()
        assert type (first_valid_index) is int
        step = s_interp[first_valid_index + 1] - s_interp[first_valid_index]
        for i in range(first_valid_index - 1, -1, -1):
            s_interp[i] = s_interp[i + 1] - step

    if not dates[-1]:
        last_valid_index = s.last_valid_index()
        assert type (last_valid_index) is  int
        step = s_interp[last_valid_index] - s_interp[last_valid_index - 1]
        for i in range(last_valid_index + 1, n):
            s_interp[i] = s_interp[i - 1] + step
    
    dt_series = pd.to_datetime(s_interp)
    return [pd.Timestamp(x).date() for x in dt_series.tolist()]

def dates_to_log_deltas(case_dates: list[date]) -> list[tuple[float, float]]:
    """
    Convert one case's ordered dates into two differnce arrays:
      - log_prev: log1p(delta since previous visit)  (first visit -> 0)
      - log_start: log1p(delta since first visit)    (first visit -> 0)

    Returns:
      list of tuples [(log_prev0, log_start0), ...] length == len(case_dates)
    """
    first = case_dates[0]
    prev = case_dates[0]

    out = []
    for dt in case_dates:
        # delta from previous (in days, possibly fractional)
        delta_prev_seconds = (dt - prev).total_seconds()
        delta_prev = delta_prev_seconds / 86400.0 # assuming days are the unit

        # delta from first
        delta_start_seconds = (dt - first).total_seconds()
        delta_start = delta_start_seconds / 86400.0

        # first visit: if dt == first then delta_prev may be 0.0, keep that
        log_prev = float(torch.log1p(torch.tensor(delta_prev, dtype=torch.float32)).item())
        log_start = float(torch.log1p(torch.tensor(delta_start, dtype=torch.float32)).item())

        out.append((log_prev, log_start))
        prev = dt

    return out

In [None]:
from transformers import RobertaTokenizer
from abc import ABC, abstractmethod
from datetime import datetime

class ModelClass(ABC):
    @abstractmethod
    def predict(self, case: list[str], dates: list[str]) -> tuple[float, list[float]]:
        pass

    def serialize(self, case: list[str], dates: list[str], syn_prob: float, attn_weights: list[float], footer: dict):
        """This function implements the Common Data Model v2"""
        output = {
            "nlp_output": {
                "record_metadata": {
                    "clinical_site_id": footer['provider_id'],
                    "patient_id": footer['person_id'],
                    "admission_id": footer['visit_detail_id'],
                    "record_id": footer['note_id'],
                    "record_type": footer['note_type_concept_id'],
                    "record_format": "json",
                    "record_creation_date": footer['note_datetime'],
                    "record_lastupdate_date": datetime.now().isoformat(),
                    "record_character_encoding": "UTF-8",
                    "record_extraction_date": datetime.now().isoformat(),
                    "report_section": footer['note_title'],
                    "report_language": "es",
                    "deidentified": "no",
                    "deidentification_pipeline_name": "",
                    "deidentification_pipeline_version": "",
                    "case": case,
                    "dates": dates,
                    "nlp_processing_date": datetime.now().isoformat(),
                    "nlp_processing_pipeline_name": self.__class__.__name__,
                    "nlp_processing_pipeline_version": "1.0",
                },
                "syntomatic_probability": syn_prob,
                "attention_weights": attn_weights
            },
            "nlp_service_info": {
                "service_app_name": "NLP Chagas Prediction",
                "service_language": "es",
                "service_version": "1.0",
                "service_model": self.__class__.__name__
            }
        }
        output["nlp_output"]["processing_success"] = True
        return output


class PredictionPipeline(ModelClass):
    def __init__(
            self,
            local_model_path: str
        ):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.cfg = LSTMBERTConfig.from_pretrained(local_model_path)

        self.tokenizer = RobertaTokenizer.from_pretrained(
            local_model_path,
            local_files_only=True
        )
        
        self.model = LSTMBERT.from_pretrained(local_model_path, config=self.cfg, local_files_only=True)
        self.model.eval()
        self.model.to(self.device) # type: ignore

    def predict(
            self,
            case: list[str],
            dates: list[str]
        ) -> tuple[float, list[float]]:

        inputs = self.tokenizer(case, return_tensors='pt', max_length=self.cfg.max_length, truncation=True, padding='max_length')
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        visit_times_list = self.format_dates(dates)    
        visit_times_tensor = torch.tensor(visit_times_list, dtype=torch.float32).to(self.device)
        outputs = self.model(inputs['input_ids'], inputs['attention_mask'], visit_times=visit_times_tensor)
        syn_prob = outputs.logits.softmax(dim=-1)[1].item()
        attn_list = outputs.attentions
        return syn_prob, attn_list
    
    @staticmethod
    def format_dates(dates_str: list[str]) -> list[tuple[float, float]]:
        """
        Convert one case's ordered string dates dates into two differnce arrays:
        - log_prev: log1p(delta since previous visit)  (first visit -> 0)
        - log_start: log1p(delta since first visit)    (first visit -> 0)

        Inputs:
        list of string dates in this format: [10Jan2024, 9Apr2024, ...]
        Returns:
        list of tuples [(log_prev0, log_start0), ...] length == len(dates_str)
        """
        dates = [datetime.strptime(d, "%d%b%Y") if d else None for d in dates_str]
        dates_imp = date_linear_impute(dates)
        return dates_to_log_deltas(dates_imp)
        

In [None]:
pipe = PredictionPipeline(local_model_path="test") # load model from the test folder created before
test_inputs = ["Paciente con fiebre y dolor de cabeza.", "Se observa inflamaci√≥n en las articulaciones."]
test_visit_dates = ['10Jan2024', '9Apr2024']
pipe.predict(test_inputs, test_visit_dates)

torch.Size([2])


(0.5387078523635864, [0.49486151337623596, 0.5051384568214417])