In [2]:
import sys
sys.path.append('../ESM')  # Path to the ESM folder
import numpy as np
import pickle
import scipy.fft as fft
from pathlib import Path
import pandas as pd
import re
import os
import torch
import matplotlib.pyplot as plt  
import random 
from itertools import product
from tqdm import tqdm
random.seed(42)
np.random.seed(42)

# qsft and ESM-specific packages
from src.utils import get_file_path, get_protein_path, load_esm_model
# Packaged from https://github.com/amirgroup-codes/InteractionRecovery
from fourier_utils import gwht, extract_wildtype, sampling_function, sampling




In [24]:
ALL_AAS = ("A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y")
q = 20

all_proteins = []
all_positions = []

ruggedness = []
ruggedness_by_order = []

sparsity = []
sparsity_by_order = []

df_protein_meta = pd.read_csv("ArnoldDatasetsMeta.csv")
df_protein_meta.head()
for index, row in df_protein_meta.iterrows():
    protein = row['proteinName']
    n = int(row['numPositions'])
    print(protein, q, n)
    all_proteins.append(protein)
    all_positions.append(n)      
        
    """
    Fourier transform and coefficients
    Adapted from https://github.com/amirgroup-codes/InteractionRecovery
    """
    samples_ESM = np.load(f'TG_results/{protein}_q{q}_n{n}.npy')
    F_k = np.real(gwht(samples_ESM,q,n))
    # Sort Fourier coefficients
    samples = np.array(list(product(range(q), repeat=n))) 
    order = np.count_nonzero(samples, axis=1)
    sorted_indices = np.argsort(order)
    order_sorted = np.sort(order)
    F_k = F_k[sorted_indices]
    # Plot Fourier spectrum
    plt.figure()
    fig, ax = plt.subplots(figsize=(10,4))
    ax.plot(F_k)
    ax.axhline(np.mean(F_k), color='C1', linestyle='--')
    plt.title(f'{protein}')
    fig.savefig(f'TG_results/FFT_{protein}_q{q}_n{n}.png')
    plt.close('all')
    # Cropped Fourier 
    all_std = np.std(np.abs(F_k[1:]))/n
    plt.figure()
    fig, ax = plt.subplots(figsize=(10,4))
    ax.plot(F_k[1:])
    breakpoints = np.where(np.diff(order_sorted) == 1)[0] + 1
    # Plot vertical dotted lines at the breakpoints
    for breakpoint in breakpoints:
        ax.axvline(breakpoint, color='red', linestyle='--')
    last_index_highest_order = np.where(order_sorted == n)[0][-1]
    ax.axhline(all_std, color='C2', linestyle='--')
    ax.axhline(-all_std, color='C2', linestyle='--')
    ax.set_xlim([1, last_index_highest_order])
    ax.set_title(f'{protein} up to {n} interactions')
    fig.savefig(f'TG_results/croppedFFT_{protein}_q{q}_n{n}.png')
    plt.close('all')

    """
    Sparsity Calculation
    Adapted from https://github.com/amirgroup-codes/InteractionRecovery
    """
    mean_all = np.mean(np.abs(F_k[1:]))
    std_all = np.std(np.abs(F_k[1:]))
    coeffs = []
    sparsity_j = 0
    total_points = 0
    for j in range(1,np.max(order_sorted)+1):
        section_indices = order_sorted == j
        total_points_above_mean = np.sum(np.abs(F_k[section_indices]) > all_std)
        if j <= n: # Cap at 5
            sparsity_j += total_points_above_mean
            total_points += len(F_k[section_indices])
        coeffs.append(total_points_above_mean/len(F_k[section_indices]))
    sparsity_norm = sparsity_j/total_points
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.bar(np.arange(1,n+1), coeffs[0:n], edgecolor='black', align='edge', color='#00a087ff')
    ax.set_xlabel('$k^{th}$ order interactions')
    ax.set_ylabel('Fraction of nonzero coefficients')
    ax.set_xticks(np.arange(1,n+1))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.title(f'Sparsity: {round(sparsity_norm, 3)}')
    fig.savefig(f'TG_results/sparsity_{protein}_q{q}_n{n}.png')
    plt.close('all')
    sparsity.append(sparsity_norm)
    sparsity_by_order.append(coeffs[0:n])

    """
    Ruggedness Calculation
    Adapted from https://github.com/amirgroup-codes/InteractionRecovery
    """
    # Percent variance explained
    var = []
    total_var = sum(F_k[1:]**2)
    for j in range(1,n+1):
        section_indices = order_sorted == j
        order_var = sum(F_k[section_indices]**2)
        var.append(order_var/total_var*100)
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.bar(np.arange(1,n+1), var, edgecolor='black', align='edge', color='#00a087ff')
    ax.set_xlabel('$k^{th}$ order interactions')
    ax.set_ylabel('% Variance Explained')
    ax.set_xticks(np.arange(1,n+1))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    weighted_avg = np.average(np.arange(1,n+1), weights=var)
    plt.title(f'Ruggedness: {round(weighted_avg,3)}')
    fig.savefig(f'TG_results/ruggedness_{protein}_q{q}_n{n}.png')
    plt.close('all')
    ruggedness.append(weighted_avg)
    ruggedness_by_order.append(var)

    print(sparsity_norm, coeffs[0:n], weighted_avg, var)


