In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

import torch

In [3]:
HERE = Path(os.getcwd())
DATA_PATH = HERE.parent / 'data' / '1.1' / 'dev'

In [4]:
import xml.etree.ElementTree as ET

tree = ET.parse(DATA_PATH / 'archehr-qa.xml')
root = tree.getroot()

for c in root.findall('case'):
    print(c.find('patient_narrative').text)
    print(c.find('patient_question').text)
    break


I had severe abdomen pain and was hospitalised for 15 days in ICU, diagnoised with CBD sludge thereafter on udiliv. Doctor advised for ERCP. My question is if the sludge was there does not the medication help in flushing it out? Whether ERCP was the only cure?
        

            


In [5]:
import json

with open(DATA_PATH / 'archehr-qa_key.json', 'r') as f:
    data = json.load(f)

In [6]:
data_pairs = []

for c, label in zip(root.findall('case'), data, strict=True):
    data_pairs.append(
        {
            'narrative': c.find('patient_narrative').text,
            'patient_question': c.find('patient_question').find('phrase').text,
            'clinician_question': c.find('clinician_question').text,
            'sentences': [
                (i, sentence.text, answer['relevance'])
                for i, (sentence, answer) in enumerate(
                    zip(c.find('note_excerpt_sentences').findall('sentence'), label['answers'], strict=True)
                )
            ]
        }
    )

In [7]:
from collections import Counter
# Look at the distribution of phrases importance

l = [a for c in data_pairs for (_, _, a) in c['sentences']]
Counter(l)

Counter({'not-relevant': 239, 'essential': 138, 'supplementary': 51})

## Test a Cross-encoder

cross-encoders are mdels that take both the query and the phrase as inputs and return an embedding.  
Here we will use the `MS Marco` model (popular on [huggingface](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2)) 

In [None]:
from sentence_transformers import CrossEncoder

cross_model_name = 'cross-encoder/ms-marco-MiniLM-L6-v2'
model = CrossEncoder(cross_model_name)

In [177]:
# Encode queries for first case
c = data_pairs[0]
queries = [c['narrative'], c['patient_question'], c['clinician_question']]
sentences = [s for (_, s, _) in c['sentences']]

# Make the pairs
questions = [
    (q, d) for q in queries for d in sentences 
]

In [179]:
results = model.predict(questions)

## Test deberta model

