In [1]:
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

import sys
import os
# Get the parent directory of train/Synthetic_Data_Generator
parent_dir = os.path.abspath(os.path.join(os.getcwd(), "../../"))
sys.path.append(parent_dir)
from train.utils import EnhancerDataset
from train import interpretation
from model.model import ExplaiNN3
from scripts.synthetic_prediction import generate_synthetic_distance_data, motif_score_prediction

# Generate Synthetic Data and Test Motif Distance

In [2]:
num_cnns = 90
filter_size = 19
batch = 322
weight_file = '/pmglocal/ty2514/Enhancer/Enhancer/data/ExplaiNN_both_results/best_r2_model_epoch_53.pth'
target_labels = ['GFP+','GFP-']
meme_file_dir = '/pmglocal/ty2514/Enhancer/motif-clustering/results/all.db.meme'
output_pickle_file = '/pmglocal/ty2514/Enhancer/Enhancer/train/Synthetic_Data_Generator/synthetic_seq_dist.pkl'

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model without moving it to the device yet
explainn = ExplaiNN3(num_cnns=num_cnns, input_length=608, num_classes=2,
                     filter_size=filter_size, num_fc=2, pool_size=7, pool_stride=7,
                     drop_out=0.3, weight_path=None)  # Training

# Load the model weights conditionally based on GPU availability
if torch.cuda.is_available():
    explainn.load_state_dict(torch.load(weight_file))
    print('explainn loaded on GPU')
else:
    explainn.load_state_dict(torch.load(weight_file, map_location=torch.device('cpu')))
    print('explainn loaded on CPU')
# Move the model to the appropriate device after loading the weights
explainn.to(device)
explainn.eval()
print('\n')

explainn loaded on CPU




In [13]:
df = generate_synthetic_distance_data(200, ['M00224_2.00','M00303_2.00'],meme_file_dir,output_pickle_file,distance = 10, replicate=5,save_plot = False)
result_df = motif_score_prediction(model = explainn, df = df, device = device, batch = batch, target_labels=target_labels)

PWM for M00224_2.00 has length: 8
M00224_2.00: CCCGCCCC
PWM for M00303_2.00 has length: 11
M00303_2.00: GCTAATTACTG


Motif A CCCGCCCC has 20 200nt DNA segments, inserted with distance 10
Motif B GCTAATTACTG has 19 200nt DNA segments, inserted with distance 10
Motif A CCCGCCCC has 20 200nt DNA segments, inserted with distance 10
Motif B GCTAATTACTG has 19 200nt DNA segments, inserted with distance 10
Motif A CCCGCCCC has 20 200nt DNA segments, inserted with distance 10
Motif B GCTAATTACTG has 19 200nt DNA segments, inserted with distance 10
Motif A CCCGCCCC has 20 200nt DNA segments, inserted with distance 10
Motif B GCTAATTACTG has 19 200nt DNA segments, inserted with distance 10
Motif A CCCGCCCC has 20 200nt DNA segments, inserted with distance 10
Motif B GCTAATTACTG has 19 200nt DNA segments, inserted with distance 10


# Interpretation

In [4]:
cluster_results = pd.read_csv("/pmglocal/ty2514/Enhancer/motif-clustering/JASPAR2024_mus_musculus_non-redundant_results/metadata.tsv",
                                        sep="\t",comment="#")
tomtom_results = pd.read_csv("/pmglocal/ty2514/Enhancer/Enhancer/data/ExplaiNN_both_results/tomtom_results/tomtom.tsv",
                                        sep="\t",comment="#")
filters_with_min_q = tomtom_results.groupby('Query_ID').min()["q-value"]
tomtom_results = tomtom_results[["Target_ID", "Query_ID", "q-value"]]
tomtom_results = tomtom_results[tomtom_results["q-value"]<0.05]

motif_to_cluster = cluster_results.set_index('motif_id')['cluster'].to_dict()
motif_to_tf_name = cluster_results.set_index('motif_id')['tf_name'].to_dict()
motif_to_family_name = cluster_results.set_index('motif_id')['family_name'].to_dict()
filters = tomtom_results["Query_ID"].unique()

