# A symmetric bilinear form for embeddings

In this notebook, we use a simple bilinear symmetric model to classify pairs of embeddings as homolog (i.e. whether the corresponding residues are in the same column in a reference MSA).

Let $x,y$ be $d$-dimensional embeddings. The model is

$$f(x,y) = \sigma(x^T R R^T y + b)$$

where $R \in \mathbb{R}^{d \times k}$ is a parameter matrix and $k<d$ can be controlled for dimensionality reduction. $b$ is a scalar bias.

The model is trained in a binary classification task, where the homology labels are induced by reference multiple sequence alignments, i.e. the label is 1 if the positions of two embeddings share the same column and 0 otherwise.

When used in learnMSA, we can compute embedding-based emission probabilities for an $d$-dimensional embeddings $x$ and a $k$-dimensional match kernel $m$ with pre-trained $R$ and $b$ like so:

$P(x \mid m) = \sigma(x R m + b)$.

In this case only $m$ is learned and $R$ and $b$ are fixed.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import tensorflow as tf
import numpy as np
import sys
sys.path.append('../../learnMSA')
from learnMSA import msa_hmm
import os
from BilinearSymmetric import SymmetricBilinearReduction, BackgroundEmbedding
import pandas as pd
import seaborn as sns
import io
import os
from multiprocessing import Pool

## Data preparation Task 1: Binary classification of embeddings sharing a column

Use Pfams clan hierarchy as the basis for train/test splitting and batch sampling. A clan is a collection of related Pfam entries. The relationship may be defined by similarity of sequence, structure or profile-HMM.



In [2]:
!head ../../PFAM/Pfam-A.clans.tsv

PF00001	CL0192	GPCR_A	7tm_1	7 transmembrane receptor (rhodopsin family)
PF00002	CL0192	GPCR_A	7tm_2	7 transmembrane receptor (Secretin family)
PF00003	CL0192	GPCR_A	7tm_3	7 transmembrane sweet-taste receptor of 3 GCPR
PF00004	CL0023	P-loop_NTPase	AAA	ATPase family associated with various cellular activities (AAA)
PF00005	CL0023	P-loop_NTPase	ABC_tran	ABC transporter
PF00006	CL0023	P-loop_NTPase	ATP-synt_ab	ATP synthase alpha/beta family, nucleotide-binding domain
PF00007	CL0079	Cystine-knot	Cys_knot	Cystine-knot domain
PF00008	CL0001	EGF	EGF	EGF-like domain
PF00009	CL0023	P-loop_NTPase	GTP_EFTU	Elongation factor Tu GTP binding domain
PF00010			HLH	Helix-loop-helix DNA-binding domain


In [3]:
clans_df = pd.read_csv("../../PFAM/Pfam-A.clans.tsv", header=None, sep="\t")
clans_df[1] = clans_df[1].fillna(clans_df[0]) #families with no clan become their own clans
clans_df.set_index(0, inplace=True)
clans_df.drop([2,3,4], axis=1, inplace=True)
clans_df.rename(columns={1 : "clan"}, inplace=True)

In [4]:
clans_df.loc["PF00001"].clan

'CL0192'

In [5]:
np.random.seed(77)

#sequences longer than this value were truncated by the LM 
#todo: redo the embeddings without truncation
#for simplicity we omit families with at least one sequence longer than this value (0.3% of all families)
lm_model_truncation_value = 1022

def get_family(filepath):
    return ".".join(os.path.basename(filepath).split(".")[:-1])

#load all fasta files
#(takes a while)
datasets = "../../PFAM/alignments/"
ext = ".fasta"

clans_df["fasta_index"] = np.nan

# load all ref alignments
fasta_files = []
to_drop = []
truncated_clans = set()
for file in os.listdir(datasets):
    if file.endswith(ext):
        family = ".".join(file.split(".")[:-1])
        fasta = msa_hmm.fasta.Fasta(datasets+file, aligned=True, single_seq_ok=True)
        #omit families with only one sequence or families with at least one truncated sequence
        if fasta.num_seq == 1 or np.any(fasta.seq_lens > lm_model_truncation_value):
            to_drop.append(family)
            truncated_clans.add(clans_df.loc[family].clan)
            continue
        fasta_files.append(fasta)
        clans_df.loc[family, "fasta_index"] = len(fasta_files)-1
        
#drop families
clans_df = clans_df.drop(to_drop, axis=0)

assert not clans_df.isna().values.any()
clans_df = clans_df.astype({"fasta_index": "int32"})

#preprocessing
seq_lens, starting_pos, seq_pos_to_column = {}, {}, {}
for f in fasta_files:
    family = get_family(f.filename)
    seq_lens[family] = f.seq_lens
    starting_pos[family] = f.starting_pos
    seq_pos_to_column[family] = f.membership_targets

unique_clans = clans_df.clan.unique()
num_clans = unique_clans.size
print(f"{num_clans} clans in total")
clans_df.head()

11865 clans in total


