# Application example | **Embedding of protein families across protein language models**

This notebook demonstrates the application example of the ema-tool library as described in the referring preprint. More detailed information about each of the functions can be found in the `HCN1-variant-example.ipynb` notebook.

In [1]:
import numpy as np
import pandas as pd

from ema import EmbeddingHandler

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
DATA_DIR = '../examples/ion-channel-proteins/'
FP_METADATA = DATA_DIR + 'metadata.csv'
FP_EMB_ESM1b = DATA_DIR + 'esm1b_t33_650M_UR50S-embeddings.npy'
FP_EMB_ESM2 = DATA_DIR + 'esm2_t33_650M_UR50D-embeddings.npy'
FP_EMB_ESM1v = DATA_DIR + 'esm1v_t33_650M_UR90S_1-embeddings.npy'
FP_EMB_t5 = DATA_DIR + 't5_u50.npy'

In [16]:
# load metadata and embeddings 

metadata = pd.read_csv(FP_METADATA)
emb_esm1b = np.load(FP_EMB_ESM1b)
emb_esm2 = np.load(FP_EMB_ESM2)
emb_esm1v = np.load(FP_EMB_ESM1v)
emb_t5 = np.load(FP_EMB_t5)

print(emb_esm1b.shape, emb_esm2.shape)
metadata.head()

(102, 1280) (102, 1280)


Unnamed: 0,gene_name,family
0,KCNA1,Kv
1,KCNA2,Kv
2,KCNA3,Kv
3,KCNA4,Kv
4,KCNA5,Kv


In [4]:
# initialize embedding handler
emb_handler = EmbeddingHandler(metadata)

# add embeddings to the handler
emb_handler.add_emb_space(embeddings=emb_esm1b, emb_space_name='esm1b')
emb_handler.add_emb_space(embeddings=emb_esm2, emb_space_name='esm2')
emb_handler.add_emb_space(embeddings=emb_esm1v, emb_space_name='esm1v')
emb_handler.add_emb_space(embeddings=emb_t5, emb_space_name='t5')

102 samples loaded.
Categories in meta data: ['family']
Numerical columns in meta data: []
8 clusters calculated for esm1b.
Embedding space esm1b added.
Embeddings have length 1280.
8 clusters calculated for esm2.
Embedding space esm2 added.
Embeddings have length 1280.
8 clusters calculated for esm1v.
Embedding space esm1v added.
Embeddings have length 1280.
8 clusters calculated for t5.
Embedding space t5 added.
Embeddings have length 1024.


## Unsupervised Clustering x Metadata

By default ema computes a number of clusters equal to the mean number of categories in the metadata. This is a good starting point, but you can also specify the number of clusters you want to compute.

In [5]:
# specify a specific number of clusters
emb_handler.recalculate_clusters(n_clusters=5, emb_space_name="esm1b")
emb_handler.recalculate_clusters(n_clusters=5, emb_space_name="esm2")
emb_handler.recalculate_clusters(n_clusters=5, emb_space_name="esm1v")

5 clusters calculated for esm1b.
5 clusters calculated for esm2.
5 clusters calculated for esm1v.


In [6]:
emb_handler.get_value_count_per_group(group="family")

Unnamed: 0,family,count
0,Kv,36
1,Kir,15
2,K2P,15
3,TRP,12
4,TRPML,12
5,CNG,5
6,KCa,4
7,HCN,3


### Unsupervised Clusters

In [7]:
emb_handler.plot_feature_cluster_overlap(emb_space_name="esm1v",
                                         feature='family')

### Overlap of unsupervised clusters between embedding spaces

In [8]:
emb_handler.plot_feature_cluster_overlap(emb_space_name="esm1b",
                                         feature='cluster_esm2')

In [9]:
emb_handler.plot_feature_cluster_overlap(emb_space_name="esm2",
                                         feature='cluster_esm1b')

In [10]:
emb_handler.plot_feature_cluster_overlap(emb_space_name="t5",
                                         feature='cluster_esm1b')

## Pairwise distances

### Similarities between ESM1b, ESM1v and ESM2 embeddings

