## Feature Attribution
Run feature attribution for selected region and plot logos

In [None]:
from AttributionModelWrapper import AttributionModelWrapper
from load_model import load_model
from eval_model import get_model_structure
from tangermeme.utils import random_one_hot
from tangermeme.deep_lift_shap import deep_lift_shap
from tangermeme.ersatz import substitute
import torch
from AttributionModelWrapper import AttributionModelWrapper
from matplotlib import pyplot as plt
import seaborn; seaborn.set_style('whitegrid')
from tangermeme.plot import plot_logo
from tangermeme.ism import saturation_mutagenesis
import numpy as np
from utils import load_observed, load_names
from combine_results import combine_trial_metrics, find_files
import os

In [None]:

model1_dir = '/data/nchand/analysis/BPcm/BP68_L0_0/'
model2_dir = '/data/nchand/analysis/BPcm/BP68_L-1_5/'
output_dir = '/data/nchand/analysis/BPcm/BP68_analysis'
difference_label ='Difference in Pearson correlation at lambda=0.5 and lambda=0.0'
n_trials_to_use = 5
label1='l=0.0'
label2='l=0.5'
model_type = 'BPcm'

n_celltypes = 90
n_filters = 300
bin_size= 1
bin_pooling_type = None
scalar_head_fc_layers = 1

eval_set = 'validation'
info_path = '/data/nchand/ImmGen/mouse/BPprofiles1000/memmaped/complete_bias_corrected_normalized_3.7.23/memmap/info.txt'
peak_names = load_names('/data/nchand/ImmGen/mouse/BPprofiles1000/memmaped/complete_bias_corrected_normalized_3.7.23/memmap/info.txt', eval_set)
cell_names = np.load("/data/nchand/ImmGen/mouse/BPprofiles1000/ImmGenATAC1219.peak_matched_in_sorted.sl10004sh-4.celltypes.npy")


analysis_file_name = 'validation_analysis.npz'
model_structure = get_model_structure(model_type, n_filters, n_celltypes)
new_analysis = True
os.makedirs(output_dir, exist_ok=True)


# for testing 
path_model1 = find_files(model1_dir, 'best_model')[0]
model1 = load_model(path_model1, model_structure=model_structure, n_filters=300, verbose=False)


# start with just one region- the first one that comes up 
selected_idx = np.load('selected_idx.npy')
region_idx = selected_idx[10]
region_name = peak_names[region_idx]

# find the one hot encoding of this region
onehot_encoding = load_observed(info_file=info_path, dataset_type=eval_set, data_name= 'onehot')
print('onehot shape', onehot_encoding.shape)
region_onehot = np.array(onehot_encoding[region_idx]) # convert from memmap -> numpy array
print(region_onehot.shape)
region_onehot = np.transpose(region_onehot, axes=(1, 0))
print(region_onehot.shape)

# find celltypes with the min, median, and max observed scalar counts
obs_total_counts = load_observed(info_path, eval_set, 'total_counts')
tc = obs_total_counts[region_idx]
tc.shape
min_cell_index = np.argmin(tc)
print("min tc", np.min(tc))
max_cell_index = np.argmax(tc)
print("max tc", np.max(tc))
median = np.median(tc)
print("median", median)
median_cell_index = np.abs(tc - median).argmin()
cell_idx = [min_cell_index, median_cell_index, max_cell_index]


# In[ ]:


# first we have to find the best performing trial for each of the models, so we know which one to use for attribution
# find the average pearson correlation for each trial
trial_idx_model1 = np.argmax(np.mean(model1_corr, axis=0))
trial_idx_model2 = np.argmax(np.mean(model2_corr, axis=0))
# Now we have to figure out what the directory of this trial is 
path_model1 = find_files(model1_dir, 'best_model')[trial_idx_model1]
path_model2 = find_files(model2_dir, 'best_model')[trial_idx_model2]    

In [None]:
# load the model
model1 = load_model(path_model1, model_structure=model_structure, n_filters=300, verbose=False)
model2 = load_model(path_model2, model_structure=model_structure, n_filters=300, verbose=False)
wrapped_model1, wrapped_model2 = AttributionModelWrapper(model1), AttributionModelWrapper(model2)


