In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from af22c.proteome import MultipleSeqAlign
import matplotlib.pyplot as plt
from timeit import timeit
from tqdm import tqdm
from itertools import combinations, chain
import time

In [None]:
prot_path = "../data/A0A0A0MRZ7.a3m"
prot = MultipleSeqAlign.from_a3m(prot_path)

In [None]:
vocab = sorted(list(set("".join([str(prot.query_seq)] + [str(match.aligned_seq) for match in prot.matches]))))
"".join(vocab),len(vocab)

In [None]:
stoi = {c:i for i, c in enumerate(vocab)}
itos = {i:c for c, i in stoi.items()}

In [None]:
query_len = len(prot.query_seq)
num_matches = len(prot.matches)
num_seqs = num_matches + 1 # include query
query_len,num_matches,num_seqs

In [None]:
encmsa = torch.zeros((num_seqs, query_len))
for seqidx, seq in enumerate([prot.query_seq] + [match.aligned_seq for match in prot.matches]):
    for colidx, colval in enumerate(seq):
        encmsa[seqidx, colidx] = stoi[colval]
encmsa.shape

In [None]:
plt.imshow(encmsa[:100])

In [None]:
smallmsa = encmsa[:10,8:18] # (num_seqs, query_len)
# adjust variables for toy example
num_seqs, query_len = smallmsa.shape
num_matches = num_seqs - 1

In [None]:
plt.imshow(smallmsa)
plt.ylabel("match")
plt.xlabel("AA index")

In [None]:
# pairwise hamming distances
pwdists = torch.sum(smallmsa[None,:,:] != smallmsa[:,None,:], axis=-1)
# pairwise sequence identities = 1 - normalized hamming distances (each value divided by sequence length)
pwseq = 1 - pwdists / query_len

In [None]:
plt.imshow(pwseq)
plt.xlabel("target seq idx")
plt.ylabel("query seq idx")

In [None]:
pwseq[0,0],pwseq[0,3],pwseq[0,4],pwseq[0,7]

with our MSA sizes, it is probably not feasible to run sequence identity checks on the entire MSA. therefore, we need to divide the MSA into smaller batches.

In [None]:
batch_size = 4 # number of sequenes in batch
# calculate all pairs for which pairwise sequence identities need to be calculated
#pairs = torch.cartesian_prod(*(torch.arange(num_seqs),)*2)
pairs = []
for i in range(num_seqs):
    for j in range(i+1, num_seqs):
        pairs.append((i,j))
pairs = torch.tensor(pairs)
pairs.shape

In [None]:
# each batch should yield a matrix with batch_size elements
num_batches = (len(pairs) + batch_size - 1) // batch_size
num_batches

In [None]:
# one batch contains batch_size many pairs, which yields batch_size many similarity scores because the 
# similarity matrix is symmetric.
bpwseq = torch.eye(num_seqs) # matrix containing similarity scores for two sequences
for batch_idx in range(num_batches):
    # calculate similarity scores for a batch
    pairs_idx = torch.arange(batch_idx*batch_size, min((batch_idx + 1)*batch_size, len(pairs)))
    batch_pairs = pairs[pairs_idx]

    # calculate sequences in batch
    batch_seqs = torch.zeros((batch_size, 2, query_len))
    for pair_idx, (i, j) in enumerate(batch_pairs):
        # TODO: get rid of float conversion?
        batch_seqs[pair_idx, 0] = smallmsa[i].float()
        batch_seqs[pair_idx, 1] = smallmsa[j].float()

    batch_pwdists = torch.sum(batch_seqs[:,0,:] != batch_seqs[:,1,:], axis=-1)
    batch_pwseq = 1 - batch_pwdists / query_len
    for pair_idx, (i, j) in enumerate(batch_pairs):
        bpwseq[i,j] = bpwseq[j,i] = batch_pwseq[pair_idx]