# Assuming `annotation` is already populated
annotation_data = []

for f in filters:
    t = tomtom_results[tomtom_results["Query_ID"] == f]
    target_id = t["Target_ID"]

    if len(target_id) > 5:
        target_id = target_id[:5]

    # Join Unique annotations by '/'
    cluster = "/".join({motif_to_cluster[i]: i for i in target_id.values})
    tf_name = "/".join({motif_to_tf_name[i]: i for i in target_id.values})
    family_name = "/".join({motif_to_family_name[i]: i for i in target_id.values})

    # Append the data to the list
    annotation_data.append({
        'filter': f,
        'cluster': cluster,
        'tf_name': tf_name,
        'family_name': family_name
    })

# Create a DataFrame from the collected data
annotation_df = pd.DataFrame(annotation_data)

# Display the resulting DataFrame
annotation_df.iloc[:5,:]


Unnamed: 0,filter,cluster,tf_name,family_name
0,filter2,AC0069,Pax4,['Paired box factors']
1,filter6,AC0069,Pax4,['Paired box factors']
2,filter7,AC0066,Gli1/Gli2,['C2H2 zinc finger factors']
3,filter8,AC0069,Pax4,['Paired box factors']
4,filter10,AC0069,Pax4,['Paired box factors']


In [5]:
weights = explainn.final.weight.detach().cpu().numpy()
print(f'weight_df has shape: {weights.shape} (number of labels, number of fileters)')
filters = ["f"+str(i) for i in range(num_cnns)]
print(filters)
for index,row in annotation_df.iterrows():
    filter = row['filter']
    split_string = filter.split('filter', 1)
    # change 'filter{i}' to 'f{i}'. e.g. filter20 -> f20
    new_filter_name = 'f' + split_string[1].strip()
    # Check if filter is in the filters list
    if new_filter_name in filters:
        # Find the index of the element to be replaced
        index_to_replace = filters.index(new_filter_name)
        # Replace the element in the filters list
        filters[index_to_replace] = f"{row['cluster']}({row['tf_name']})-{new_filter_name}"


#for i in annotation.keys():
#    filters[int(i.split("filter")[-1])] = annotation[i]
weight_df = pd.DataFrame(weights, target_labels, columns=filters)
#result_dir = '/pmglocal/ty2514/Enhancer/Enhancer/data/ExplaiNN_both_results'
#weight_file_dir = os.path.join(result_dir, 'filter_weights.csv')
# Save the DataFrame to a CSV file
#weight_df.to_csv(weight_file_dir, index=True)  # Set index=True if you want to save the index
weight_df

weight_df has shape: (2, 90) (number of labels, number of fileters)
['f0', 'f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20', 'f21', 'f22', 'f23', 'f24', 'f25', 'f26', 'f27', 'f28', 'f29', 'f30', 'f31', 'f32', 'f33', 'f34', 'f35', 'f36', 'f37', 'f38', 'f39', 'f40', 'f41', 'f42', 'f43', 'f44', 'f45', 'f46', 'f47', 'f48', 'f49', 'f50', 'f51', 'f52', 'f53', 'f54', 'f55', 'f56', 'f57', 'f58', 'f59', 'f60', 'f61', 'f62', 'f63', 'f64', 'f65', 'f66', 'f67', 'f68', 'f69', 'f70', 'f71', 'f72', 'f73', 'f74', 'f75', 'f76', 'f77', 'f78', 'f79', 'f80', 'f81', 'f82', 'f83', 'f84', 'f85', 'f86', 'f87', 'f88', 'f89']


