# A symmetric bilinear form for embeddings

In this notebook, we train a simple bilinear symmetric model to use it in learnMSA. The task is a categorical classification task where to inputs are pairs of sequences: Query and target. The model should predict for each residual in the query sequence the correct residual of the target (if any) it is aligned to in the reference MSA.

Let $x_i,y_j$ where $x$ is the query and $y$ is the target be $d$-dimensional embeddings. The model is

$$f(x_i, y_j) = \frac{\exp(s_{i,j})} {\sum_j' \exp(s_{i,j'})} $$

with $s_{i,j} = x_i^T R R^T y_j + b$.

where $R \in \mathbb{R}^{d \times k}$ is a parameter matrix and $k<d$ can be controlled for dimensionality reduction. $b$ is a scalar bias.

When used in learnMSA, we can compute embedding-based emission probabilities for a $k$-dimensional match kernel $m_i$ and an $d$-dimensional embeddings $x_j$ with pre-trained $R$ and $b$ like so:

$$P(x_j \mid m_i) = \frac{\exp(s'_{i,j})} {\sum_j' \exp(s'_{i,j'})}$$

with $s'_{i,j} = x_j R m_i + b$.

In this case only $m$ is learned and $R$ and $b$ are fixed. 

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
import numpy as np
import sys
sys.path.append('../../learnMSA')
from learnMSA import msa_hmm
import os
from BilinearSymmetric import SymmetricBilinearReduction, BackgroundEmbedding
import pandas as pd
import seaborn as sns
import io
import os
from multiprocessing import Pool

2023-05-12 08:44:21.979238: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-12 08:44:22.440135: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/opt/conda/lib
2023-05-12 08:44:22.440335: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/opt/conda/lib


## Data preparation Task 2: Categorical classification of embedding

Use Pfams clan hierarchy as the basis for train/test splitting and batch sampling. A clan is a collection of related Pfam entries. The relationship may be defined by similarity of sequence, structure or profile-HMM.

In [2]:
!head ../../PFAM/Pfam-A.clans.tsv

PF00001	CL0192	GPCR_A	7tm_1	7 transmembrane receptor (rhodopsin family)
PF00002	CL0192	GPCR_A	7tm_2	7 transmembrane receptor (Secretin family)
PF00003	CL0192	GPCR_A	7tm_3	7 transmembrane sweet-taste receptor of 3 GCPR
PF00004	CL0023	P-loop_NTPase	AAA	ATPase family associated with various cellular activities (AAA)
PF00005	CL0023	P-loop_NTPase	ABC_tran	ABC transporter
PF00006	CL0023	P-loop_NTPase	ATP-synt_ab	ATP synthase alpha/beta family, nucleotide-binding domain
PF00007	CL0079	Cystine-knot	Cys_knot	Cystine-knot domain
PF00008	CL0001	EGF	EGF	EGF-like domain
PF00009	CL0023	P-loop_NTPase	GTP_EFTU	Elongation factor Tu GTP binding domain
PF00010			HLH	Helix-loop-helix DNA-binding domain


In [3]:
clans_df = pd.read_csv("../../PFAM/Pfam-A.clans.tsv", header=None, sep="\t")
clans_df[1] = clans_df[1].fillna(clans_df[0]) #families with no clan become their own clans
clans_df.set_index(0, inplace=True)
clans_df.drop([2,3,4], axis=1, inplace=True)
clans_df.rename(columns={1 : "clan"}, inplace=True)

In [4]:
clans_df.loc["PF00001"].clan

'CL0192'

In [5]:
np.random.seed(77)

#sequences longer than this value were truncated by the LM 
#todo: redo the embeddings without truncation
#for simplicity we omit families with at least one sequence longer than this value (0.3% of all families)
lm_model_truncation_value = 1022

def get_family(filepath):
    return ".".join(os.path.basename(filepath).split(".")[:-1])

#load all fasta files
#(takes a while)
datasets = "../../PFAM/alignments/"
ext = ".fasta"

clans_df["fasta_index"] = np.nan