In [None]:
fig,(ax1,ax2) = plt.subplots(ncols=2)
ax1.imshow(bpwseq)
ax1.set_title("batch pairwise seq identity")
ax2.imshow(pwseq)
ax2.set_title("complete MSA pairwise seq identity")

In [None]:
assert torch.allclose(bpwseq, pwseq)

In [None]:
# baseline: 30.474739761004457s
timeit(lambda: prot.compute_neff(), number=1)

In [None]:
from neff_gpu import neff
neff

In [None]:
#timeit(lambda: compute_neff(prot), number=1)
# baseline: 20.309782647993416s
timeit(lambda: neff(prot, batch_size=10000), number=1)

## Combined functionality

The cells below combine the derived Neff score calculations from the cells above to be able to evaluate small changes to the function.

In [None]:
def fn1(encmsa):
    batch_size = 4 # number of sequenes in batch
    # calculate all pairs for which pairwise sequence identities need to be calculated
    #pairs = torch.cartesian_prod(*(torch.arange(num_seqs),)*2)
    pairs = []
    for i in range(num_seqs):
        for j in range(i+1, num_seqs):
            pairs.append((i,j))
    pairs = torch.tensor(pairs)
    
    # each batch should yield a matrix with batch_size elements
    num_batches = (len(pairs) + batch_size - 1) // batch_size
    
    bpwseq = torch.eye(num_seqs) # matrix containing similarity scores for two sequences
    for batch_idx in range(num_batches):
        # calculate similarity scores for a batch
        pairs_idx = torch.arange(batch_idx*batch_size, min((batch_idx + 1)*batch_size, len(pairs)))
        batch_pairs = pairs[pairs_idx]

        # calculate sequences in batch
        batch_seqs = torch.zeros((batch_size, 2, query_len))
        for pair_idx, (i, j) in enumerate(batch_pairs):
            # TODO: get rid of float conversion?
            batch_seqs[pair_idx, 0] = encmsa[i].float()
            batch_seqs[pair_idx, 1] = encmsa[j].float()

        batch_pwdists = torch.sum(batch_seqs[:,0,:] != batch_seqs[:,1,:], axis=-1)
        batch_pwseq = 1 - batch_pwdists / query_len
        for pair_idx, (i, j) in enumerate(batch_pairs):
            bpwseq[i,j] = bpwseq[j,i] = batch_pwseq[pair_idx]
    return bpwseq
assert torch.allclose(fn1(smallmsa), pwseq)

In [None]:
def fn2(encmsa):
    batch_size = 4 # number of sequenes in batch
    # calculate all pairs for which pairwise sequence identities need to be calculated
    #pairs = torch.cartesian_prod(*(torch.arange(num_seqs),)*2)
    pairs = []
    for i in range(num_seqs):
        for j in range(i+1, num_seqs):
            pairs.append((i,j))
    pairs = torch.tensor(pairs)
    
    # each batch should yield a matrix with batch_size elements
    num_batches = (len(pairs) + batch_size - 1) // batch_size
    
    bpwseq = torch.eye(num_seqs) # matrix containing similarity scores for two sequences
    for batch_idx in range(num_batches):
        # calculate similarity scores for a batch
        pairs_idx = torch.arange(batch_idx*batch_size, min((batch_idx + 1)*batch_size, len(pairs)))
        batch_pairs = pairs[pairs_idx]
        batch_pairs_flat = batch_pairs.view(-1)
        
        # calculate sequences in batch
        batch_seqs = encmsa[batch_pairs_flat]
        batch_seqs = batch_seqs.view((-1, 2, query_len))

        # calculate pairwise distances 
        batch_pwdists = torch.sum(batch_seqs[:,0,:] != batch_seqs[:,1,:], axis=-1)
        batch_pwseq = 1 - batch_pwdists / query_len
        
        # put at right location in result matrix
        bpwseq[batch_pairs[:,0],batch_pairs[:,1]] = batch_pwseq
        bpwseq[batch_pairs[:,1],batch_pairs[:,0]] = batch_pwseq
    return bpwseq
