<a href="https://colab.research.google.com/github/dyl4nm4rsh4ll/funsae/blob/master/seqsal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ------------------------------------------------------------
# "THE BEERWARE LICENSE" (Revision 42):
# ------------------------------------------------------------
# <dylan_marshall@fas.harvard.com>, <so@g.harvard.edu> and 
# <koo@cshl.edu> wrote this code. As long as you retain this
# notice, you can do whatever you want with this stuff. If we 
# meet someday, and you think this stuff is worth it, you can
# buy us a beer in return.
# -Dylan Marshall, Sergey Ovchinnikov and Peter Koo
# ------------------------------------------------------------

#Load Libraries

In [None]:
%tensorflow_version 2.x
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v1.keras.backend as K1
import tensorflow.keras.backend as K

tf1.disable_eager_execution()

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Dense, Flatten, Reshape, Activation, Dropout, BatchNormalization, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.initializers import Zeros, Constant

In [None]:
import numpy as np
from scipy.stats import spearmanr

#MODELS

In [None]:
def train_model(model, X, W, schedule, verbose=False):
  N = X.shape[0]
  idx = np.arange(N)
  for batch_size, epochs in schedule:
    if batch_size == N:
      model.fit(X, X, sample_weight=W, batch_size=batch_size, verbose=verbose, epochs=epochs)
    else:
      for e in range(epochs):
        np.random.shuffle(idx)
        model.fit(X[idx], X[idx], sample_weight=W[idx], batch_size=batch_size, verbose=verbose, epochs=1)

###MRF

In [None]:
def mrf(X, W, use_bias=False, lam=0.01, train=True):
  
  N,L,A = X.shape
  F = L*A
  
  # clear graph
  tf1.reset_default_graph()
  K.clear_session()
    
  #####################################################
  # setup kernel
  #####################################################
  def cst_w(weights):
    weights = (weights + K.transpose(weights)) / 2
    mask = K.constant((1-np.eye(L))[:,None,:,None], dtype=tf.float32)
    weights = K.reshape(weights, (L,A,L,A)) * mask
    return K.reshape(weights,(F,F))
  
  params = {"units":F,
            "kernel_initializer":Zeros,
            "kernel_regularizer":l2((lam/N)*(L-1)*(A-1)/2),
            "kernel_constraint":cst_w}
  
  #####################################################
  # setup bias
  #####################################################
  if use_bias:
    init_v = np.log((X.T*W).sum(-1).T + lam*np.log(W.sum()))
    params["bias_initializer"] = Constant(init_v - init_v.mean(-1, keepdims=True))
    params["bias_regularizer"] = l2(lam/N)
  else:
    params["use_bias"] = False
   
  #####################################################
  # setup model
  #####################################################
  model = Sequential()
  model.add(Flatten(input_shape=(L,A)))
  model.add(Dense(**params))
  model.add(Reshape((L, A)))
  model.add(Activation("softmax"))
  
  #####################################################
  # compile and train
  #####################################################
  if train:
    def loss(p, q):
      return K.sum(K.categorical_crossentropy(p,q),-1)
    model.compile(Adam(0.1*np.log(W.sum())/L), loss)
    train_model(model,X,W,[[N,200]])
  
  return model

###LAE

In [None]:
# Linear Auto-Encoder
def lae(X, W, 
        enc=[], rank=256, dec=[],
        lam_w=0.1, lam_e=1.0, use_e=True,
        use_bias=False, train=True):
  
  N,L,A = X.shape
  F = L*A

  # clear graph
  tf1.reset_default_graph()
  K.clear_session()
    
  # model params
  params_w = {"use_bias":use_bias, "kernel_regularizer":l2(lam_w * F/N)}
  params_e = {"use_bias":use_bias, "kernel_regularizer":l2(lam_e)}
  
  #####################################################
  # encoder
  #####################################################
  model = Sequential()
  model.add(Flatten(input_shape=(L,A)))
  for unit in enc: model.add(Dense(uni, **params_w))
  model.add(Dense(rank, **params_w))
  
  #####################################################
  # decoder
  #####################################################
  for unit in dec: model.add(Dense(uni, **params_w))
  model.add(Dense(F, **params_w))
  model.add(Reshape((L,A)))
  if use_e: model.add(Dense(A, **params_e))
  model.add(Activation("softmax"))
  
  #####################################################
  # compile and train
  #####################################################
  if train:
    def loss(p, q):
      return K.sum(K.categorical_crossentropy(p,q),-1)
    
    model.compile(Adam(0.1*np.log(W.sum())/L), loss)
    train_model(model,X,W,[[32,5],[64,10],[128,20],[N,100]])  
  
  return model