Unnamed: 0,f0,f1,AC0069(Pax4)-f2,f3,f4,f5,AC0069(Pax4)-f6,AC0066(Gli1/Gli2)-f7,AC0069(Pax4)-f8,f9,...,f80,f81,f82,f83,AC0069(Pax4)-f84,f85,f86,f87,AC0069(Pax4)-f88,f89
GFP+,0.081534,0.181157,0.248874,0.18156,0.002539,0.047324,0.237871,0.110817,0.233839,0.033559,...,0.212228,0.11542,0.157949,0.121214,0.217157,0.134688,0.163591,0.001927,0.169545,-0.001799
GFP-,0.000444,0.083332,0.188044,0.242838,0.058648,0.123694,0.170986,0.155349,0.08056,0.121754,...,0.099247,0.197698,0.133829,0.035363,0.230519,0.041809,0.102332,0.035025,0.165388,-0.024118


In [126]:
import matplotlib.pyplot as plt
import math
import pandas as pd

def plot_filter_weight(weight_df, dir_save_plot):
    num_cnns = weight_df.shape[1]  # Assuming number of columns represents CNNs
    
    # Loop through each row to create a separate plot
    for index, row in weight_df.iterrows():
        # Sort the row in descending order by weight values
        sorted_row = row.sort_values(ascending=False)

        # Extract labels (column names, now sorted) and values (sorted weights)
        labels = sorted_row.index
        values = sorted_row.values

        # Calculate min and max values for the x-axis range
        min_value = values.min()
        max_value = values.max()

        # Define colors for the bars based on a condition (customize as needed)
        colors = ['royalblue' if '-' not in label.lower() else 'red' for label in labels]

        # Create a new figure for each row
        plt.figure(figsize=(8, math.ceil(0.15 * num_cnns)))

        # Plot the bar chart for this row
        plt.barh(labels, values, color=colors)
        plt.title(f'Weights for Target: {index}')
        plt.xlabel('Weight')
        plt.ylabel('Filters')

        # Invert y-axis to have the highest value at the top
        plt.gca().invert_yaxis()

        # Set x-axis limits
        plt.xlim(min_value - 0.05, max_value + 0.05)

        # Annotate the value next to each bar
        for i, (label, value) in enumerate(zip(labels, values)):
            if value >= 0:
                plt.text(value, i, f'{value:.3f}', va='center', ha='left', fontsize=10)
            else:
                plt.text(value, i, f'{value:.3f}', va='center', ha='right', fontsize=10)

        # Adjust layout to prevent overlap
        plt.tight_layout()
        plt.show()

        # Save the individual plot to the specified directory
        plot_filename = f'{dir_save_plot}/filter_weights_{index}.png'
        # plt.savefig(plot_filename)
        print(f'Saved plot for {index} at {plot_filename}')

        # Optionally close the plot after saving to free up memory
        # plt.close()


In [7]:
# Plotting
#plot_filter_weight(weight_df,'asdf')
annotated_weight_df = weight_df.loc[:, weight_df.columns.str.contains('-')]
#plot_filter_weight(annotated_weight_df,'asdf')

In [8]:
upper_bound = 0.25
input_data_dir = '/pmglocal/ty2514/Enhancer/Enhancer/data/input_data.csv'
df = pd.read_csv(input_data_dir)
dataset = EnhancerDataset(df, feature_list=['G+','G-'], scale_mode = 'none')
# Prepare dataloader
dataset = DataLoader(dataset=dataset, batch_size=batch, shuffle=False)
# Running get_explainn_predictions function to get predictions and true labels for all sequences in the given data loader
predictions, labels = interpretation.get_explainn_predictions(dataset, explainn, device, isSigmoid=False)

# Calculate absolute residuals
residuals = np.abs(labels - predictions)

# Define the upper bound of residuals
print(f'Using Bound = {upper_bound} as a cutoff to select high confident predictions.')

# Create a mask for filtering out samples with low confident precition (abs(residual) > upper_bound)
mask = (residuals <= upper_bound).all(axis=1)
# Get sequences and labels from dataset
data_inp = []
data_out = []
# Iterate over the DataLoader
for batch_features, batch_labels in dataset:
    data_inp.append(batch_features)
    data_out.append(batch_labels)