data = {'Protein': all_proteins,
        'NumPositions': all_positions,
        'Sparsity': sparsity,
        'Sparsity by Interaction Order': sparsity_by_order,
        'Ruggedness': ruggedness,
        'Ruggedness by Interaction Order': ruggedness_by_order
    }

df = pd.DataFrame(data)
df.to_csv('TG_results/InteractionOrderSummary.csv', index=False)



DHFR 20 3
0.19814976872109014 [0.9649122807017544, 0.6094182825484764, 0.1268406473246829] 1.402597958512278 [63.591989818985084, 32.556224510802316, 3.851785670212789]
GB1 20 4
0.44529028306426915 [1.0, 0.9344413665743305, 0.7452981484181368, 0.3736773045019605] 2.223795314980473 [20.820700889919504, 43.00528184760371, 29.147802136987494, 7.026215125489707]
ParD2 20 3
0.25415676959619954 [0.9649122807017544, 0.6842105263157895, 0.18034698935704913] 1.2459149823383266 [80.17004807726202, 15.068405611642977, 4.7615463110948175]
ParD3 20 3
0.26428303537942244 [1.0, 0.7054478301015698, 0.18851144481702872] 1.2262922467979003 [82.62672097821459, 12.117333363780517, 5.255945658004739]
TrpB3A 20 3
0.7775971996499562 [0.9298245614035088, 0.8070175438596491, 0.7716868348155708] 2.387636948783652 [22.023400129768, 17.189504862098616, 60.78709500813313]
TrpB3B 20 3
0.8277284660582572 [0.9824561403508771, 0.8587257617728532, 0.8215483306604461] 2.676437339248935 [8.070509177633626, 16.21524771983

In [25]:
df_summary = pd.read_csv("TG_results/InteractionOrderSummary.csv")
df_summary.head()
for index, row in df_summary.iterrows():
    print(row['Protein'], row['NumPositions'])
    print('\tSparsity:', row['Sparsity'])
    print('\tSparsity by Interaction Order:', row['Sparsity by Interaction Order'])
    print('\tRuggedness:', row['Ruggedness'])
    print('\tRuggedness by Interaction Order:', row['Ruggedness by Interaction Order'])

DHFR 3
	Sparsity: 0.1981497687210901
	Sparsity by Interaction Order: [0.9649122807017544, 0.6094182825484764, 0.1268406473246829]
	Ruggedness: 1.402597958512278
	Ruggedness by Interaction Order: [63.591989818985084, 32.556224510802316, 3.851785670212789]
GB1 4
	Sparsity: 0.4452902830642691
	Sparsity by Interaction Order: [1.0, 0.9344413665743305, 0.7452981484181368, 0.3736773045019605]
	Ruggedness: 2.223795314980473
	Ruggedness by Interaction Order: [20.820700889919504, 43.00528184760371, 29.147802136987494, 7.026215125489707]
ParD2 3
	Sparsity: 0.2541567695961995
	Sparsity by Interaction Order: [0.9649122807017544, 0.6842105263157895, 0.18034698935704913]
	Ruggedness: 1.2459149823383266
	Ruggedness by Interaction Order: [80.17004807726202, 15.068405611642977, 4.7615463110948175]
ParD3 3
	Sparsity: 0.2642830353794224
	Sparsity by Interaction Order: [1.0, 0.7054478301015698, 0.18851144481702872]
	Ruggedness: 1.2262922467979005
	Ruggedness by Interaction Order: [82.62672097821459, 12.117