In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
from Bio import SeqIO
import matplotlib.pyplot as plt
import pandas as pd
from numba import prange, jit

#######################################################
# 1) Helper function: read_tensor_from_txt
#######################################################
def read_tensor_from_txt(filename):
    with open(filename, 'r') as f:
        lines = f.readlines()

    # Read the dimensions from the first line
    dims = list(map(int, lines[0].strip().split()))
    
    tensor_data = []
    current_slice = []
    for line in lines[1:]:
        line = line.strip()
        if line.startswith("Slice"):
            if current_slice:
                tensor_data.append(current_slice)
                current_slice = []
        elif line:
            current_slice.append(list(map(float, line.split(','))))
    if current_slice:
        tensor_data.append(current_slice)

    tensor = torch.tensor(tensor_data).view(*dims)
    return tensor


#######################################################
# 2) The multi-domain sub-block model
#######################################################
class MultiDomainAttentionSubBlock(nn.Module):
    """
    A single-layer attention model that splits H heads among 3 groups:
      - Domain1 heads => Q1,K1,V1
      - Domain2 heads => Q2,K2,V2
      - Inter-domain heads => Qint1,Kint1,Vint1 (domain1->domain2),
                              Qint2,Kint2,Vint2 (domain2->domain1)
    """
    def __init__(
        self,
        H=32,          # total heads
        d=23,          # dimension for Q,K
        N=176,         # total protein length
        q=22,          # amino-acid alphabet
        lambd=0.001,
        domain1_end=63, 
        H1=10, 
        H2=20, 
        device='cpu'
    ):
        super().__init__()
        self.H = H
        self.d = d
        self.N = N
        self.q = q
        self.lambd = lambd
        self.device = device

        # domain1 is [0..domain1_end], length => N_alpha
        self.domain1_end = domain1_end
        self.domain2_start = domain1_end + 1
        self.N_alpha = self.domain2_start
        self.N_beta = self.N - self.N_alpha

        self.H1 = H1
        self.H2 = H2

        # (1) Domain1 heads
        self.Q1 = nn.Parameter(torch.randn(H1, d, self.N_alpha, device=self.device))
        self.K1 = nn.Parameter(torch.randn(H1, d, self.N_alpha, device=self.device))
        self.V1 = nn.Parameter(torch.randn(H1, q, q, device=self.device))

        # (2) Domain2 heads
        num_dom2_heads = self.H2 - self.H1
        self.Q2 = nn.Parameter(torch.randn(num_dom2_heads, d, self.N_beta, device=self.device))
        self.K2 = nn.Parameter(torch.randn(num_dom2_heads, d, self.N_beta, device=self.device))
        self.V2 = nn.Parameter(torch.randn(num_dom2_heads, q, q, device=self.device))

        # (3) Inter-domain heads
        num_inter_heads = self.H - self.H2
        num_inter_heads1 = num_inter_heads // 2  # domain1->domain2
        num_inter_heads2 = num_inter_heads - num_inter_heads1  # domain2->domain1

        self.Qint1 = nn.Parameter(torch.randn(num_inter_heads1, d, self.N_alpha, device=self.device))
        self.Kint1 = nn.Parameter(torch.randn(num_inter_heads1, d, self.N_beta, device=self.device))
        self.Vint1 = nn.Parameter(torch.randn(num_inter_heads1, q, q, device=self.device))

        self.Qint2 = nn.Parameter(torch.randn(num_inter_heads2, d, self.N_beta, device=self.device))
        self.Kint2 = nn.Parameter(torch.randn(num_inter_heads2, d, self.N_alpha, device=self.device))
        self.Vint2 = nn.Parameter(torch.randn(num_inter_heads2, q, q, device=self.device))

    def forward(self, Z, weights, head_group='all'):
        """
        Normally you do partial updates. We'll skip the forward logic here 
        since we only want the sub-block parameters for the energy test.
        """
        return None