# load all ref alignments
fasta_files = []
to_drop = []
truncated_clans = set()
for file in os.listdir(datasets):
    if file.endswith(ext):
        family = ".".join(file.split(".")[:-1])
        fasta = msa_hmm.fasta.Fasta(datasets+file, aligned=True, single_seq_ok=True)
        #omit families with only one sequence or families with at least one truncated sequence
        if fasta.num_seq == 1 or np.any(fasta.seq_lens > lm_model_truncation_value):
            to_drop.append(family)
            truncated_clans.add(clans_df.loc[family].clan)
            continue
        fasta_files.append(fasta)
        clans_df.loc[family, "fasta_index"] = len(fasta_files)-1
        
#drop families
clans_df = clans_df.drop(to_drop, axis=0)

assert not clans_df.isna().values.any()
clans_df = clans_df.astype({"fasta_index": "int32"})

#preprocessing
seq_lens, starting_pos, seq_pos_to_column = {}, {}, {}
for f in fasta_files:
    family = get_family(f.filename)
    seq_lens[family] = f.seq_lens
    starting_pos[family] = f.starting_pos
    seq_pos_to_column[family] = f.membership_targets

unique_clans = clans_df.clan.unique()
num_clans = unique_clans.size
print(f"{num_clans} clans in total")
clans_df.head()

11865 clans in total


Unnamed: 0_level_0,clan,fasta_index
0,Unnamed: 1_level_1,Unnamed: 2_level_1
PF00001,CL0192,13877
PF00002,CL0192,12061
PF00003,CL0192,12652
PF00004,CL0023,6410
PF00005,CL0023,13836


In [6]:
prot_model = "esm"

def load(clan):
    clan_families = clans_df[clans_df.clan == clan]
    return {family : np.load(f"{prot_model}/pfam/{family}.npy").astype(np.float16) for family in clan_families.index}

# warning, this loads ~1.8TB directly into memory
with Pool(8) as p:
    clan_embeddings = p.map(load, unique_clans)

In [7]:
for _,emb in clan_embeddings[0].items():
    emb_dim = emb.shape[-1]
clan_sizes = np.array([len(emb) for emb in clan_embeddings])

def _get_features_labels(emb, lens, start, pos_to_col, rand):
    n = lens.size
    i = int(np.floor(rand * n))
    s = start[i]
    t = s + lens[i]
    return emb[s:t], pos_to_col[s:t], lens[i]

def make_dataset(clans, batch_size):
    def _gen_inputs():
        """ Generates a batch of training examples where one example is generated as:
            1. Sample a random clan
            2. Sample a random family from this clan
            3. Sample 2 random sequences from this family
            4. Compute the label matrix that has the 2 sequence lengths as rows/columns 
                and holds a 1 where residues share the same alignment column.
        Returns:
            (query_seqs, target_seqs), labels
            Where query_seqs and target_seqs are batches of embedded sequences (b, L1, d), (b, L2, d) 
            and labels are of shape (b, L1, L2)
        """
        while True:
            #sample clan
            batch_clans = np.random.choice(clans, size=batch_size)
            #sample family
            rand_family = np.random.rand(batch_size)
            batch_families = np.floor(rand_family * clan_sizes[batch_clans]).astype(batch_clans.dtype)
            #sample sequences (with replacement for speed)
            rand_seq = np.random.rand(batch_size, 2)
            emb1, emb2, labels = [], [], []
            max_len_1, max_len_2 = -1, -1
            for c,f,r in zip(batch_clans, batch_families, rand_seq):
                f_name = list(clan_embeddings[c].keys())[f]
                emb = clan_embeddings[c][f_name]
                lens = seq_lens[f_name]
                start = starting_pos[f_name]
                pos_to_col = seq_pos_to_column[f_name]
                assert emb.shape[0] == pos_to_col.size
                e1, c1, l1 = _get_features_labels(emb, lens, start, pos_to_col, r[0])
                e2, c2, l2 = _get_features_labels(emb, lens, start, pos_to_col, r[1])
                max_len_1 = max(max_len_1, l1)
                max_len_2 = max(max_len_2, l2)
                emb1.append(e1)
                emb2.append(e2)
                labels.append((c1[:, np.newaxis] == c2[np.newaxis, :]).astype(np.float32))
            #merge everything in padded tensors
            batched_emb1 = np.zeros((batch_size, max_len_1, emb_dim), dtype=np.float32)
            batched_emb2 = np.zeros((batch_size, max_len_2, emb_dim), dtype=np.float32)
            batched_labels = np.zeros((batch_size, max_len_1, max_len_2), dtype=np.float32)
            for i, (e1, e2, label) in enumerate(zip(emb1, emb2, labels)):
                batched_emb1[i, :e1.shape[0]] = e1
                batched_emb2[i, :e2.shape[0]] = e2
                batched_labels[i, :e1.shape[0], :e2.shape[0]] = label
            yield (batched_emb1, batched_emb2), batched_labels
            
    output_signature = ((tf.TensorSpec(shape=(batch_size, None, emb_dim), dtype=tf.float32), 
                         tf.TensorSpec(shape=(batch_size, None, emb_dim), dtype=tf.float32)), 
                            tf.TensorSpec(shape=(batch_size, None, None), dtype=tf.float32))
    ds = tf.data.Dataset.from_generator(_gen_inputs, output_signature = output_signature)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