## VAE

In [None]:
def vae(X, W,
        enc=[512, 512], rank=32, dec=[512, 512],
        drop=0.5, lam=0.0, beta=0.5, train=True):
  
  N,L,A = X.shape
  F = L * A
  
  # clear graph
  tf1.reset_default_graph()
  K.clear_session()
    
  # model params
  params = {"activation":"selu"}
  if lam > 0:
    params["kernel_regularizer"] = l2(lam * F/N)
  
  #####################################################
  # encoder (E)
  #####################################################
  E_I = Input((L,A))
  E = Flatten()(E_I)
  for unit in enc:
    E = Dense(unit, **params)(E)
    E = Dropout(drop)(E)
    E = BatchNormalization()(E)
  
  #####################################################
  # latent (Z)
  #####################################################
  Z_mu = Dense(rank)(E)
  Z_log_sg = Dense(rank)(E)
  Z_sg = Lambda(lambda x: K.exp(0.5 * x))(Z_log_sg)
  Z = Lambda(lambda x: x[0]+x[1]*K.random_normal(K.shape(x[0])))([Z_mu, Z_sg])
  
  model_EN    = Model(E_I, Z,    name="model_EN")
  model_EN_mu = Model(E_I, Z_mu, name="model_EN_mu")
  
  #####################################################
  # decoder (D)
  #####################################################
  D_I = Input((rank,))
  D = D_I
  for unit in dec:
    D = Dense(unit, **params)(D)
    D = Dropout(drop)(D)
    D = BatchNormalization()(D)
    
  D = Dense(F, **params)(D)
  D = Reshape((L,A))(D)
  D_O = Activation("softmax")(D)  
  model_DE = Model(D_I, D_O, name="model_DE")
  
  #####################################################
  # autoencoder
  #####################################################
  model    = Model(E_I, model_DE(model_EN(E_I)))
  model_mu = Model(E_I, model_DE(model_EN_mu(E_I)))

  #####################################################
  # compile and train
  #####################################################
  if train:  
    def loss(p,q):
      RE = K.sum(K.categorical_crossentropy(p,q),-1)
      KL = beta*K.sum(K.square(Z_mu)+K.square(Z_sg)-Z_log_sg-1.0,-1)
      return RE + KL
    model.compile("adam",loss)
    train_model(model,X,W,[[64,50],[128,50],[256,50],[512,50],[1024,50],[2048,50]])
        
  return model_mu

#OTHER

In [None]:
def pw_saliency(model):
  sess = K1.get_session()
  i = K1.placeholder(shape=[],dtype=tf.int32)
  out = model.output[:,i]
  L,A = [int(s) for s in model.output.get_shape()[1:]]
  sal = -K1.gradients(-K.sum(np.eye(A)*K.log(out + 1e-8)), model.input)[0]
  null = np.zeros((A,L,A))
  pw = np.array([sess.run(sal, {i:j, model.input:null}) for j in range(L)])
  return 0.5*(pw+np.transpose(pw,(2,3,0,1)))

def pw_contact_map(pw):
  l2_norm = np.sqrt(np.square(pw[:,:20,:,:20]).sum((1,3)))
  np.fill_diagonal(l2_norm, 0.0)
  ap = l2_norm.sum(0)
  ap = ap[None,:]*ap[:,None]/ap.sum()
  l2_norm_apc = l2_norm - ap
  np.fill_diagonal(l2_norm_apc, 0.0)
  return l2_norm_apc