Unnamed: 0_level_0,clan,fasta_index
0,Unnamed: 1_level_1,Unnamed: 2_level_1
PF00001,CL0192,13877
PF00002,CL0192,12061
PF00003,CL0192,12652
PF00004,CL0023,6410
PF00005,CL0023,13836


In [6]:
#cannot load all precomputed embeddings at once into memory
#and loading them on demand in the data pipeline is too slow
#with a memory limit of 100GB 
#assuming that a single embedding vector has size 2560 and the datatype is float16
#one embedding is 2560 * 2 /1000 /1000 = 0.00512MB
#therefore we can store 100000 / 0.00512 = 19531250 embeddings
#thats about 16 per sequence 
#and about 1560 per clan
!du -h esm/pfam

1.8T	esm/pfam


In [7]:
prot_model = "esm"
!mkdir -p {prot_model}/tmp
embeddings_per_clan = 1500
    
if not os.path.exists("esm/sampled_embeddings.npy"):
    def load(clan):
        if os.path.exists(f"{prot_model}/tmp/{clan}_embeddings.npy"):
            return
        clan_families = clans_df[clans_df.clan == clan]
        clan_embeddings = {family : np.load(f"{prot_model}/pfam/{family}.npy").astype(np.float16) for family in clan_families.index}
        family_sample = clan_families.sample(embeddings_per_clan, replace=True).index
        emb, fam, col = [], [], []
        for j, family in enumerate(family_sample):
            seq = np.random.randint(seq_lens[family].size, size=1)[0]
            pos_in_seq = np.random.randint(seq_lens[family][seq], size=1)[0]
            emb.append( clan_embeddings[family][starting_pos[family][seq] + pos_in_seq] )
            fam.append( clans_df.index.get_loc(family) )
            col.append( seq_pos_to_column[family][starting_pos[family][seq] + pos_in_seq] )
        np.save(f"{prot_model}/tmp/{clan}_embeddings.npy", np.stack(emb, axis=0))
        np.save(f"{prot_model}/tmp/{clan}_families.npy", np.array(fam))
        np.save(f"{prot_model}/tmp/{clan}_columns.npy", np.array(col))

    with Pool(8) as p:
        p.map(load, unique_clans)

In [8]:
if not os.path.exists("esm/sampled_embeddings.npy"):
    total_num_embeddings = num_clans * embeddings_per_clan
    emb_dim = 2560
    embeddings = np.zeros((num_clans, embeddings_per_clan, emb_dim), dtype=np.float16)
    families = np.zeros((num_clans, embeddings_per_clan), dtype=np.int32)
    columns = np.zeros((num_clans, embeddings_per_clan), dtype=np.int32)
    for i, clan in enumerate(unique_clans):
        emb = np.load(f"{prot_model}/tmp/{clan}_embeddings.npy")
        fam = np.load(f"{prot_model}/tmp/{clan}_families.npy")
        col = np.load(f"{prot_model}/tmp/{clan}_columns.npy")
        embeddings[i] = emb
        families[i] = fam
        columns[i] = col
    np.save(f"{prot_model}/sampled_embeddings.npy", embeddings)
    np.save(f"{prot_model}/sampled_families.npy", families)
    np.save(f"{prot_model}/sampled_columns.npy", columns)
else:
    sampled_embeddings = np.load(f"{prot_model}/sampled_embeddings.npy")
    sampled_families = np.load(f"{prot_model}/sampled_families.npy")
    sampled_columns = np.load(f"{prot_model}/sampled_columns.npy")

In [19]:
emb_dim = sampled_embeddings.shape[-1]
A = np.arange(embeddings_per_clan)
def make_dataset(clans, batch_size):
    def _gen_inputs():
        """ Generates one batch of training examples by drawing a random embedding pair.
        """
        while True:
            #sample random clans
            c = np.random.choice(clans, size=batch_size)
            
            #sample a random embeddings from clans
            i = np.random.randint(embeddings_per_clan, size=batch_size)
            
            #sample another random embedding from the same family as the first one
            batch_families = sampled_families[c, i]
            rand = np.random.rand(batch_size)
            j = []
            for clan, f, r in zip(sampled_families[c], batch_families, rand):
                same_family = A[clan == f]
                j.append(same_family[np.floor(r * same_family.size).astype(i.dtype)])
            j = np.array(j)
            
            #label is 1 if and only if they share the same alignment column
            label = np.float32(sampled_columns[c,i] == sampled_columns[c,j])
            yield (sampled_embeddings[c,i], sampled_embeddings[c,j]), np.reshape(label, (batch_size))
            
    output_signature = ((tf.TensorSpec(shape=(batch_size, emb_dim), dtype=tf.float32), 
                         tf.TensorSpec(shape=(batch_size, emb_dim), dtype=tf.float32)), 
                            tf.TensorSpec(shape=(batch_size), dtype=tf.float32))
    ds = tf.data.Dataset.from_generator(_gen_inputs, output_signature = output_signature)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