#######################################################
# 3) A function to compute "domain1->domain2" 
#    inter-domain energy for a single seq
#######################################################
def compute_energy_inter_subblock_domain12(seqAB, model):
    """
    seqAB: shape (model.N_alpha + model.N_beta,) of amino-acid indices
           The first model.N_alpha = domain1, next model.N_beta = domain2

    Returns a scalar float of the "inter-domain" energy for domain1->domain2
    based on Qint1,Kint1,Vint1, using a vectorized approach.
    """
    # 1) Split the sequence
    seqA = seqAB[: model.N_alpha]  # domain1
    seqB = seqAB[model.N_alpha:]   # domain2

    # 2) Build raw logits => e_sel: shape (L1, L2, #heads)
    Q_sel = model.Qint1    # (#heads_int1, d, L1)
    K_sel = model.Kint1    # (#heads_int1, d, L2)
    V_sel = model.Vint1    # (#heads_int1, q, q)
    e_sel = torch.einsum('hdi,hdj->ijh', Q_sel, K_sel)  # (L1,L2,#heads_int1)

    # 3) Softmax over j => shape (L1,L2,#heads_int1)
    sf = torch.softmax(e_sel, dim=1)

    # 4) Vectorized gather of V_sel[h, a_i, a_j]
    #    - We'll permute sf to (h, L1, L2)
    sf_t = sf.permute(2, 0, 1)  # => (num_heads, L1, L2)

    # (a) Reshape seqA and seqB to enable broadcast
    #     seqA_2D => shape (L1,1), seqB_2D => shape(1,L2)
    seqA_2D = seqA.view(-1,1)  # shape (L1,1)
    seqB_2D = seqB.view(1,-1)  # shape (1,L2)

    # (b) For each head h, we want V_sel[h][seqA_2D, seqB_2D]
    #     That yields shape (L1,L2). We'll stack these into (num_heads, L1, L2).
    #     We can do it with a list comprehension and torch.stack:
    all_vals = []
    num_heads = V_sel.shape[0]
    for h in range(num_heads):
        # gather => shape (L1, L2)
        # indexing with (seqA_2D, seqB_2D) picks V_sel[h, a_i, a_j] for each i,j
        vals_h = V_sel[h][seqA_2D, seqB_2D]  
        all_vals.append(vals_h)

    # shape => (num_heads, L1, L2)
    VAL = torch.stack(all_vals, dim=0)

    # 5) Multiply (num_heads, L1, L2) by sf_t => (num_heads, L1, L2) and sum
    energy_tensor = sf_t * VAL
    # Single .sum() to get final scalar
    energy_val = energy_tensor.sum().item()

    return energy_val



#######################################################
# 4) Example usage: 
#    Compare "correct vs. random pairs" in a paired MSA
#######################################################

from Bio import SeqIO

tokens_protein = "ACDEFGHIKLMNPQRSTVWY-"

def encode_sequence_lore(sequence : str, tokens : str=tokens_protein):
    letter_map = {l : n for n, l in enumerate(tokens)}
    return np.array([letter_map[l] for l in sequence])

def tokenize_seqs(seqs, tokens=tokens_protein):
    # seqs: shape (N, L)
    seqs=seqs.astype(str)
    seqs_str = np.array(["".join(row) for row in seqs])
    seqs_tks = []
    for s in seqs_str:
        arr = encode_sequence_lore(s, tokens=tokens)
        seqs_tks.append(arr)
    return np.array(seqs_tks)

def paired_fasta_to_labeled_array(fasta_name, L_A):
    """
    Loads a FASTA of domainA||domainB concatenated. 
    Returns two arrays: MSA_A_label, MSA_B_label with shape (N, L_A) or (N, L_B).
    Each row has 2 'label' columns in front if needed, or you can skip them.
    For brevity, let's keep it simple.
    """
    seqs=[]
    IDs=[]
    for record in SeqIO.parse(fasta_name, "fasta"):
        seqs.append(record.seq)
        IDs.append(record.description)
    seqs=np.array(seqs)

    MSA_A = seqs[:,:L_A]  # domain1
    MSA_B = seqs[:,L_A:]  # domain2
    return MSA_A, MSA_B

@jit(nopython=True, parallel=True)
def test_interdomain_energy(model, MSA_A, MSA_B, correct_criterion="index"):
    """
    MSA_A: shape (N, L1)
    MSA_B: shape (N, L2)
    If correct_criterion="index", we assume row i in A is correct partner of row i in B.
    We test all pairs => (i,j). 
    """
    L1 = model.N_alpha
    L2 = model.N_beta

    # tokenize
    A_tks = tokenize_seqs(MSA_A)  # => shape(N,L1)
    B_tks = tokenize_seqs(MSA_B)  # => shape(N,L2)

    scores = []
    correct_flags = []
    N_A = A_tks.shape[0]
    N_B = B_tks.shape[0]

    for i in range(N_A):
        seqA = A_tks[i]
        for j in range(N_B):
            seqB = B_tks[j]
            # concatenate
            seqAB = torch.tensor(np.concatenate([seqA, seqB]), dtype=torch.long)
            # compute domain1->domain2 energy
            val = compute_energy_inter_subblock_domain12(seqAB, model)

            scores.append(val)
            if correct_criterion=="index":
                # correct if i==j
                is_correct = (i==j)
            else:
                # or do a species label check, etc.
                is_correct = False
            correct_flags.append(is_correct)

    scores = np.array(scores)
    correct_flags = np.array(correct_flags)

    corr_scores = scores[correct_flags]
    incorr_scores = scores[~correct_flags]

    return corr_scores, incorr_scores