assert torch.allclose(fn2(smallmsa), pwseq)

In [None]:
print(f"fn1: {timeit(lambda: fn1(smallmsa), number=100)}s")
print(f"fn2: {timeit(lambda: fn2(smallmsa), number=100)}s")

## Move to GPU

In [None]:
def fn3(encmsa, device=None, batch_size = 4096, **kwargs):
    num_seqs,query_len = encmsa.shape

    # calculate all pairs for which pairwise sequence identities need to be calculated
    pairs = torch.triu_indices(*(num_seqs,)*2, 1, device=device).T
    
    # each batch should yield a matrix with batch_size elements
    num_batches = (len(pairs) + batch_size - 1) // batch_size
    
    bpwseq = torch.eye(num_seqs, device=device) # matrix containing similarity scores for two sequences
    for batch_idx in tqdm(range(num_batches), desc="running batches"):
        # calculate similarity scores for a batch
        pairs_idx = torch.arange(batch_idx*batch_size, min((batch_idx + 1)*batch_size, len(pairs)))
        batch_pairs = pairs[pairs_idx]
        batch_pairs_flat = batch_pairs.view(-1)
        
        # calculate sequences in batch
        batch_seqs = encmsa[batch_pairs_flat]
        batch_seqs = batch_seqs.view((-1, 2, query_len))

        # calculate pairwise distances 
        batch_pwdists = torch.sum(batch_seqs[:,0,:] != batch_seqs[:,1,:], axis=-1)
        batch_pwseq = 1 - batch_pwdists / query_len
        
        # put at right location in result matrix
        bpwseq[batch_pairs[:,0],batch_pairs[:,1]] = batch_pwseq
        bpwseq[batch_pairs[:,1],batch_pairs[:,0]] = batch_pwseq
    return bpwseq

In [None]:
def gapcount(encmsa, weights=None, gaptok=None, stoi=None, nongap=False, **kwargs):
    """
    calculate number of gaps in each msa column. returns a vector of length query_len.
    weights can be specified to weight each sequence in the msa column.
    if nongap=False (default), gaps will be counted; otherwise non-gaps will be counted.
    the gaps are identified by gaptok, which can be the token id for gaps. alternatively, a
    dictionary stoi can be supplied, where the gaptok is looked up.
    """
    if gaptok is None:
        assert stoi is not None
        gaptok = stoi['-']
    gapindicator = encmsa != gaptok if nongap else encmsa == gaptok
    if weights is not None:
        return weights @ gapindicator.float()
    return torch.sum(gapindicator, dim=0)

def neff(encmsa, pwseqfn=fn3, seqid_thres=0.8, **kwargs):
    """
    calculate neff scores for an encoded msa.
    """
    num_seqs, query_len = encmsa.shape
    pwseq = pwseqfn(encmsa, **kwargs)
    # calculate neff weights (dim can be 0 or 1, does not matter because pwseq is symmetric)
    neffweights = 1 / torch.sum(pwseq >= seqid_thres, dim=0)
    return gapcount(encmsa, weights=neffweights, nongap=True, **kwargs)

#neff_fast = neff(encmsa)

## Accellerating file loading

In the new neff score calculation, a lot of time is spent encoding the MSA. Can we combine file loading and encoding?

In [None]:
neffref = prot.compute_neff()

In [None]:
from string import ascii_lowercase
def loadmsa(path):
    """load an MSA as an encoded tensor"""
    with open(path) as f:
        lines = f.readlines()
        assert not len(lines) & 1, f"MSA at {path} should contain an even number of lines"
        query_id = lines[0].split()[0]
        query_len = len([() for ch in lines[1].strip() if not ch in ascii_lowercase])
        
        # count header lines that are not just the query ID (1+ because first query is included in encoded msa)
        num_seqs = 1 + len([1 for line in lines[::2] if not line[:-1] == query_id and line.startswith(">")])
        
        # encode msa from lines
        encmsa = torch.zeros((num_seqs,query_len), dtype=torch.uint8)
        idx = 0
        for hdr, seq in zip(lines[::2], lines[1::2]):
            # skip query that appears multiple times in the file
            if idx > 1 and hdr[:-1] == query_id:
                continue
            # encode sequence
            encmsa[idx,:] = torch.tensor(
                # NB: stoi comes from outside!
                [stoi[ch] for ch in seq[:-1] if not ch in ascii_lowercase],
                dtype=torch.uint8,
            )
            idx += 1
        return encmsa