In [None]:
def get_logo(X_attr, ax, title):
    plot_logo(X_attr[0, :, 375:625], ax=ax)
    ax.set_title(title)

In [None]:
def get_attributions(wrapped_model, attribution_type, outpath, x_axis_show_threshold=1):
    print('region name', region_name)
    X = torch.tensor(np.expand_dims(region_onehot, axis=0), dtype=torch.float32)

    num_cells = len(cell_idx)
    fig_dynamic = plt.figure(figsize=(10, 2 * num_cells))
    fig_static = plt.figure(figsize=(10, 2 * num_cells))

    axes_dynamic = [fig_dynamic.add_subplot(num_cells, 1, i+1) for i in range(num_cells)]
    axes_static = [fig_static.add_subplot(num_cells, 1, i+1) for i in range(num_cells)]

    vmin, vmax = float('inf'), float('-inf')
    x_start, x_end = 375, 625

    for i, cell in enumerate(cell_idx):
        print("celltype:", cell_names[cell])
        
        if attribution_type == 'deep_lift_shap':
            X_attr = deep_lift_shap(wrapped_model, X, target=cell, device='cpu', random_state=0)
            method_name = 'Deep SHAP'
        elif attribution_type == 'ism':
            X_attr = saturation_mutagenesis(wrapped_model, X, target=cell, device='cpu')
            method_name = 'ISM'
        else:
            raise ValueError("Invalid attribution type. Choose 'deep_lift_shap' or 'ism'.")

        vmin = min(vmin, X_attr[0, :, :].min())
        vmax = max(vmax, X_attr[0, :, :].max())
        max_per_position, _ = torch.max(X_attr[0, :, :], dim=0)
        threshold_idx = torch.argwhere(torch.abs(max_per_position) >= x_axis_show_threshold)
        if threshold_idx.any():
            smallest_visible_idx = threshold_idx[0]
            largest_visible_idx = threshold_idx[-1]
            print('largest visible idx', largest_visible_idx, 'with value', max_per_position[largest_visible_idx])
            x_start = min(x_start, smallest_visible_idx)
            x_end = max(x_end, largest_visible_idx)

        title = f"{region_name} - {cell_names[cell]} ({method_name})"
        
        # Plot with dynamic x-axis limits
        plot_logo(X_attr[0, :, :], axes_dynamic[i])
        plot_logo(X_attr[0, :, :], axes_static[i])


    for ax in axes_dynamic:
        ax.set_title(title)
        ax.set_ylim(vmin, vmax)
        ax.set_xlim(x_start, x_end)
    for ax in axes_static:
        ax.set_title(title)
        ax.set_ylim(vmin, vmax)
        ax.set_xlim(375, 625)
        
    fig_dynamic.tight_layout()
    fig_static.tight_layout()
    
    plt.figure(fig_dynamic.number)
    plt.show()
    
    plt.figure(fig_static.number)
    plt.show()

    # Create file names with region name
    dynamic_filename = f'{region_name}_3_logos_{attribution_type}_dynamic.png'
    static_filename = f'{region_name}_3_logos_{attribution_type}_static.png'
    
    # Save figures to the specified outpath
    if outpath:
        import os
        os.makedirs(outpath, exist_ok=True)
        dynamic_filepath = os.path.join(outpath, dynamic_filename)
        static_filepath = os.path.join(outpath, static_filename)
    else:
        dynamic_filepath = dynamic_filename
        static_filepath = static_filename
    
    fig_dynamic.savefig(dynamic_filepath, dpi=300, bbox_inches='tight')
    fig_static.savefig(static_filepath, dpi=300, bbox_inches='tight')
    
# Example usage:
# get_attributions('deep_lift_shap')
# get_attributions('ism')


In [None]:
get_attributions(wrapped_model2, 'ism', '/data/nchand/analysis/BPcm/BP68', 0.2)

In [None]:
get_attributions('deep_lift_shap', '/data/nchand/analysis/BPcm/BP68', 0.2)