#######################################################
# 5) Putting it all together in "main"
#######################################################

if __name__ == "__main__":
    import sys

    # (A) Suppose we read Qint1,Kint1,Vint1,... from disk
    cwd = os.getcwd()
    H=60
    d=12
    N=174
    q=22
    domain1_end=63
    H1=25
    H2=H1+15
    family = "HKRR_25_15_20_but_with_new_NEW_model_sep_optimizer_newreg_withHKRRtrainingfasta_d12_500batch"
    n_epochs=500
    loss_type="without_J"

    # read them:
    Qint1 = read_tensor_from_txt( f"{cwd}/results/{H}_{d}_{family}_{loss_type}_{n_epochs}/Qint1_tensor.txt" )
    Kint1 = read_tensor_from_txt( f"{cwd}/results/{H}_{d}_{family}_{loss_type}_{n_epochs}/Kint1_tensor.txt" )
    Vint1 = read_tensor_from_txt( f"{cwd}/results/{H}_{d}_{family}_{loss_type}_{n_epochs}/Vint1_tensor.txt" )

    Qint2 = read_tensor_from_txt( f"{cwd}/results/{H}_{d}_{family}_{loss_type}_{n_epochs}/Qint2_tensor.txt" )
    Kint2 = read_tensor_from_txt( f"{cwd}/results/{H}_{d}_{family}_{loss_type}_{n_epochs}/Kint2_tensor.txt" )
    Vint2 = read_tensor_from_txt( f"{cwd}/results/{H}_{d}_{family}_{loss_type}_{n_epochs}/Vint2_tensor.txt" )

    device='cpu'
    model = MultiDomainAttentionSubBlock(
        H=H, d=d, N=N, q=q,
        domain1_end=domain1_end,
        H1=H1, H2=H2,
        device=device
    ).to(device)

    # assign them:
    model.Qint1.data = Qint1
    model.Kint1.data = Kint1
    model.Vint1.data = Vint1
    model.Qint2.data = Qint2
    model.Kint2.data = Kint2
    model.Vint2.data = Vint2

    print("Loaded model with separate heads for domain1, domain2, interdomain")

    # (B) We load or define the MSA test data
    # e.g. a 'paired' FASTA with domain1+domain2
    # or you can have domainA.fasta, domainB.fasta if separate
    # We'll assume it's domain1||domain2 in one file

    L_A = model.N_alpha  # 64
    model.to('cuda')
    test_fasta = f"{cwd}/CODE/DataAttentionDCA/data/lisa_data/HK-RR_174_test.fasta"
    MSA_A, MSA_B = paired_fasta_to_labeled_array(test_fasta, L_A=L_A)
    # MSA_A => shape (N, L_A), MSA_B => shape(N, L_B)

    # (C) get correct vs. incorrect energies
    corr_scores, incorr_scores = test_interdomain_energy(
        model,
        MSA_A,
        MSA_B,
        correct_criterion="index"
    )

    # (D) Plot hist
    import seaborn as sns
    mypalette=sns.color_palette("Set2")
    plt.figure()
    bin_width=5
    def_val=1e5  # if you used a default penalty for mismatched species, etc.

    # Exclude any huge def_val if you do species checks
    incorr_scores_filtered = incorr_scores[incorr_scores < def_val]

    minval = min(corr_scores.min(), incorr_scores_filtered.min())
    maxval = max(corr_scores.max(), incorr_scores_filtered.max())
    bins = np.arange(minval, maxval+bin_width, bin_width)

    plt.hist(corr_scores, bins=bins, alpha=0.5, label="Correct Pairs", color=mypalette[0], density=True)
    plt.hist(incorr_scores_filtered, bins=bins, alpha=0.5, label="Incorrect Pairs", color=mypalette[1], density=True)

    plt.xlabel("Inter-domain energy (domain1->domain2)")
    plt.ylabel("Density")
    plt.title(f"Comparison of correct vs. random pairs, H={H}, d={d}")
    plt.legend()
    plt.show()


Loaded model with separate heads for domain1, domain2, interdomain


TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'tokenize_seqs': Cannot determine Numba type of <class 'function'>

File "../../../../../../../tmp/ipykernel_15632/1050797239.py", line 219:
<source missing, REPL/exec in use?> 

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class '__main__.MultiDomainAttentionSubBlock'>
