In [1]:
import pandas as pd
import numpy as np
import argparse
import seaborn as sns
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import nltk

In [2]:
!pip install sentence-transformers
!pip install transformers
!pip3 install pickle5

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.0.tar.gz (79 kB)
[K     |████████████████████████████████| 79 kB 6.8 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.20.0-py3-none-any.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 49.8 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 52.9 MB/s 
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
[K     |████████████████████████████████| 86 kB 4.0 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 50.1 MB/s 
[?25hCollecting tokenizers!=0.

In [3]:
def load_pickle5_file(filepath):
  import pickle5 as pickle
  data = None
  with open(filepath, "rb") as fh:
    data = pickle.load(fh)
  return data

In [4]:
assert torch.cuda.is_available()
device = torch.device("cuda")
print("Using device:", device)

Using device: cuda


In [5]:
# Create the general sentencee embedding architecture 
from sentence_transformers import SentenceTransformer, models
from transformers import AutoTokenizer, AutoModel

word_embedding_model = models.Transformer('emilyalsentzer/Bio_ClinicalBERT', max_seq_length=256)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
model.to(device)
print("model sent to device")

Downloading:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/416M [00:00<?, ?B/s]

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

model sent to device


## Training

First let's see how the model performs 

In [7]:
## Get the conditions data
conditions = pd.read_csv('/content/drive/MyDrive/db_int_conditions.csv')

In [13]:
conditions = conditions.append({'0':10608, 'Neuromuscular Blockade': 'Neuromuscular Blockade'}, ignore_index=True)

In [15]:
conditions.rename(columns = {'0':'id', 'Neuromuscular Blockade':'condition'}, inplace = True)

In [18]:
encoded = model.encode(conditions['condition'])

In [22]:
conditions['encoded'] = pd.Series(encoded.tolist())

In [47]:
# Cluster the measures
# uniques
from sklearn.cluster import AgglomerativeClustering

cluster = AgglomerativeClustering(n_clusters=None, affinity='cosine', linkage='average', distance_threshold=.01).fit(encoded)
conditions['cluster_label'] = pd.Series(cluster.labels_).values

In [48]:
conditions['cluster_label'].value_counts()

4494    13
216     12
251     11
442     11
519     10
        ..
4991     1
7510     1
5689     1
7178     1
345      1
Name: cluster_label, Length: 9297, dtype: int64

In [51]:
conditions[conditions['cluster_label'] == 519]

Unnamed: 0,id,condition,encoded,cluster_label
2420,2421,Stage IVA Uterine Corpus Cancer AJCC v7,"[0.09006096422672272, 0.05761827155947685, -0....",519
3831,3832,Stage IA Uterine Corpus Cancer AJCC v7,"[0.10220805555582047, 0.12565310299396515, -0....",519
3832,3833,Stage IB Uterine Corpus Cancer AJCC v7,"[0.13736595213413239, 0.08130642771720886, -0....",519
3833,3834,Stage II Uterine Corpus Cancer AJCC v7,"[0.10111526399850845, 0.10834573954343796, -0....",519
3834,3835,Stage IIIA Uterine Corpus Cancer AJCC v7,"[0.06422608345746994, 0.10428615659475327, -0....",519
3835,3836,Stage IIIB Uterine Corpus Cancer AJCC v7,"[0.10565837472677231, 0.14559291303157806, -0....",519
3836,3837,Stage IIIC Uterine Corpus Cancer AJCC v7,"[0.10871592164039612, 0.0932564064860344, -0.2...",519
3883,3884,Stage I Uterine Corpus Cancer AJCC v7,"[0.0867515280842781, 0.10899213701486588, -0.2...",519
7927,7928,Stage III Uterine Corpus Cancer AJCC v7,"[0.09293092787265778, 0.12446162104606628, -0....",519
7928,7929,Stage IV Uterine Corpus Cancer AJCC v7,"[0.11784208565950394, 0.08415602892637253, -0....",519