In [20]:
#split by clan, test on kept back clans
test_p = 0.15
val_p = 0.15
assert test_p + val_p < 1
num_test = int(num_clans * test_p)
num_val = int(num_clans * val_p)
num_train = num_clans - num_test
a = np.arange(num_clans)
np.random.shuffle(a)
test_clans = a[:num_test]
val_clans = a[num_test:num_test+num_val]
train_clans = a[num_test+num_val:]

In [21]:
batch_size = 512
test_ds = make_dataset(test_clans, batch_size)
val_ds = make_dataset(val_clans, batch_size)
train_ds = make_dataset(train_clans, batch_size)

In [22]:
np.mean([y for x,y in test_ds.take(100)])

0.011054687

In [23]:
np.mean([y for x,y in train_ds.take(100)])

0.010488281

## Training

In [39]:
#tackle the class imbalance problem
def make_weighted_bce(pos_weight):
    bce = tf.keras.losses.BinaryCrossentropy()
    def weighted_bce(y_true, y_pred):
        weights = y_true * pos_weight + (1-y_true)
        return bce(y_true, y_pred, sample_weight=weights)
    return weighted_bce
    
def make_model(reduced_dim = 256, dropout = 0.2, l2 = 0., pos_weight=1.):
    # input to the training pipeline are pairs of embeddings
    emb1 = tf.keras.layers.Input(shape=(emb_dim,))
    emb2 = tf.keras.layers.Input(shape=(emb_dim,))

    # outputs are homology probabilities 
    output = SymmetricBilinearReduction(reduced_dim, dropout, l2)(emb1[:,tf.newaxis], emb2[:,tf.newaxis])

    # construct a model and compile for a standard binary classification task
    model = tf.keras.models.Model(inputs=[emb1, emb2], outputs=output)

    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=1e-2,
        decay_steps=1000,
        decay_rate=0.9)

    model.compile(loss=make_weighted_bce(pos_weight), 
                  optimizer=tf.keras.optimizers.Adam(lr_schedule),
                  metrics=["accuracy", tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
    
    return model

In [None]:
with open(f"{prot_model}/fit_log.txt", "w") as file:
    for dim in [64, 128, 256]:
        for dropout in [0.1, 0.2]:
            model = make_model(dim, dropout)
            model.fit(train_ds, epochs=20, steps_per_epoch=10000, verbose=0)
            file.write(f"dim = {dim} dropout = {dropout}\n")
            file.write(f"val results = {model.evaluate(val_ds, steps=10000, verbose=0)}\n")
            file.write(f"test results = {model.evaluate(test_ds, steps=10000, verbose=0)}\n")
            file.flush()

## Retrain the best configuration on all data

In [57]:
reduced_dim = 128
dropout = 0.1
l2 = 0
pos_weight = 4
full_ds = make_dataset(a, batch_size)
final_model = make_model(reduced_dim = reduced_dim, dropout = dropout, l2 = l2, pos_weight = pos_weight)
final_model.fit(full_ds, epochs=10, steps_per_epoch=10000)
final_model.save(f"{prot_model}/bilinear_form_model_pfam_{reduced_dim}_{dropout}_{l2}_{pos_weight}")

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10




INFO:tensorflow:Assets written to: esm/bilinear_form_model_pfam_128_0.1_0_4/assets


INFO:tensorflow:Assets written to: esm/bilinear_form_model_pfam_128_0.1_0_4/assets


In [None]:
loaded_model = tf.keras.models.load_model(f"{prot_model}/bilinear_form_model_pfam_{reduced_dim}_{dropout}_{l2}_{pos_weight}", 
                                          custom_objects = {"SymmetricBilinearReduction" : SymmetricBilinearReduction, 
                                                            "weighted_bce" : make_weighted_bce(pos_weight)})
loaded_model.layers[-1].b

In [47]:
loaded_model = tf.keras.models.load_model(f"{prot_model}/bilinear_form_model_pfam_{reduced_dim}_{dropout}_{0}_{pos_weight}", 
                                          custom_objects = {"SymmetricBilinearReduction" : SymmetricBilinearReduction, 
                                                            "weighted_bce" : make_weighted_bce(pos_weight)})
loaded_model.layers[-1].b

<tf.Variable 'symmetric_bilinear_reduction_7/b:0' shape=(1,) dtype=float32, numpy=array([-39.687576], dtype=float32)>

## Test if the model can be loaded again without errors

In [55]:
loaded_model = tf.keras.models.load_model(f"{prot_model}/bilinear_form_model_pfam_{reduced_dim}_{dropout}_{l2}_{pos_weight}", 
                                          custom_objects = {"SymmetricBilinearReduction" : SymmetricBilinearReduction, 
                                                            "weighted_bce" : make_weighted_bce(pos_weight)})

In [56]:
loaded_model.evaluate(test_ds, steps=1000)



[0.04159269854426384,
 0.995437502861023,
 0.8132928609848022,
 0.7790518403053284]