In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/Synchronisé/Cours/Illuin/
!pip install transformers
!pip install datasets
!pip install seqeval

/content/drive/MyDrive/Synchronisé/Cours/Illuin


In [None]:
from math import ceil
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
from datasets import load_metric
from typing import List, Optional, Set, Union
import re
from dataclasses import dataclass
from tqdm import tqdm
from scipy import spatial

In [None]:
TRAIN_CSV = './data/train.csv'
VALIDATION_CSV = './data/validation.csv'
NB_EPOCHS = 5
BATCH_SIZE = 16
LEARNING_RATE = 2e-5

## Preprocess data

In [None]:
df_train = pd.read_csv(TRAIN_CSV, index_col=0)
df_valid = pd.read_csv(VALIDATION_CSV, index_col=0)
df_train

Unnamed: 0,name,path,concept,rel,ast
0,614746156,./data/train_data/partners/txt/614746156.txt,./data/train_data/partners/concept/614746156.con,./data/train_data/partners/rel/614746156.rel,./data/train_data/partners/ast/614746156.ast
1,record-124,./data/train_data/beth/txt/record-124.txt,./data/train_data/beth/concept/record-124.con,./data/train_data/beth/rel/record-124.rel,./data/train_data/beth/ast/record-124.ast
2,917989835_RWH,./data/train_data/partners/txt/917989835_RWH.txt,./data/train_data/partners/concept/917989835_R...,./data/train_data/partners/rel/917989835_RWH.rel,./data/train_data/partners/ast/917989835_RWH.ast
3,433651389,./data/train_data/partners/txt/433651389.txt,./data/train_data/partners/concept/433651389.con,./data/train_data/partners/rel/433651389.rel,./data/train_data/partners/ast/433651389.ast
4,405868244_YC,./data/train_data/partners/txt/405868244_YC.txt,./data/train_data/partners/concept/405868244_Y...,./data/train_data/partners/rel/405868244_YC.rel,./data/train_data/partners/ast/405868244_YC.ast
...,...,...,...,...,...
131,record-51,./data/train_data/beth/txt/record-51.txt,./data/train_data/beth/concept/record-51.con,./data/train_data/beth/rel/record-51.rel,./data/train_data/beth/ast/record-51.ast
132,194442600_RWH,./data/train_data/partners/txt/194442600_RWH.txt,./data/train_data/partners/concept/194442600_R...,./data/train_data/partners/rel/194442600_RWH.rel,./data/train_data/partners/ast/194442600_RWH.ast
133,record-49,./data/train_data/beth/txt/record-49.txt,./data/train_data/beth/concept/record-49.con,./data/train_data/beth/rel/record-49.rel,./data/train_data/beth/ast/record-49.ast
134,record-15,./data/train_data/beth/txt/record-15.txt,./data/train_data/beth/concept/record-15.con,./data/train_data/beth/rel/record-15.rel,./data/train_data/beth/ast/record-15.ast


In [None]:
data = []
for _, row in df_train.iterrows():
    with open(row["path"], encoding="utf-8") as file:
        text = file.read()
    with open(row["concept"]) as file:
        concepts = file.read()
    data.append({"text": text, "concept": concepts})

In [None]:
@dataclass
class EntityAnnotation:
    """Entity Annotation"""

    label: str
    text: str
    start_line: int
    end_line: int
    start_word: int
    end_word: int

In [None]:
def parse_concept_annotation(text: str) -> Optional[EntityAnnotation]:
        try:
            return EntityAnnotation(
                label=text.split("||")[1].split("=")[1].replace('"', "").replace("\n", ""),
                text=re.split("(\d{1,6}:\d{1,6} \d{1,6}:\d{1,6})", text.split("||")[0])[0]
                .split("=")[1]
                .replace('"', ""),
                start_line=int(
                    re.split("(\d{1,6}:\d{1,6} \d{1,6}:\d{1,6})", text.split("||")[0])[1]
                    .split(" ")[0]
                    .split(":")[0]
                ),
                start_word=int(
                    re.split("(\d{1,6}:\d{1,6} \d{1,6}:\d{1,6})", text.split("||")[0])[1]
                    .split(" ")[0]
                    .split(":")[1]
                ),
                end_line=int(
                    re.split("(\d{1,6}:\d{1,6} \d{1,6}:\d{1,6})", text.split("||")[0])[1]
                    .split(" ")[1]
                    .split(":")[0]
                ),
                end_word=int(
                    re.split("(\d{1,6}:\d{1,6} \d{1,6}:\d{1,6})", text.split("||")[0])[1]
                    .split(" ")[1]
                    .split(":")[1]
                ),
            )
        except (ValueError, IndexError):
            return None

In [None]:
for x in data:
    x["labels"] = list(map(parse_concept_annotation, x["concept"].split("\n")))
data[0]