The deBerta model ([here](https://huggingface.co/cross-encoder/nli-deberta-v3-base))
gives for a sentence pair the corresponding labels: contradiction, entailment, neutral.  

For our case we can just convert those labels to non-essential, essential, supplementary ?

In [8]:
from sentence_transformers import CrossEncoder

cross_model_name = 'cross-encoder/nli-deberta-v3-base'
model = CrossEncoder(cross_model_name)

  from .autonotebook import tqdm as notebook_tqdm


In [32]:
list(model.named_modules())[-1]

('activation_fn', Identity())

In [38]:
list(model.named_modules())[:-1]

[('',
  CrossEncoder(
    (model): DebertaV2ForSequenceClassification(
      (deberta): DebertaV2Model(
        (embeddings): DebertaV2Embeddings(
          (word_embeddings): Embedding(128100, 768, padding_idx=0)
          (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): DebertaV2Encoder(
          (layer): ModuleList(
            (0-11): 12 x DebertaV2Layer(
              (attention): DebertaV2Attention(
                (self): DisentangledSelfAttention(
                  (query_proj): Linear(in_features=768, out_features=768, bias=True)
                  (key_proj): Linear(in_features=768, out_features=768, bias=True)
                  (value_proj): Linear(in_features=768, out_features=768, bias=True)
                  (pos_dropout): Dropout(p=0.1, inplace=False)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): DebertaV2SelfOutput(
 

In [37]:
# Remove the last layer
new_model = torch.nn.Sequential(*list(model.named_modules())[:-1])
new_model

TypeError: tuple is not a Module subclass

In [None]:
# Make the predictions
scores = model.predict(questions)

label_mapping = ['non-essential', 'essential', 'supplementary']
labels = [label_mapping[score_max] for score_max in scores.argmax(axis=1)]

In [16]:
def make_questions(c):
    # Make the queries
    queries = [c['narrative'], c['patient_question'], c['clinician_question']]
    sentences = [s for (_, s, _) in c['sentences']]

    # Make the pairs
    questions = [
        (q, d) for q in queries for d in sentences 
    ]

    # Get the answers
    answers = [r for (_, _, r) in c['sentences']]

    return questions, answers

In [18]:
scores = []
answers = []

for c in data_pairs:
    # Make the questions
    questions, a = make_questions(c)

    # Calculate the scores
    s = model.predict(questions)
    scores.extend(s)
    answers.extend(a)


In [69]:
import numpy as np

def make_prediction(scores, label_mapping):
    scores = np.array(scores)
    labels = np.array(
        [label_mapping[score_max] for score_max in scores.argmax(axis=1)]
    ).reshape((len(scores) // 3, 3))
    
    def pref(row):
        if 'essential' in row:
            return 'essential'
        
        elif 'not-relevant' in row:
            return 'not-relevant'

        else:
            return 'supplementary'
    
    labels = np.apply_along_axis(pref, axis=1, arr=labels)

    return labels

In [87]:
# Get the real values
truth = np.array(answers)

# Make all the possible permutations
cl = set(np.unique(truth))
choices = [[a] for a in cl]

for i in range(len(cl) - 1):
    new_choices = [l + [a] for l in choices for a in (cl - set(l))]
    choices = new_choices

# For each possible permutation look at the acc
choice_scores = []

for choice in choices:
    labels = make_prediction(scores, choice)

    mask = (truth == 'essential')
    acc = sum(labels[mask] == truth[mask]) / len(labels[mask])

    choice_scores.append((choice, float(acc)))

best_choice, acc = sorted(choice_scores, key=lambda x: x[1], reverse=True)[0]

labels = make_prediction(scores, best_choice)
print(f"{acc:.1%}")

91.3%


In [95]:
scores[:5]

[array([-2.6728544, -2.2185981,  4.880722 ], dtype=float32),
 array([-2.147493 , -2.4036314,  4.593408 ], dtype=float32),
 array([-2.4452193, -2.8953137,  5.480044 ], dtype=float32),
 array([ 2.9081976, -4.7010617,  2.8155947], dtype=float32),
 array([ 3.8878734, -5.3210015,  2.6661558], dtype=float32)]

In [94]:
np.array(scores).reshape((len(truth), 3, 3))[:3]

array([[[-2.6728544, -2.2185981,  4.880722 ],
        [-2.147493 , -2.4036314,  4.593408 ],
        [-2.4452193, -2.8953137,  5.480044 ]],

       [[ 2.9081976, -4.7010617,  2.8155947],
        [ 3.8878734, -5.3210015,  2.6661558],
        [-1.6935804, -3.488221 ,  5.5374184]],

       [[-1.8459907, -2.6642559,  4.602802 ],
        [-1.0479984, -3.4409842,  4.8996525],
        [ 1.2287819, -4.7787423,  4.574991 ]]], dtype=float32)

## Test a Sentence transformer

Embeddings models are useful for RAG models, here we test to embed
the narrative / patient_question / clinician_question as well as the
sentences of the excerpt.

Then we use both embeddings to try to predict the utility of the phrase.

In [None]:
from sentence_transformers import SentenceTransformer

# Load the model
model_name = 'Snowflake/snowflake-arctic-embed-l-v2.0'
model = SentenceTransformer(model_name)

# Encode queries for first case
c = data_pairs[0]
queries = [c['narrative'], c['patient_question'], c['clinician_question']]
documents = [s for (_, s, _) in c['sentences']]

query_embeddings = model.encode(queries, prompt_name='query')
doc_embeddings = model.encode(documents)

In [36]:
scores = model.similarity(query_embeddings, doc_embeddings)

In [160]:
# Make a small classification head to see if it grants good results
import torch
import torch.nn as nn
from collections import OrderedDict

class MLP(nn.Module):

    def __init__(
        self,
        in_dim: int,
        n_layers: int = 3,
        hidden_dim: int = 2048,
        output_dim: int = 2,
    ):
        super().__init__()
        self.in_dim = in_dim
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.layers = self._make_stack(
            self.n_layers,
            self.in_dim,
            self.hidden_dim,
            self.output_dim
        )
    
    @staticmethod
    def _make_stack(n_layers, in_dim, hidden_dim, output_dim):
        layers = OrderedDict()
        layers['layer0'] = nn.Linear(in_dim, hidden_dim)
        layers['gelu0'] = nn.ReLU()

        for i in range(n_layers - 1):
            layers[f'layer{i+1}'] = nn.Linear(hidden_dim, hidden_dim)
            layers[f'gelu{i+1}'] = nn.ReLU()

        layers['final_layer'] = nn.Linear(hidden_dim, output_dim)

        return nn.Sequential(layers)

    def forward(self, x):
        return self.layers(x)


In [None]:
from torch.utils.data import Dataset

# Train with the clinician question
class CustomDataset(Dataset):
   
    def __init__(self, data, embed_model):
        super().__init__()
        self.embed_model = embed_model
        self.data = self._make_pairs(data, embed_model)
        self.translate_dict = {
            a: i 
            for i, a in enumerate(set([lab for _, lab in self.data]))
        }

    def _make_pairs(self, data, model_name):
        model = SentenceTransformer(model_name)
        
        d = []
        for c in data:
            query = torch.tensor(model.encode(c['clinician_question'], prompt_name='query'))
            phrases = [
                (torch.tensor(model.encode(p)), lab) for _, p, lab in c['sentences']
            ]

            d.extend([(torch.hstack((query, p)), lab) for p, lab in phrases])
        
        return d
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        v, lab = self.data[index]
        return v, self.translate_dict[lab]



In [155]:
def eval_model(model, dataloader, target, progress_bar = None):
    model.eval()
    preds, labels = [], []

    with torch.no_grad():
        for inp, lab in dataloader:
            # calculate the predictions of the model
            p = model(inp)
            preds.append(p.argmax(dim=1))
            labels.append(lab)

        labels = torch.concat(labels)
        preds = torch.concat(preds)

        # find tp & fo
        tp = sum((labels == preds)[labels == target])
        fp = sum((labels != preds)[preds == target])

        # calculate the metrics
        acc = (sum(labels == preds) / len(labels)).item()
        rec = (tp / sum(labels == target)).item()
        ppv = (tp / (tp + fp)).item()

    if progress_bar:
        progress_bar.set_postfix(acc=f'{acc:.1%}', recall=f'{rec:.1%}', ppv=f'{ppv:.1%}')

    return acc, rec, ppv



In [157]:
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm

# Instantiate everything needed
# The data
dataset = CustomDataset(data_pairs, model_name)
dataloader = DataLoader(dataset, batch_size=32)


MLP(
  (layers): Sequential(
    (layer0): Linear(in_features=2048, out_features=128, bias=True)
    (gelu0): GELU(approximate='none')
    (layer1): Linear(in_features=128, out_features=128, bias=True)
    (gelu1): GELU(approximate='none')
    (layer2): Linear(in_features=128, out_features=128, bias=True)
    (gelu2): GELU(approximate='none')
    (final_layer): Linear(in_features=128, out_features=3, bias=True)
  )
)

In [163]:
# The model
model = MLP(in_dim=2048, hidden_dim=1024, output_dim=len(dataset.translate_dict))

# The loss and optimizer
loss = nn.CrossEntropyLoss()
optim = AdamW(model.parameters())

model.train()

MLP(
  (layers): Sequential(
    (layer0): Linear(in_features=2048, out_features=1024, bias=True)
    (gelu0): ReLU()
    (layer1): Linear(in_features=1024, out_features=1024, bias=True)
    (gelu1): ReLU()
    (layer2): Linear(in_features=1024, out_features=1024, bias=True)
    (gelu2): ReLU()
    (final_layer): Linear(in_features=1024, out_features=3, bias=True)
  )
)

In [164]:
n_epochs = 1000
accuracy, recall, positive_pred = [], [], [] 

for i in (progress_bar := tqdm(range(n_epochs))):
    model.train()
    for inputs, labels in dataloader:
        # calculate the predictions of the model
        preds = model(inputs)
        
        # calculate the loss
        l = loss(preds, labels)
        
        # backpropagation
        l.backward()

        # optimize parameters
        optim.step()

    if i % 100 == 0:
        # calculate the metrics
        acc, rec, ppv = eval_model(
            model,
            dataloader,
            dataset.translate_dict.get('essential'),
            progress_bar
        )

        accuracy.append(acc)
        recall.append(rec)
        positive_pred.append(ppv)

100%|██████████| 1000/1000 [07:46<00:00,  2.14it/s, acc=56.5%, ppv=nan%, recall=0.0%]  


In [128]:
model.eval()

preds, labels = [], [] 

for inp, lab in tqdm(dataloader):
    p = model(inp)

    preds.append(p.argmax(dim=1))
    labels.append(lab)

100%|██████████| 14/14 [00:00<00:00, 91.64it/s]
