---
title: "Semantic clustering of PanUKB phenotypes"
author: "Saikat Banerjee"
format:
  html: default
date: "2024-04-02"
file-modified: "2024-04-02"
abstract: "We use pre-trained LLM models for clustering of PanUKB phenotypes"

---

In [2]:
import os
import numpy as np
import pandas as pd
import pickle
import re

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils

mpl_stylesheet.banskt_presentation(splinecolor = 'black', dpi = 120)

In [3]:
import torch
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as st_util

In [4]:
data_dir = "/gpfs/commons/home/sbanerjee/work/npd/PanUKB/data"
trait_df  = pd.read_pickle(os.path.join(data_dir, f"modselect/traits_all_with_desc.pkl"))
trait_df

Unnamed: 0,zindex,trait_type,phenocode,pheno_sex,coding,modifier,description,description_more,coding_description,category,BIN_QT,n_cases_EUR,n_controls_EUR,N,Neff,filename,aws_link,estimates.final.h2_observed,long_description,short_description
0,1,icd10,A04,both_sexes,,,A04 Other bacterial intestinal infections,truncated: true,,Chapter I Certain infectious and parasitic dis...,BIN,3088,417443.0,420531,6130.649032,icd10-A04-both_sexes.tsv.bgz,https://pan-ukb-us-east-1.s3.amazonaws.com/sum...,0.0033,A04 Other bacterial intestinal infections,A04 Bacterial intestinal infections
1,2,icd10,A08,both_sexes,,,A08 Viral and other specified intestinal infec...,truncated: true,,Chapter I Certain infectious and parasitic dis...,BIN,1107,419424.0,420531,2208.171897,icd10-A08-both_sexes.tsv.bgz,https://pan-ukb-us-east-1.s3.amazonaws.com/sum...,0.0001,A08 Viral and other specified intestinal infec...,"A08 Viral, other intestinal infections"
2,3,icd10,A09,both_sexes,,,A09 Diarrhoea and gastro-enteritis of presumed...,truncated: true,,Chapter I Certain infectious and parasitic dis...,BIN,9029,411502.0,420531,17670.286180,icd10-A09-both_sexes.tsv.bgz,https://pan-ukb-us-east-1.s3.amazonaws.com/sum...,0.0035,A09 Diarrhoea and gastro-enteritis of presumed...,"A09 Diarrhoea, infectious gastro-enteritis"
3,4,icd10,A41,both_sexes,,,A41 Other septicaemia,truncated: true,,Chapter I Certain infectious and parasitic dis...,BIN,5512,415019.0,420531,10879.505810,icd10-A41-both_sexes.tsv.bgz,https://pan-ukb-us-east-1.s3.amazonaws.com/sum...,0.0011,A41 Other septicaemia,A41 Other septicaemia
4,5,icd10,B34,both_sexes,,,B34 Viral infection of unspecified site,truncated: true,,Chapter I Certain infectious and parasitic dis...,BIN,2129,418402.0,420531,4236.443249,icd10-B34-both_sexes.tsv.bgz,https://pan-ukb-us-east-1.s3.amazonaws.com/sum...,0.0003,B34 Viral infection of unspecified site,B34 Viral infection of unspecified site
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2478,2479,continuous,Smoking,both_sexes,,Ever_Never,"Smoking status, ever vs never",Ever (previous + current smoker) vs never base...,,,QT,418817,,418817,418817.000000,continuous-Smoking-both_sexes-Ever_Never.tsv.bgz,https://pan-ukb-us-east-1.s3.amazonaws.com/sum...,0.1100,"Smoking status, ever vs never","Smoking status, ever vs never"
2479,2480,continuous,eGFR,both_sexes,,irnt,"Estimated glomerular filtration rate, serum cr...",eGFR based on serum creatinine (30700) using t...,,,QT,401867,,401867,401867.000000,continuous-eGFR-both_sexes-irnt.tsv.bgz,https://pan-ukb-us-east-1.s3.amazonaws.com/sum...,0.2070,"Estimated glomerular filtration rate, serum cr...","Estimated GFR, serum creatinine"
2480,2481,continuous,eGFRcreacys,both_sexes,,irnt,"Estimated glomerular filtration rate, cystain C",eGFR based on cystain C (30720) using the CKD-...,,,QT,401570,,401570,401570.000000,continuous-eGFRcreacys-both_sexes-irnt.tsv.bgz,https://pan-ukb-us-east-1.s3.amazonaws.com/sum...,0.2380,"Estimated glomerular filtration rate, cystain C","Estimated GFR, cystain C"
2481,2482,continuous,eGFRcys,both_sexes,,irnt,"Estimated glomerular filtration rate, serum cr...",eGFR based on serum creatinine (30700) and cys...,,,QT,402031,,402031,402031.000000,continuous-eGFRcys-both_sexes-irnt.tsv.bgz,https://pan-ukb-us-east-1.s3.amazonaws.com/sum...,0.2240,"Estimated glomerular filtration rate, serum cr...","Estimated GFR, serum creatinine + cystain C"