{'concept': 'c="abdominal pain" 22:10 22:11||t="problem"\nc="nausea" 22:7 22:7||t="problem"\nc="abdominal ct" 45:0 45:1||t="test"\nc="primary colorectal adenocarcinoma" 45:9 45:11||t="problem"\nc="multiple liver metastases" 45:13 45:15||t="problem"\nc="avn" 29:5 29:5||t="problem"\nc="melena" 22:3 22:3||t="problem"\nc="the hematocrit" 47:11 47:12||t="test"\nc="hives" 35:10 35:10||t="problem"\nc="painless jaundice" 16:10 16:11||t="problem"\nc="packed red blood cells" 47:6 47:9||t="treatment"\nc="biopsy" 49:0 49:0||t="test"\nc="further treatment" 51:16 51:17||t="treatment"\nc="glucotrol" 33:0 33:0||t="treatment"\nc="painless jaundice" 44:11 44:12||t="problem"\nc="colonoscopy" 49:4 49:4||t="test"\nc="tenesmus" 24:2 24:2||t="problem"\nc="iron" 54:0 54:0||t="treatment"\nc="an increased appetite" 18:15 18:17||t="problem"\nc="hematochezia" 22:5 22:5||t="problem"\nc="last hemoglobin a1c" 27:10 27:12||t="test"\nc="a 23 pound weight loss" 18:3 18:7||t="problem"\nc="night sweats" 21:2 21:3||t="pro

## Experiments

In [None]:
def format_data(data):
    formatted_data = []
    for elt in data:
        sentences = elt["text"].split("\n")
        formatted_doc = []
        for token in elt["labels"]:
            if type(token) == EntityAnnotation:
                formatted_doc.append({"sentence": sentences[token.start_line-1].split(" "), "label": token.label, "start_word": token.start_word, "end_word": token.end_word})
        formatted_data.append(formatted_doc)
    return formatted_data

formatted_data = format_data(data)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

Downloading:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

In [None]:
model = AutoModel.from_pretrained("./model", output_hidden_states=True, ).to(device)

Some weights of the model checkpoint at ./model were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ./model and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
def get_hidden_states(encoded, token_ids_words, layers=[-4, -3, -2, -1]):
     """Push input IDs through model. Stack and sum `layers` (last four by default).
        Select only those subword token outputs that belong to our word of interest
        and average them."""
     with torch.no_grad():
         output = model(encoded)
     # Get all hidden states
     states = output.hidden_states
     # Stack and sum all requested layers
     output = torch.stack([states[i] for i in layers]).sum(0)
     # Only select the tokens that constitute the requested words
     hidden_states = []
     for i, ids in enumerate(token_ids_words):
         hidden_states.append(output[i, ids].squeeze(dim=0).mean(dim=0).cpu().numpy())
     return hidden_states

In [None]:
def get_embedding_from_doc(labels, target_label='problem', batch_size=16, layers=[-4, -3, -2, -1]):
    # Filter out uninteresting labels
    filtered_labels = list(filter(lambda label: label['label'] == target_label, labels))
    # Encode the sentences with the tokenizer
    encoded = tokenizer([label["sentence"] for label in filtered_labels], is_split_into_words=True, return_tensors="pt", padding=True)
    # Compute the embeddings of the selected labels
    nb_embeddings = encoded.input_ids.shape[0]
    embeddings = []
    for i in range(ceil(nb_embeddings/batch_size)):
        encoded_local = encoded.input_ids[i*batch_size:min((i+1)*batch_size, nb_embeddings)].to(device=device)
        indexes_list = [range(label["start_word"], label["end_word"]+1) for label in filtered_labels[i*batch_size:min((i+1)*batch_size, nb_embeddings)]]
        words_ids = [encoded.word_ids(i) for i in range(i*batch_size,min((i+1)*batch_size, nb_embeddings))]
        # Compute which embedding to get from the model result
        token_ids_word = [np.where(np.in1d(word_ids, indexes)) for word_ids, indexes in zip(words_ids, indexes_list)]
        hidden_states = get_hidden_states(encoded_local, token_ids_word, layers)
        embeddings.extend(hidden_states)
    # Compute the final embedding vector
    if len(embeddings) > 0:
        embedding = np.stack(embeddings).mean(axis=0)
    else:
        embedding = None
    return embedding

In [None]:
def get_all_embeddings(docs, target_label='problem', batch_size=2, layers=[-4, -3, -2, -1]):
    embeddings = []
    for doc in tqdm(docs):
        embedding = get_embedding_from_doc(doc, target_label, batch_size, layers)
        embeddings.append(embedding)
    return embeddings

In [None]:
embeddings = get_all_embeddings(formatted_data)
len(embeddings)

  
100%|██████████| 136/136 [00:50<00:00,  2.69it/s]


136

In [None]:
def normalize_embeddings(embeddings):
    mean_embedding = np.stack(list(filter(lambda e: e is not None, embeddings))).mean(0)
    normalized_embeddings = [e - mean_embedding if e is not None else None for e in embeddings]
    return normalized_embeddings

normalized_embeddings = normalize_embeddings(embeddings)

In [None]:
def get_best_docs(i, docs, embeddings):
    similarities = [1 - spatial.distance.cosine(embeddings[i], embeddings[j]) if embeddings[j] is not None else 0 for j in range(len(embeddings))]
    best_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[1:6]
    best_docs = [(k, similarities[k], docs[k]) for k in best_indices]
    return best_docs

## Some tests

In [None]:
# With embeddings normalized

idx_compared = 56

doc_labels = [[label.text for label in doc["labels"] if label is not None and label.label == "problem"] for doc in data]
best_docs = get_best_docs(idx_compared, doc_labels, normalized_embeddings)
print("This doc:\n")
print([label.text for label in data[idx_compared]["labels"] if label is not None and label.label == "problem"])
for i, (k, score, labels) in enumerate(best_docs):
    print("\n-------------------------------------")
    print(f"{i+1}th best doc ({k}th doc, score of {score}):\n")
    print(labels)

This doc:

['carotid disease ', 'hepatosplenomegaly ', 'right coronary artery dominant diseased ', 'high cholesterol ', 'high cholesterol ', 'ekg changes ', 'coronary artery disease ', 'pain ', 'atraumatic ', 'cyanosis ', 'lymphadenopathy ', 'clubbing ', 'an acute inferior myocardial infarction ', 'afebrile ', 'mitral regurgitation ', 'murmurs ', 'acute distress ', 'masses ', 'mid chest pain ', 'nondistended ', 'nontender ', 'right carotid bruits ', 'unresponsive ', 'burning ', 'insufficiency ', 'hypokinesis ', 'thyromegaly ', 'narrowed ', 'breathing ', 'edema ']

-------------------------------------
1th best doc (55th doc, score of 0.6866931915283203):

['left tm scarring ', 'chills ', 'anorexia ', 'melena ', 'nausea ', 'lymphadenopathy ', 'hepatosplenomegaly ', 'episodes ', 'recurrent biliary colic ', 'bright red blood per rectum ', 'gallbladder ', 'rub ', 'symptoms ', 'weight loss ', 'ultrasound proven gallstones ', 'apparent distress ', 'recurrent biliary colic ', 'rash ', 'nonten

In [None]:
# Without embeddings normalized

idx_compared = 56

doc_labels = [[label.text for label in doc["labels"] if label is not None and label.label == "problem"] for doc in data]
best_docs = get_best_docs(idx_compared, doc_labels, embeddings)
print("This doc:\n")
print([label.text for label in data[idx_compared]["labels"] if label is not None and label.label == "problem"])
for i, (k, score, labels) in enumerate(best_docs):
    print("\n-------------------------------------")
    print(f"{i+1}th best doc ({k}th doc, score of {score}):\n")
    print(labels)

This doc:

['carotid disease ', 'hepatosplenomegaly ', 'right coronary artery dominant diseased ', 'high cholesterol ', 'high cholesterol ', 'ekg changes ', 'coronary artery disease ', 'pain ', 'atraumatic ', 'cyanosis ', 'lymphadenopathy ', 'clubbing ', 'an acute inferior myocardial infarction ', 'afebrile ', 'mitral regurgitation ', 'murmurs ', 'acute distress ', 'masses ', 'mid chest pain ', 'nondistended ', 'nontender ', 'right carotid bruits ', 'unresponsive ', 'burning ', 'insufficiency ', 'hypokinesis ', 'thyromegaly ', 'narrowed ', 'breathing ', 'edema ']

-------------------------------------
1th best doc (55th doc, score of 0.9809367060661316):

['left tm scarring ', 'chills ', 'anorexia ', 'melena ', 'nausea ', 'lymphadenopathy ', 'hepatosplenomegaly ', 'episodes ', 'recurrent biliary colic ', 'bright red blood per rectum ', 'gallbladder ', 'rub ', 'symptoms ', 'weight loss ', 'ultrasound proven gallstones ', 'apparent distress ', 'recurrent biliary colic ', 'rash ', 'nonten

In [None]:
# With embeddings normalized

idx_compared = 0

doc_texts = [doc["text"] for doc in data]
best_docs = get_best_docs(idx_compared, doc_texts, normalized_embeddings)
print("This doc:\n")
print(data[idx_compared]["text"])
for i, (k, score, text) in enumerate(best_docs):
    print("\n-------------------------------------")
    print(f"{i+1}th best doc ({k}th doc, score of {score}):\n")
    print(text)

This doc:

614746156
CTMC
48720920
513332
7/3/1999 12:00:00 AM
Discharge Summary
Signed
DIS
Admission Date :
07/03/1999
Report Status :
Signed
Discharge Date :
EHC9
HISTORY OF PRESENT ILLNESS :
The patient is a 71-year-old police chief who presents with painless jaundice x1 day .
The patient was generally in excellent health with a past medical history significant only for noninsulin dependent diabetes mellitus who was presented with painless jaundice x2 days .
He also noted a 23 pound weight loss in the past 11 months despite having an increased appetite .
The patient also complained of fatigue and &quot; feeling down &quot; .
His wife noted personality changes with increased irritability .
Patient denies night sweats in the past month .
The patient denies melena , hematochezia , nausea , and abdominal pain .
The patient states that he is occasionally constipated .
Also denies tenesmus .
On the day prior to admission , the patient &apos;s family noted that &quot; he looked yellow &quo

The results don't seem really nice.

We should try the same thing with a S-BERT.