In [None]:
loadedmsa=loadmsa(prot_path)
loadedmsa.shape,len(prot.matches)+1

In [None]:
# avg 5.556711397999607s per run
def f():
    loadedmsa=loadmsa(prot_path)
    return neff(loadedmsa, stoi=stoi)
timeit(f,number=4)

In [None]:
nefffast = f()

In [None]:
print(torch.tensor(neffref).float() - nefffast)
torch.allclose(torch.tensor(neffref).float(), nefffast) # yields true, are they "equal enough"? :3

In [None]:
plt.plot(neffref)
plt.plot(nefffast)

In [None]:
# baseline: 15.939979004993802s
# with better pair generation: 4.461970445991028s
timeit(lambda: neff(loadmsa(prot_path), stoi=stoi), number=1)

## Generating pairs

The second thing taking a long time is the generating a list of pairs that need to be processed. This can be done faster...

In [None]:
n,m=100,100
def combipairs(num_seqs):
    return list(combinations(range(num_seqs),2))
def triupairs(num_seqs):
    return torch.triu_indices(*(num_seqs,)*2,1).T

In [None]:
timeit(lambda: combipairs(n), number=1000)

In [None]:
timeit(lambda: triupairs(n), number=1000)

In [None]:
combipairs(4),triupairs(4)

In [None]:
device = "cuda"
encmsa = loadmsa(prot_path).to(device)
neff(encmsa, device=device)

In [None]:
batch_size = 5
num_seqs = 5

pairs = torch.triu_indices(*(num_seqs,)*2,1).T
num_full_batches = len(pairs) // batch_size
batch_pairs = pairs[:-(len(pairs)%batch_size)]
rest_pairs = pairs[-(len(pairs)%batch_size):]
pairs.shape,batch_pairs.shape,rest_pairs.shape

In [None]:
pairs,batch_pairs,rest_pairs

In [None]:
ptr = lambda t: t.storage().data_ptr()
ptr(pairs),ptr(batch_pairs),ptr(rest_pairs)

In [None]:
ptr(batch_pairs.view(num_full_batches, -1, 2))

In [None]:
for bp in chain(batch_pairs.view(num_full_batches, -1, 2), [rest_pairs]):
    print(ptr(bp.view(-1)))

compare the performance of two versions of the pwseq function.

In [None]:
# baseline
def pwseq1(encmsa, device=None, batch_size=2**18, **kwargs):
  """return pairwise sequence identity calculated with pytorch"""
  num_seqs,query_len = encmsa.shape
  
  # calculate all pairs for which pairwise sequence identities need to be calculated
  pairs = torch.triu_indices(*(num_seqs,)*2, 1, device=device).T
  
  # each batch should yield a matrix with batch_size elements
  num_batches = (len(pairs) + batch_size - 1) // batch_size
  
  # one batch contains batch_size many pairs, which yields batch_size many similarity scores because the 
  # similarity matrix is symmetric.
  bpwseq = torch.eye(num_seqs, device=device) # matrix containing similarity scores for two sequences
  for batch_idx in tqdm(range(num_batches), desc="running batches"):
    # calculate similarity scores for a batch
    pairs_idx = torch.arange(batch_idx*batch_size, min((batch_idx + 1)*batch_size, len(pairs)))
    batch_pairs = pairs[pairs_idx]
    batch_pairs_flat = batch_pairs.view(-1)
    
    # calculate sequences in batch
    batch_seqs = encmsa[batch_pairs_flat]
    batch_seqs = batch_seqs.view(-1, 2, query_len)

    # calculate pairwise distances 
    batch_pwdists = torch.sum(batch_seqs[:,0,:] != batch_seqs[:,1,:], axis=-1)
    batch_pwseq = 1 - batch_pwdists / query_len
    
    # put at right location in result matrix (and make symmetric)
    bpwseq[batch_pairs[:,0],batch_pairs[:,1]] = batch_pwseq
    bpwseq[batch_pairs[:,1],batch_pairs[:,0]] = batch_pwseq
  return bpwseq