# Phenotype Description to Sentences

First we convert all descriptions to sentences, which will be used for embedding.

In [5]:
long_desc  = trait_df['long_description'].tolist()

# Get embedding from sentences

I decided to use [Sentence Transformer](https://www.sbert.net/) for clustering the long description of the phenotypes. 
This is an ideal usecase for LLM models. 

Several pre-trained models are available on [Hugginface](https://huggingface.co/). Any pre-trained models can be loaded. Some of them are trained for Sentence Transformer, as you can search [here](https://huggingface.co/models?library=sentence-transformers&sort=trending). I found a few which are trained on medical data.
1. https://huggingface.co/menadsa/S-Bio_ClinicalBERT
2. https://huggingface.co/ls-da3m0ns/bge_large_medical

In [6]:
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def get_embeddings(sentences, model_name, use_pooling = False):
    
    if use_pooling:
        # Load model from HuggingFace Hub
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name)
        
        # Tokenize sentences
        encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

        # Compute token embeddings
        with torch.no_grad():
            model_output = model(**encoded_input)

        # Perform pooling. In this case, mean pooling.
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    else:
        model = SentenceTransformer(model_name)
        sentence_embeddings = model.encode(sentences, batch_size=64, show_progress_bar=True, convert_to_tensor=True)
        
    return sentence_embeddings

In [7]:
model_names = {
    "ls-da3m0ns/bge_large_medical" : "SentenceTransformer",
    "medicalai/ClinicalBERT" : "Transformer",
    "emilyalsentzer/Bio_ClinicalBERT" : "Transformer"
}

In [8]:
embeddings = dict()
for model_name, model_type in model_names.items():
    use_pooling = False if model_type == "SentenceTransformer" else True
    embeddings[model_name] = get_embeddings(long_desc, model_name, use_pooling = use_pooling)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/172 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/2.57k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/742 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

Batches:   0%|          | 0/39 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/62.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/466 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/542M [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


# Compute clusters from embeddings

We can either use K-Means clustering or community detection

In [22]:
from sklearn.cluster import AgglomerativeClustering

def compute_clusters_kmeans(embeddings, n_clusters = 30):
    # Perform kmean clustering
    clustering_model = AgglomerativeClustering(n_clusters = n_clusters)  # , affinity='cosine', linkage='average', distance_threshold=0.4)
    clustering_model.fit(embeddings)
    cluster_assignment = clustering_model.labels_
    phenotype_clusters = [list(np.where(cluster_assignment == i)[0]) for i in range(np.max(cluster_assignment) + 1)]
    return phenotype_clusters

def compute_cluster_community(embeddings, n_size = 10, n_clusters = 30, thres_step = 0.05):
    threshold = 1.0
    clusters = st_util.community_detection(embeddings, min_community_size = n_size, threshold = threshold)
    while len(clusters) < n_clusters:
        threshold -= thres_step
        clusters = st_util.community_detection(embeddings, min_community_size = n_size, threshold = threshold)
    return clusters, threshold

In [25]:
clusters_community = dict()
clusters_kmeans = dict()
for model_name in model_names.keys():
    clusters_community[model_name], threshold = compute_cluster_community(embeddings[model_name], thres_step = 0.01, n_clusters = 20)
    clusters_kmeans[model_name] = compute_clusters_kmeans(embeddings[model_name], n_clusters = 20)
    
    kmeans_sizes = ", ".join([f"{len(x)}" for x in clusters_kmeans[model_name]])
    print (f"{model_name}\n\tCommunity: {len(clusters_community[model_name])} clusters | auto threshold = {threshold:.2f}\n\tKMeans: [{kmeans_sizes}]\n")

ls-da3m0ns/bge_large_medical
	Community: 21 clusters | auto threshold = 0.70
	KMeans: [128, 177, 117, 184, 188, 75, 99, 148, 438, 98, 56, 24, 150, 189, 85, 154, 50, 35, 45, 43]

medicalai/ClinicalBERT
	Community: 20 clusters | auto threshold = 0.83
	KMeans: [263, 125, 238, 117, 141, 72, 117, 144, 260, 38, 94, 188, 268, 127, 21, 54, 72, 45, 56, 43]

emilyalsentzer/Bio_ClinicalBERT
	Community: 29 clusters | auto threshold = 0.93
	KMeans: [68, 345, 334, 210, 77, 170, 147, 47, 93, 115, 76, 34, 69, 119, 163, 109, 31, 55, 106, 115]



In [26]:
for m, citems in clusters_community.items():
    
    print (f"Model name: {m}")
    
    nc = len(citems)
    
    # The top 3 and bottom 3 elements for the top 3 and bottom 3 clusters
    for i in list(range(3)) + list(range(nc))[-3:]:
        cluster = citems[i]
        print("\nCluster {}, #{} Elements ".format(i + 1, len(cluster)))
        for sentence_id in cluster[0:3]:
            print("\t", long_desc[sentence_id])
        print("\t", "...")
        for sentence_id in cluster[-3:]:
            print("\t", long_desc[sentence_id])
    print ("---------------------")
    print ()

Model name: ls-da3m0ns/bge_large_medical

Cluster 1, #32 Elements 
	 Other fruit intake
	 Mixed fruit intake
	 Other sweets intake
	 ...
	 Grapefruit juice intake
	 Other cheese intake
	 Side salad intake

Cluster 2, #30 Elements 
	 6mm cylindrical power angle (left)
	 6mm cylindrical power angle (right)
	 3mm cylindrical power angle (left)
	 ...
	 3mm strong meridian angle (right)
	 3mm weak meridian angle (right)
	 6mm regularity index (right)

Cluster 3, #26 Elements 
	 Systolic blood pressure, automated reading
	 Systolic blood pressure, combined automated + manual reading
	 Systolic blood pressure, automated reading, adjusted by medication
	 ...
	 Pulse pressure, manual reading, adjusted by medication
	 Mean arterial pressure, manual reading, adjusted by medication
	 Elevated blood pressure reading without diagnosis of hypertension

Cluster 19, #10 Elements 
	 Heart valve disorders
	 Rheumatic disease of the heart valves
	 Aortic valve disease
	 ...
	 I35 Nonrheumatic aortic valve

# Save the clusters

In [31]:
def save_cluster_list(filepath, clusters):
    dirname = os.path.dirname(filepath)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    with open(filepath, "wb") as fh:
        pickle.dump(clusters, fh, protocol=pickle.HIGHEST_PROTOCOL)
        
outdir = "/gpfs/commons/home/sbanerjee/work/npd/PanUKB/results/llm"

for method, clusters in clusters_community.items():
    m_filename = os.path.join(outdir, f"{method}/community_clusters.pkl")
    save_cluster_list(m_filename, clusters)
    
for method, clusters in clusters_kmeans.items():
    m_filename = os.path.join(outdir, f"{method}/kmeans_clusters.pkl")
    save_cluster_list(m_filename, clusters)