# Concatenate all the batches into single tensors
data_inp = torch.cat(data_inp, dim=0)
data_out = torch.cat(data_out, dim=0)

# Use the mask to filter the predictions and labels
print(f'Total number of input samples: {len(data_inp)}')
data_inp = data_inp[mask]
data_out = data_out[mask]
print(data_inp.shape)
print(data_out.shape)

print(f'Number of input samples with high confident prediction: {len(data_inp)}')

# Create new dataloader with filtered high confident samples
dataset = torch.utils.data.TensorDataset(data_inp, data_out)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch, shuffle=False,)



Using Bound = 0.25 as a cutoff to select high confident predictions.
Total number of input samples: 28800
torch.Size([21449, 4, 608])
torch.Size([21449, 2])
Number of input samples with high confident prediction: 21449


In [9]:
# Assuming 'interpretation' and 'explainn' are already defined and properly configured
# Also assuming 'data_loader', 'device', and 'annotation' are defined as per your setup
activations = np.load('/pmglocal/ty2514/Enhancer/Enhancer/data/ExplaiNN_both_results/Model_Activations.npy')

unit_importance_GFP_pos = []
unit_importance_GFP_neg = []

# Loop through units with indices 0 to 4
for unit_index in range(num_cnns):
    print(unit_index)
    # Calculate unit importance for the current unit
    unit_outputs = interpretation.get_explainn_unit_outputs(data_loader, explainn, device)
    target_labels = ['GFP+','GFP-']
    importance = interpretation.get_specific_unit_importance(activations, explainn, unit_outputs, unit_index, target_labels)
    unit_importance_GFP_pos.append(importance['GFP+'])
    unit_importance_GFP_neg.append(importance['GFP-'])


0
1
2
3
4
5
6
7
8
9
10


KeyboardInterrupt: 

In [None]:
"""Plot the Ranked Importance Value of Each Filter"""
num_filter_plot = 10
# Create a list to store common important filters
common_filters = []
def plot_importance(unit_importance_values, unit_names, title_suffix):
    plt.figure(figsize=(8, math.ceil(0.15 * num_cnns)))

    # Calculate the means of each list in unit_importance_values
    means = [np.mean(values) for values in unit_importance_values]

    # Create tuples of means, unit names, and values, then sort them by means
    sorted_data = sorted(zip(means, unit_names, unit_importance_values), key=lambda x: x[0], reverse=True)

    # Unzip the sorted data
    sorted_means, sorted_names, sorted_values = zip(*sorted_data)

    # Print top 10 most impoertant filters
    #values_list = [tf_to_filter[name] for name in sorted_names[:10] if name in tf_to_filter]
    #common_filters.append(values_list)
    #print(values_list)

    # Unzip the sorted data
    #sorted_means, sorted_names, sorted_values = zip(*sorted_data[:num_filter_plot][::-1])

    # Define properties for outliers (fliers)
    flierprops = dict(marker='o', color='black', markersize=6)

    # Create box plots individually to control colors
    box_width = 0.6  # Set box width
    for i, (name, data) in enumerate(zip(sorted_names, sorted_values)):
        color = "#ff9999" if "filter" not in name.lower() else "#228833"
        plt.boxplot(data, positions=[i + 1], widths=box_width, notch=True, patch_artist=True,vert=False,
                   boxprops=dict(facecolor=color, color=color), flierprops=flierprops)

    # Set custom sorted x-axis labels
    plt.set_yticks(range(1, len(sorted_names) + 1))
    plt.set_yticklabels(sorted_names, rotation=0)
    plt.set_title(f"Unit Importance of Each Filter on Predicting {title_suffix}")
    plt.set_xlabel("Importance Values")
    plt.show()


# Plot for GFP+ positive
plot_importance(unit_importance_GFP_pos, "GFP Positive")

# Plot for GFP+ negative
plot_importance(unit_importance_GFP_neg, "GFP Negative")

plt.tight_layout()
plt.show()



In [None]:
unit_importance_GFP_pos