In [None]:
def pwseq2(encmsa, device=None, batch_size=2**12, **kwargs):
  """return pairwise sequence identity calculated with pytorch"""
  num_seqs,query_len = encmsa.shape
  
  # calculate all pairs for which pairwise sequence identities need to be calculated
  pairs = torch.triu_indices(*(num_seqs,)*2, 1, device=device).T
  
  # each batch should yield a matrix with batch_size elements
  num_batches = (len(pairs) + batch_size - 1) // batch_size
  num_full_batches = len(pairs) // batch_size
  
  batch_pairs = pairs[:-(len(pairs)%batch_size)]
  rest_pairs = pairs[-(len(pairs)%batch_size):]

  checkpoints = "seq_extract pwdists putback"
  t = {name: 0 for name in checkpoints.split()}
  
  # put pairs into batches
  batches = batch_pairs.view(num_full_batches, -1, 2)
  if num_batches != num_full_batches:
    batches = chain(batches, [rest_pairs])
    
  # one batch contains batch_size many pairs, which yields batch_size many similarity scores because the 
  # similarity matrix is symmetric.
  bpwseq = torch.eye(num_seqs, device=device) # matrix containing similarity scores for two sequences
  for batch_pairs in tqdm(batches, total=num_batches, desc="running batches"):
    # view batch_pairs as a flat array
    #start = time.perf_counter()
    #print(f"{batch_pairs.shape=}")
    #batch_pairs_flat = batch_pairs.view(-1)
    #break
    #end = time.perf_counter()
    #t["pairflatten"] += end - start
    
    # extract sequences in batch
    start = time.perf_counter()
    #batch_seqs = encmsa[batch_pairs_flat]
    batch_seqs = encmsa[batch_pairs]
    #batch_seqs = batch_seqs.view(-1, 2, query_len)
    #print(f"{batch_seqs.shape}")
    end = time.perf_counter()
    t["seq_extract"] += end - start

    # calculate pairwise distances 
    start = time.perf_counter()
    batch_pwdists = torch.sum(batch_seqs[:,0,:] != batch_seqs[:,1,:], axis=-1)
    batch_pwseq = 1 - batch_pwdists / query_len
    end = time.perf_counter()
    t["pwdists"] += end - start
    
    # put at right location in result matrix (and make symmetric)
    start = time.perf_counter()
    bpwseq[batch_pairs[:,0],batch_pairs[:,1]] = batch_pwseq
    bpwseq[batch_pairs[:,1],batch_pairs[:,0]] = batch_pwseq
    end = time.perf_counter()
    t["putback"] += end - start
  
  c = sum(t.values())
  for name, total in t.items():
    print(f"{name}: {total}s={100*total/c:.01f}%, {total/num_batches}s/batch")
  print(f"sum={c}")
  
  return bpwseq

In [None]:
pwseq2(encmsa, batch_size=batch_size)

In [None]:
encmsa = loadmsa("../data/A0A0A0MRZ7.a3m")
encmsa.shape

In [None]:
batch_size=2**12

In [None]:
print("baseline", timeit(lambda: pwseq1(encmsa, batch_size=batch_size), number=3))

In [None]:
print("version 2", timeit(lambda: pwseq2(encmsa, batch_size=batch_size), number=3))