In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import pickle as pkl
import glob
import numpy as np
from scipy import stats
plt.rcParams['svg.fonttype'] = 'none'

In [None]:
###########

def compute_Li_sliding_window_div(file_name):

    with open(file_name, 'rb') as f:
        predicted_attentionA = pkl.load(f)

    attention_raw = predicted_attentionA.mean(0)
    last_non_zero_index = (attention_raw.sum(0) != 0).nonzero(as_tuple=False).max()
    attention = attention_raw[:last_non_zero_index+1, :last_non_zero_index+1]

    window_size = 5
    stride = 1
    rows, cols = attention.shape
    
    num_windows = (rows - window_size) // stride + 1

    Li_list = []
    for i in range(num_windows):
        # Calculate the starting indices for the current window along the diagonal
        start_i = i
        start_j = i

        # Calculate the ending indices for the current window
        end_i = start_i + window_size
        end_j = start_j + window_size

        window = attention[start_i:end_i, start_j:end_j]
        window_median = torch.median(window,0)[0]
        dist = torch.mean(torch.abs(torch.topk(np.transpose(window), 5).indices - torch.range(0, window_size-1).reshape(window_size, -1)),1)
        Li = window_median/dist
        window_Li = torch.mean(Li)
        Li_list.append(window_Li.reshape(1))

    Li_results = torch.cat(Li_list).numpy()
    df_Li = pd.DataFrame(data=Li_results).reset_index()

    #df_Li.to_csv(f"{file_name}.csv")
    df_Li.to_csv(f"{file_name}.{window_size}.div.csv")

In [None]:
file_list = glob.glob("*attention*long*")
for file_name in file_list:
    compute_Li_sliding_window_div(file_name)

In [None]:
gene_list = ['APP', 'g3bp1', 'hnRNPA1', 'tau', 'FUS']
window_size = 5
half_win = 2

for gene in gene_list:
    df_A = pd.read_csv(f"{gene}_WT_predicted_attentionA_long_A.{window_size}.div.csv", header=0,index_col=0)
    df_B = pd.read_csv(f"{gene}_WT_predicted_attentionB_long_B.{window_size}.div.csv", header=0,index_col=0)
    df_A_B = pd.concat([df_A,df_B], axis=1).drop('index', axis=1)
    df_A_B.columns = ['A', 'B']
    s, p = stats.ttest_ind(df_A_B['A'], df_A_B['B'])
    
    
    row, col = df_A_B.shape
    
    fig, ax = plt.subplots(1,1, figsize=(8,6), dpi=300)
    df_A_B.plot(ax=ax)
    ax.set_ylabel("Li Score")
    ax.set_xlabel("Position")
    xticks = ax.get_xticks()
    new_xtick_labels = [f"{int(label) + half_win}" for label in xticks]
    ax.set_xticklabels(new_xtick_labels)
    ax.text(0.5, 0.8, f"t test:p={p:.2f}",  transform=ax.transAxes)
    
    
    plt.savefig(f"{gene}_WT_predicted_attention_long_A_B.{window_size}.div.png")
    plt.savefig(f"{gene}_WT_predicted_attention_long_A_B.{window_size}.div.svg")
    
    
    fig, ax = plt.subplots(1,1, figsize=(8,6), dpi=300)
    df_A_B.plot(ax=ax)
    ax.set_ylabel("Li Score")
    ax.set_xlabel("Position")
    ax.set_xlim(0,50)
    xticks = ax.get_xticks()
    new_xtick_labels = [f"{int(label) + half_win}" for label in xticks]
    ax.set_xticklabels(new_xtick_labels)
    plt.savefig(f"{gene}_WT_predicted_attention_long_A_B_50_zoomin.{window_size}.div.png")
    plt.savefig(f"{gene}_WT_predicted_attention_long_A_B_50_zoomin.{window_size}.div.svg")