def contact_auc(pred, meas, thresh=0.01):
  eval_idx = np.triu_indices_from(meas, 6)
  pred_, meas_ = pred[eval_idx], meas[eval_idx] 
  L = (np.linspace(0.1,1.0,10) * len(meas)).astype("int")
  sort_idx = np.argsort(pred_)[::-1]
  return np.mean([(meas_[sort_idx[:l]] > thresh).mean() for l in L])

def sco(p,q):
  return (p * np.log(q + 1e-8)).sum((1,2))

#DATA

In [None]:
%%bash
pip -q install gdown
gdown https://drive.google.com/uc?id=1MqhiNhwgLNJ-rnD_YCzEniDYizEc2-YP

Downloading...
From: https://drive.google.com/uc?id=1MqhiNhwgLNJ-rnD_YCzEniDYizEc2-YP
To: /content/data.npy
0.00B [00:00, ?B/s]4.72MB [00:00, 17.2MB/s]27.3MB [00:00, 23.8MB/s]34.6MB [00:00, 27.6MB/s]59.8MB [00:00, 37.6MB/s]71.8MB [00:00, 42.5MB/s]86.0MB [00:00, 53.8MB/s]101MB [00:01, 52.3MB/s] 111MB [00:01, 49.0MB/s]129MB [00:01, 62.4MB/s]141MB [00:01, 70.9MB/s]158MB [00:01, 86.1MB/s]174MB [00:01, 100MB/s] 188MB [00:02, 87.3MB/s]213MB [00:02, 108MB/s] 229MB [00:02, 78.6MB/s]252MB [00:02, 85.0MB/s]274MB [00:02, 104MB/s] 288MB [00:03, 95.0MB/s]301MB [00:03, 97.3MB/s]313MB [00:03, 89.8MB/s]329MB [00:03, 103MB/s] 344MB [00:03, 95.9MB/s]368MB [00:03, 116MB/s] 383MB [00:03, 85.2MB/s]409MB [00:04, 107MB/s] 426MB [00:04, 83.1MB/s]440MB [00:04, 95.5MB/s]454MB [00:04, 87.7MB/s]466MB [00:04, 77.4MB/s]488MB [00:04, 95.9MB/s]502MB [00:05, 87.6MB/s]524MB [00:05, 107MB/s] 539MB [00:05, 84.7MB/s]563MB [00:05, 88.1MB/s]587MB [00:05, 109MB/s] 602MB [00:06, 103MB/s]6

In [None]:
[print(a, type(b)) for a, b in data.items()];

X <class 'numpy.ndarray'>
W <class 'numpy.ndarray'>
dX <class 'numpy.ndarray'>
dy <class 'numpy.ndarray'>


In [None]:
data = np.load("data.npy",allow_pickle=True).item()
X,dX = [np.eye(21)[data[k]] for k in ("X","dX")]
W,dY = data["W"], data["dY"]
cons = data["cons"]

In [None]:
#A = []
#B = []
o = 256
for k in range(34,1000):
  model = vae(X, W, rank=o)
  loss = model.evaluate(X,X,sample_weight=W,verbose=False)
  w = pw_saliency(model)
  cons_pred = pw_contact_map(w)
  cons_auc = contact_auc(cons_pred, cons)
  dY_sco = sco(dX, model.predict(dX))
  dY_pssm_sco = sco(dX, model.predict(X[0,None]))
  A.append(dY_sco)
  B.append(dY_pssm_sco)
  print(o,k,loss,spearmanr(dY, dY_sco)[0], spearmanr(dY, dY_pssm_sco)[0], cons_auc, 
        spearmanr(dY, np.mean(A,0))[0], spearmanr(dY, np.mean(B,0))[0])

0.6520947988885377

0.7242218963189837

In [None]:
vae_model = vae(X, W)

NameError: ignored

In [None]:
vae_w = pw_saliency(mrf_model)
vae_cons = pw_contact_map(vae_w)
vae_cons_auc = contact_auc(vae_cons, cons)

In [None]:
contact_auc(mrf_cons, cons)

0.8574528783672488

#DATA

In [None]:
# import internals
import copy, itertools, json, os, pickle, sys, time
# import externals
from matplotlib import animation, cm, colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objs as g
import seaborn as sns
from sklearn.decomposition import PCA
from scipy import special
from scipy import signal as sig
from scipy.spatial.distance import jensenshannon, pdist, squareform, hamming
import scipy.stats as stats