In [11]:
emb_handler.plot_emb_dis_scatter(
    emb_space_name_1='esm1b', 
    emb_space_name_2='t5', 
    distance_metric = 'cityblock_normalised',
)

There is a stronger correlation between the pairwise distances between the embeddings of ESM1b and ESM1v compared to ESM1b and ESM2 for the set of sequencs.

In [12]:
emb_handler.plot_emb_dis_scatter(
    emb_space_name_1='esm1b', 
    emb_space_name_2='esm1v', 
    distance_metric = 'euclidean',
)

using the normalised Euclidean distance to account for the different scales of the embeddings

In [None]:
emb_handler.plot_emb_dis_scatter(
    emb_space_name_1='esm1b', 
    emb_space_name_2='esm1v', 
    distance_metric = 'euclidean_normalised',
)

### Unsupervised clusters

We can inspect how close the initial clusters from one embedding space are in the ther embedding spaces.

In [None]:
emb_handler.plot_emb_dis_scatter(
    emb_space_name_1='esm1b', 
    emb_space_name_2='esm2', 
    distance_metric = 'euclidean',
    colour_group="cluster_esm1b",
    colour_value_1="2",
)

In [None]:
emb_handler.plot_emb_dis_scatter(
    emb_space_name_1='esm1b', 
    emb_space_name_2='esm1v', 
    distance_metric = 'euclidean',
    colour_group="cluster_esm1b",
    colour_value_1="2",
)

## Visualisation of dimensionality reduction x Unsupervised Clustering

### PCA

#### ESM1b coloured by unsupervised clusters in ESM1b

In [13]:
emb_handler.visualise_emb_pca(emb_space_name="esm1b", 
                              colour="cluster_esm1b")

#### ESM1v coloured by unsupervised clusters in ESM1b

In [14]:
emb_handler.visualise_emb_pca(emb_space_name="esm1v", 
                              colour="cluster_esm1b")

#### ESM2 coloured by unsupervised clusters in ESM1b

In [None]:
emb_handler.visualise_emb_pca(emb_space_name="esm2",
                              colour="cluster_esm1b")

## Adding Metadata

### Unsupervised Clustering x Metadata

We can check whether the unsupervised clusters are related to the provided metadata.

In [None]:
emb_handler.plot_feature_cluster_overlap(emb_space_name="esm1b",
                                         feature='family')

In [None]:
emb_handler.plot_feature_cluster_overlap(emb_space_name="esm1v",
                                         feature='family')

In [None]:
emb_handler.plot_feature_cluster_overlap(emb_space_name="esm2",
                                         feature='family')

In [None]:
emb_handler.plot_feature_cluster_overlap(emb_space_name="t5",
                                         feature='family')

### Pairwise distance x Metadata

In [None]:
emb_handler.plot_emb_dis_per_group(emb_space_name="esm1b",
                                           distance_metric='euclidean',
                                           group="family")

In [None]:
emb_handler.plot_emb_dis_per_group(emb_space_name="esm2",
                                           distance_metric='euclidean',
                                           group="family")

In [None]:
emb_handler.plot_emb_dis_per_group(emb_space_name="t5",
                                           distance_metric='euclidean',
                                           group="family",
                                           # group_value="TRPML"
                                           )

In [None]:
emb_handler.plot_emb_dis_scatter(
    emb_space_name_1='esm1b', 
    emb_space_name_2='esm1v', 
    distance_metric = 'euclidean',
    colour_group="family",
    colour_value_1="Kir",
)

In [None]:
emb_handler.plot_emb_dis_scatter(
    emb_space_name_1='esm1b', 
    emb_space_name_2='esm2', 
    distance_metric = 'euclidean',
    colour_group="family",
    colour_value_1="CNG",
    colour_value_2="HCN"
)

In [None]:
emb_handler.plot_emb_dis_scatter(
    emb_space_name_1='esm1b', 
    emb_space_name_2='t5', 
    distance_metric = 'cityblock_normalised',
    colour_group="family",
    colour_value_1="TRPML",
    # colour_value_2="HCN"
)

In [None]:
emb_handler.visualise_emb_tsne(emb_space_name="esm1b",
                                colour="family")

In [None]:
emb_handler.visualise_emb_tsne(emb_space_name="t5",
                                colour="family")