In [8]:
#split by clan, test on kept back clans
test_p = 0.15
val_p = 0.15
assert test_p + val_p < 1
num_test = int(num_clans * test_p)
num_val = int(num_clans * val_p)
num_train = num_clans - num_test
a = np.arange(num_clans)
np.random.shuffle(a)
test_clans = a[:num_test]
val_clans = a[num_test:num_test+num_val]
train_clans = a[num_test+num_val:]

In [9]:
batch_size = 32
test_ds = make_dataset(test_clans, batch_size)
val_ds = make_dataset(val_clans, batch_size)
train_ds = make_dataset(train_clans, batch_size)

## Training

In [7]:
def masked_crossentropy(y_true, y_pred):
    mask = tf.reduce_any(tf.not_equal(y_true, 0), -1)
    y_true_masked = tf.boolean_mask(y_true, mask)
    y_pred_masked = tf.boolean_mask(y_pred, mask)
    cee = tf.keras.metrics.categorical_crossentropy(y_true_masked, y_pred_masked)
    return tf.reduce_mean(cee)
    
def make_model(reduced_dim = 256, dropout = 0.2):
    # input to the training pipeline are pairs of embeddings
    emb1 = tf.keras.layers.Input(shape=(None, emb_dim))
    emb2 = tf.keras.layers.Input(shape=(None, emb_dim))

    # outputs are homology probabilities 
    output = SymmetricBilinearReduction(reduced_dim,
                                        dropout, 
                                        use_attention_scores = True)(emb1, emb2)

    # construct a model and compile for a standard binary classification task
    model = tf.keras.models.Model(inputs=[emb1, emb2], outputs=output)

    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=1e-2,
        decay_steps=1000,
        decay_rate=0.9)

    model.compile(loss=masked_crossentropy, 
                  optimizer=tf.keras.optimizers.Adam(lr_schedule),
                  metrics=[tf.keras.metrics.CategoricalAccuracy()])
    
    return model

In [11]:
full_ds = make_dataset(a, batch_size)
final_model = make_model(reduced_dim = 64, dropout = 0.2)
final_model.fit(full_ds, epochs=10, steps_per_epoch=10000)
final_model.save(f"{prot_model}/bilinear_form_model_pfam_attention")

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10




INFO:tensorflow:Assets written to: esm/bilinear_form_model_pfam_attention/assets


INFO:tensorflow:Assets written to: esm/bilinear_form_model_pfam_attention/assets


In [None]:
model = tf.keras.models.load_model("esm/bilinear_form_model_pfam_attention", 
                                  custom_objects = {"SymmetricBilinearReduction" : SymmetricBilinearReduction, 
                                                    "masked_crossentropy" : masked_crossentropy})

In [6]:
model.save_weights("esm/bilinear_form_model_pfam_attention/checkpoints")