In [None]:
def cce(p, q):
  """categorical cross entropy"""
  return -np.sum(p * np.log(q + 1e-8), axis=(1, 2))

def collate_dms(dms_data, wrt, considered="v_", dms_info=[
  "mut", "x", "y", "ind", "pw", "v_μ", "v_1", "v_2", "v_3", "v_4", "v_5"
]):
  """clean DMS data for evaluation

    dms_data := raw DMS data,
    wrt := valid DMS data indices,
    considered := DMS experiment contextualizing edge cases,
    dms_info := reported DeepSeq data types"""

  # clean s.t. viable mutants wrt MSA
  dms_msa_pre = {
    dms: {
      mut: {
        k: v for k, v in xy.items()
      } for mut, xy in mut_xy.items() if xy["x"] is not None
    } for dms, mut_xy in dms_data.items()
  }
  # ensure edge cases (infs, nans) D.N.E.
  return {
    dms: {
      v: np.stack([
        val[v] for mut, val in dms_msa_pre[dms].items() if all([
          np.isfinite(dms_msa_pre[col][mut][wrt])
            for col in dms_data.keys() if considered in col
        ])
      ]) for v in dms_info
    } for dms in dms_data.keys()
  }

def load_pkl(fname):
  """load pickled file"""
  if "pkl" in fname:
    with open(fname, "rb") as f: return pickle.load(f)
  else: print("check file"); return None

def map_labels(l, _l2c):
  """sequence header phyla to color"""
  return [_l2c[x] if x in _l2c.keys() else _l2c["else"] for x in l]

def sequence_identity(u, v):
  """calculate sequence identity for two sequences"""
  lengths = set([len(u), len(v)])
  assert len(lengths) == 1, print("hmmmm")
  length = list(lengths)[0]
  U = np.array([_2_params["i2a"][np.argmax(i)] for i in u])
  V = np.array([_2_params["i2a"][np.argmax(i)] for i in v])
  idx = np.setdiff1d(
    np.arange(length),
    np.concatenate([np.where(U == "-")[0], np.where(V == "-")[0]])
  )
  return 1 - hamming(U[idx], V[idx])

def split_data(msa, frac=0.1, numpy=True):
  """(training / validation) split

    msa := multiple sequence alignment
    frac := fraction training data assigned to validation,
    numpy := file is numpy format"""
  if numpy:
    data = msa["clean"]
    # samples, length, amino acids
    N, L, A = data.shape
    # number partition samples, shuffled indices
    n, shuff = int(N * frac), np.random.permutation(N)
    indices = {"train": (n, N), "valid": (0, n)}
    # train, valid
    return {
      part: {
        "x": data[shuff[idx[0]:idx[1]]].reshape((-1, L, A)),
        "weights": msa["weights"][shuff[idx[0]:idx[1]]],
        "phyla": msa["phyla"][shuff[idx[0]:idx[1]]],
        "seq_id": msa["seq_id"][shuff[idx[0]:idx[1]]]
      } for part, idx in indices.items()
    }
  # TODO
  else: # file is HDF5
    train_valid = np.array(data.get("train")).astype(np.float32)
    # number valid samples, shuffled indices
    N, L, A = train_valid.shape
    num, shuff = int(N * frac), np.random.permutation(N)
    # train, valid
    return {
      "train": train_valid[shuff[num:]].reshape((-1, L, A)),
      "valid": train_valid[shuff[:num]].reshape((-1, L, A)),
      "test": np.array(data.get("test")).astype(np.float32)
    }

In [None]:
%%bash
pip -q install gdown
gdown https://drive.google.com/uc?id=1RaH9ErtltosAEtKvokOBje2NCcye6oo3
gdown https://drive.google.com/uc?id=1odSnJIjK95a_KsNfFnZ6x3NsNEye5qa1

Downloading...
From: https://drive.google.com/uc?id=1RaH9ErtltosAEtKvokOBje2NCcye6oo3
To: /content/beta_lactamase_P62593.pkl
0.00B [00:00, ?B/s]4.72MB [00:00, 34.9MB/s]8.91MB [00:00, 33.9MB/s]42.5MB [00:00, 45.5MB/s]67.6MB [00:00, 57.3MB/s]98.0MB [00:00, 75.7MB/s]119MB [00:00, 93.5MB/s] 137MB [00:00, 107MB/s] 154MB [00:01, 114MB/s]181MB [00:01, 137MB/s]202MB [00:01, 153MB/s]227MB [00:01, 144MB/s]256MB [00:01, 169MB/s]287MB [00:01, 195MB/s]315MB [00:01, 215MB/s]340MB [00:01, 160MB/s]370MB [00:02, 185MB/s]395MB [00:02, 201MB/s]422MB [00:02, 217MB/s]446MB [00:02, 186MB/s]474MB [00:02, 206MB/s]504MB [00:02, 223MB/s]530MB [00:02, 233MB/s]555MB [00:02, 217MB/s]582MB [00:03, 231MB/s]607MB [00:03, 201MB/s]636MB [00:03, 221MB/s]660MB [00:03, 226MB/s]684MB [00:03, 218MB/s]707MB [00:03, 208MB/s]735MB [00:03, 223MB/s]763MB [00:03, 237MB/s]789MB [00:03, 242MB/s]816MB [00:04, 251MB/s]842MB [00:04, 241MB/s]869MB [00:04, 248MB/s]897MB [00:04, 256MB/s]911MB [00:04,

In [None]:
%%bash
wget -q -nc https://grigoryanlab.org/confind/confind-msl-bin.tar.gz
gunzip confind-msl-bin.tar.gz
tar -xvf confind-msl-bin.tar
rm confind-msl-bin.tar
wget -q -nc https://files.rcsb.org/view/1ERO.pdb
./confind --p 1ERO.pdb --rLib ./rotlibs | grep "contact" > meas_con.txt

confind
rotlibs/
rotlibs/EBL.out
rotlibs/BEBL.out


In [None]:
################################################################################
# define plotting / globals / settings
plt.style.use("default")
  # primary
_1_params = {
  "alphabet": "ARNDCQEGHILKMFPSTWYV-",
  # maps to "alphabet"
  "tri_alphabet": [
    "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY",
    "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER",
    "THR", "TRP", "TYR", "VAL", "GAP"
  ]
}
  # secondary
_2_params = {
  "a2i": {AA: i for i, AA in enumerate(_1_params["alphabet"])},
  "aaa2a": dict(zip(_1_params["tri_alphabet"], _1_params["alphabet"])),
  "i2a": {i: AA for i, AA in enumerate(_1_params["alphabet"])}
}
# define BLOSUM62
b62_raw = {
  "aa": np.array([
    "A", "R", "N", "D", "C", "Q", "E", "G",
    "H", "I", "L", "K", "M", "F", "P", "S",
    "T", "W", "Y", "V", "B", "Z", "X", "-"
  ]),
  "log_odds": """
    4 -1 -2 -2  0 -1 -1  0 -2 -1 -1 -1 -1 -2 -1  1  0 -3 -2  0 -2 -1  0 -4 
    -1  5  0 -2 -3  1  0 -2  0 -3 -2  2 -1 -3 -2 -1 -1 -3 -2 -3 -1  0 -1 -4 
    -2  0  6  1 -3  0  0  0  1 -3 -3  0 -2 -3 -2  1  0 -4 -2 -3  3  0 -1 -4 
    -2 -2  1  6 -3  0  2 -1 -1 -3 -4 -1 -3 -3 -1  0 -1 -4 -3 -3  4  1 -1 -4 
    0 -3 -3 -3  9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 -3 -3 -2 -4 
    -1  1  0  0 -3  5  2 -2  0 -3 -2  1  0 -3 -1  0 -1 -2 -1 -2  0  3 -1 -4 
    -1  0  0  2 -4  2  5 -2  0 -3 -3  1 -2 -3 -1  0 -1 -3 -2 -2  1  4 -1 -4 
    0 -2  0 -1 -3 -2 -2  6 -2 -4 -4 -2 -3 -3 -2  0 -2 -2 -3 -3 -1 -2 -1 -4 
    -2  0  1 -1 -3  0  0 -2  8 -3 -3 -1 -2 -1 -2 -1 -2 -2  2 -3  0  0 -1 -4 
    -1 -3 -3 -3 -1 -3 -3 -4 -3  4  2 -3  1  0 -3 -2 -1 -3 -1  3 -3 -3 -1 -4 
    -1 -2 -3 -4 -1 -2 -3 -4 -3  2  4 -2  2  0 -3 -2 -1 -2 -1  1 -4 -3 -1 -4 
    -1  2  0 -1 -3  1  1 -2 -1 -3 -2  5 -1 -3 -1  0 -1 -3 -2 -2  0  1 -1 -4 
    -1 -1 -2 -3 -1  0 -2 -3 -2  1  2 -1  5  0 -2 -1 -1 -1 -1  1 -3 -1 -1 -4 
    -2 -3 -3 -3 -2 -3 -3 -3 -1  0  0 -3  0  6 -4 -2 -2  1  3 -1 -3 -3 -1 -4 
    -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4  7 -1 -1 -4 -3 -2 -2 -1 -2 -4 
    1 -1  1  0 -1  0  0  0 -1 -2 -2  0 -1 -2 -1  4  1 -3 -2 -2  0  0  0 -4 
    0 -1  0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1  1  5 -2 -2  0 -1 -1  0 -4 
    -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1  1 -4 -3 -2 11  2 -3 -4 -3 -2 -4 
    -2 -2 -2 -3 -2 -1 -2 -3  2 -1 -1 -2 -1  3 -3 -2 -2  2  7 -1 -3 -2 -1 -4 
    0 -3 -3 -3 -1 -2 -2 -3 -3  3  1 -2  1 -1 -2 -2  0 -3 -1  4 -3 -2 -1 -4 
    -2 -1  3  4 -3  0  1 -1  0 -3 -4  0 -3 -3 -2  0 -1 -4 -3 -3  4  1 -1 -4 
    -1  0  0  1 -3  3  4 -2  0 -3 -3  1 -1 -3 -1  0 -1 -3 -2 -2  1  4 -1 -4 
    0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2  0  0 -2 -1 -1 -1 -1 -1 -4 
    -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4  1
  """
}
b62 = pd.DataFrame({
  aa: row for aa, row in zip(b62_raw["aa"], [row for row in np.array([
    x for x in b62_raw["log_odds"].replace("\n", " ").split(" ") if x != ""
  ]).reshape((24, 24))])
})[[c for c in b62_raw["aa"] if c in _1_params["alphabet"]]].iloc[[
  i for i, c in enumerate(b62_raw["aa"]) if c in _1_params["alphabet"]
]].astype("int")

In [None]:
file_data = {
  "beta_lactamase_P62593": {
    "msa": "https://drive.google.com/uc?id=1RaH9ErtltosAEtKvokOBje2NCcye6oo3",
    "dms": "https://drive.google.com/uc?id=1odSnJIjK95a_KsNfFnZ6x3NsNEye5qa1",
    "pdb": "1ERO.pdb"
  }
  # TODO: modular
}
file_pref = "beta_lactamase_P62593"
considered_dms = "Ranganathan2015"

In [None]:
%%time
msa = file_data[file_pref]["msa"]
dms = file_data[file_pref]["dms"]
pdb = file_data[file_pref]["pdb"]
load_data = {
  "msa": load_pkl(file_pref + ".pkl"),
  "dms": collate_dms(
    dms_data=load_pkl(file_pref + ".DMS.pkl"),
    wrt="y",
    considered=considered_dms
  )
}
# measured contacts dataframe
meas_con_df = pd.read_csv("meas_con.txt", sep="\t", header=None, names=[
  "kind", "i", "j", "con", "aa_i", "aa_j"
])
# initialize measured contact map data
meas_con_data = {
  "con_i_a": dict(zip(
    np.array([int(ai[2:]) for ai in np.concatenate([
      meas_con_df["i"].values, [meas_con_df["j"].values[-1]]
    ])]),
    np.array([_2_params["aaa2a"][aaa] for aaa in np.concatenate([
      meas_con_df["aa_i"].values, [meas_con_df["aa_j"].values[-1]]
    ])])
  ))
}
meas_con_data.update({
  "vals": np.zeros((
    np.max(np.array([int(x[2:]) for x in meas_con_df["j"]])) + 1,
    np.max(np.array([int(x[2:]) for x in meas_con_df["j"]])) + 1
  ))
})
# fill measured contact map data
for row in meas_con_df.itertuples():
  i, j = int(row[2][2:]), int(row[3][2:])
  meas_con_data["vals"][i, j] = row[4]
  meas_con_data["vals"][j, i] = row[4]
# valid measured contact indices            
v_con_idx = np.intersect1d(
  np.where(np.sum(meas_con_data["vals"], 0) > 0)[0],
  np.where(np.sum(meas_con_data["vals"], 0) > 0)[0]
)                                  
meas_con_data["vals"] = meas_con_data["vals"][v_con_idx, :][:, v_con_idx]
# reference sequence, measured contact sequence
raw_ref_seq = "".join([
  _2_params["i2a"][np.argmax(i)] for i in load_data["msa"]["raw"][0]
])
meas_con_seq = "".join(list(meas_con_data["con_i_a"].values()))
# map measured contact sequence against raw reference sequence
meas_con_idx = raw_ref_seq.find(meas_con_seq)
# PDB validity
assert meas_con_idx > 0, "\ncheck PDB file\n"
# viable PDB indices
meas_con_data.update({
  "good_idx": np.array([
    i for i, j in enumerate(np.arange(meas_con_idx, len(raw_ref_seq)))
      if j in load_data["msa"]["non_gap"]
  ])
})
# clean namespace
del msa, dms, v_con_idx, raw_ref_seq, meas_con_seq, meas_con_idx, row, i, j
# examine MSA / DMS data
[print("\nload_data['msa']:")] + [
  print("  ", a, b.shape) for a, b in load_data["msa"].items()]
[print("\nload_data['dms']:")] + [
  print("  ", a, b["x"].shape) for a, b in load_data["dms"].items()];
[print("\nmeas_con_data:")] + [
  print("  ", a, len(b)) for a, b in meas_con_data.items()];
print()


load_data['msa']:
   raw (10062, 286, 21)
   non_gap (252,)
   clean (10062, 252, 21)
   weights (10062,)
   phyla (10062,)
   seq_id (10062,)

load_data['dms']:
   Ranganathan2015.1 (4769, 252, 21)
   Ranganathan2015.2 (4769, 252, 21)
   Ranganathan2015.μ (4769, 252, 21)
   Palzkill2012 (4769, 252, 21)
   Tenaillon2013 (971, 252, 21)
   Ostermeier2014 (4575, 252, 21)

meas_con_data:
   con_i_a 263
   vals 263
   good_idx 252

CPU times: user 2.46 s, sys: 1.64 s, total: 4.1 s
Wall time: 4.11 s


In [None]:
# references: raw / clean
r_raw = load_data["msa"]["raw"][0][None, :]
r_clean = load_data["msa"]["clean"][0][None, :]

# cleaned one-hot seqs, effective weights
X, W = load_data["msa"]["clean"], load_data["msa"]["weights"]

# DMS one-hot seqs
dX = load_data["dms"][considered_dms + ".μ"]["x"]

# DMS effect
dy = load_data["dms"][considered_dms + ".μ"]["y"]

# DeepSeq VAE prediction
dDS = load_data["dms"][considered_dms + ".μ"]["v_μ"]

# DeepSeq pairwise prediction
dDS_pw = load_data["dms"][considered_dms + ".μ"]["pw"]

# DeepSeq indices
DS_idx = np.isfinite(load_data["dms"][considered_dms + ".μ"]["v_μ"])

In [None]:
data = {"X":X.argmax(-1), "W":W, "dX":dX.argmax(-1), "dY":dy, "cons":meas_con_data["vals"][:,meas_con_data["good_idx"]][meas_con_data["good_idx"],:]}

In [None]:
np.save("data.npy",data)