In [None]:
import tensorflow as tf

In [None]:
# Make sure TensorFlow is below 2.16! Otherwise the model won't be able to load.
print(tf.__version__)

In [None]:
# GPU info:
!nvidia-smi

In [None]:
# Assign GPU to use:
GPU_id = '7'
import os
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_id

# check GPU:
gpu_devices = tf.config.list_physical_devices('GPU')
print(f"Num GPUs Available: {len(gpu_devices)}")

if gpu_devices:
    print("GPU working")
    for device in gpu_devices:
        print(f"Device name: {device.name}")
        print(f"Device type: {device.device_type}")
else:
    print("GPU not working")

In [None]:
# Function to generate a list of all possible mutations for a seq:
def all_possible_mutations(dna_seq):
    mutated_seqs = []
    for i in range(len(dna_seq)):
        for nucleotide in ["A", "T", "C", "G"]:
            if nucleotide != dna_seq[i]:
                mutated_seq = dna_seq[:i] + nucleotide + dna_seq[i+1:]
                mutated_seqs.append(mutated_seq)
    #
    return mutated_seqs

# function to plot nucleotide saliencies:
import logomaker
import matplotlib.pyplot as plt
def plot_saliency(df, negative=False,
                  start=None, end=None, figsize=[8,2],
                  xticks=False, yticks=False,
                  spines=False, ylim=None):
    # make Figure and Axes objects:
    fig, ax = plt.subplots(1,1,figsize=figsize)
    # limit x range, if defined:
    if start is not None and end is not None:
        df = df[start : end+1]
    elif start is not None:
        df = df[start : ]
    elif end is not None:
        df = df[ : end+1]
    # flip saliencies if defined:
    if negative == True:
        df = -df
    #
    logo = logomaker.Logo(df, ax=ax)
    #
    if ylim is not None:
        logo.ax.set_ylim(ylim)
    #
    if spines==False:
        logo.style_spines(visible=False)
    #
    if xticks==False:
        ax.set_xticks([])
    #
    if yticks==False:
        ax.set_yticks([])
    #
    return logo.fig.tight_layout()

In [None]:
# Define the model:
model = '/rd4/users/liangn/mywork/L5-220528_em5-LSTM64x32x0.5-64x0.5-rep4.hdf5'
model_x_length = 46
optimal_x_length = 45


# Function to convert a DNA sequence to vector:
vocab = ['pad','N','A','T','C','G']
char2idx = {u:i for i, u in enumerate(vocab)}
def vectorize_dna_seq(dna_seq):
    vectorized_dna_seq = [char2idx[char] for char in dna_seq]
    return vectorized_dna_seq


# Function to convert a list of DNA into x array for ANN inputs:
from tensorflow.keras.preprocessing.sequence import pad_sequences
def prepare_x(dna_list, x_lenth):
    x = list(map(vectorize_dna_seq, dna_list))
    x = pad_sequences(x, maxlen=x_lenth, padding='post')
    #
    return x


# function to split a string into k-mers:
def kmerize(string, k):
    return [string[i:i+k] for i in range(len(string)-k+1)]


# function to predict big seq with sliding windows:
import pandas as pd
from keras.models import load_model
import numpy as np
def predict_big(big_seqs, k=optimal_x_length, model_x_length=model_x_length, model=model):
    # form data.frame:
    # if big_seqs only had one, convert to a list as well:
    if isinstance(big_seqs, str):
        big_seqs = [big_seqs]
    seq_df = pd.DataFrame({'seq': big_seqs})
    # split each seq into k-mers, mark indexes:
    seq_df = seq_df['seq'].apply(lambda x: kmerize(x, k))
    seq_df = seq_df.apply(pd.Series)
    seq_df = seq_df.stack().reset_index(level=1, drop=True).to_frame('seq')
    # prepare x for predictions:
    x = seq_df['seq'].apply(vectorize_dna_seq)
    x = pad_sequences(x, maxlen=model_x_length, padding='post')
    # predict:
    Model = load_model(model)
    y_pred = Model.predict(x)
    # take group means of the same indexes:
    unique_index = np.unique(seq_df.index)
    y_pred_mean = np.zeros((len(unique_index), y_pred.shape[1]))
    for i in range(len(unique_index)):
        y_pred_mean[i] = np.mean(y_pred[seq_df.index == unique_index[i]], axis=0)
    #
    return y_pred_mean


# function to get a df of the regulatory relevance of each nucleotide of a seq:
from keras.models import load_model
import pandas as pd
def saliency_df(seq, optimal_x_length=optimal_x_length, model_x_length=model_x_length, model=model,
                target='expression'):
    # generate all point mutations:
    all_mutants = all_possible_mutations(seq)
    # add the original seq to the mutant list:
    all_mutants.insert(0, seq)
    # check seq length:
    seq_len = len(seq)
    # predict the mutants:
    # use different prediction strategy by seq length:
    if seq_len == optimal_x_length:
        x = prepare_x(all_mutants, model_x_length)
        Model = load_model(model)
        y_pred = Model.predict(x)
    else:
        y_pred = predict_big(big_seqs=all_mutants, k=optimal_x_length, model_x_length=model_x_length,
                             model=model)
    # convert array to dataframe:
    pred_df = pd.DataFrame(y_pred, columns = ['nuc','cyt'])
    # calculate values:
    pred_df['nuc'] = 2**pred_df['nuc']
    pred_df['cyt'] = 2**pred_df['cyt']
    pred_df['expression'] = pred_df['cyt']*(16/17) + pred_df['nuc']*(1/17)
    pred_df['export'] = pred_df['cyt']/pred_df['nuc']
    # calculate delta:
    values = pred_df[target].values
    deltas = values[0] - values[1:]
    # median deltas of each original nucleotide/position:
    delta_medians = []
    for i in range(3, len(deltas)+1, 3):
      median = np.median(deltas[i-3:i])
      delta_medians.append(median)
    # form the final data.frame suitable for logomaker:
    seq_list = list(seq)
    df = pd.DataFrame(columns=['A', 'C', 'G', 'T'])
    for i, letter in enumerate(seq_list):
      df.at[i, letter] = delta_medians[i]
    df = df.fillna(0)
    df = df.astype('float64')
    # change row index to 1,2,3...:
    df.index = range(1, len(df) + 1)
    #
    return df


# function to calculate the connections between nucleotides:
import numpy as np
import matplotlib.pylab as plt
def connections_array (seq, target='expression', optimal_x_length=optimal_x_length,
                     model_x_length=model_x_length, model=model):
    # generate all point mutations:
    all_mutants = all_possible_mutations(seq)
    # add the original seq to the mutant list:
    all_mutants.insert(0, seq)
    # generate a list of saliency data.frames for each mutant:
    df_list = [saliency_df(seq=seq, optimal_x_length=optimal_x_length, model_x_length=model_x_length,
                       model=model, target=target) for seq in all_mutants]
    # convert the data.frames into arrays by summing ATCG values:
    all_saliencies = np.vstack([df.sum(axis=1) for df in df_list])
    # calculate fold changes:
    all_saliencies_fc = np.divide(all_saliencies[1:], all_saliencies[0])
    # group the mutants by the position of mutations:
    sub_arrays = np.array_split(all_saliencies_fc, len(all_saliencies_fc)/3)
    # take medians:
    medians_array = np.array([np.median(sub_array, axis=0) for sub_array in sub_arrays])
    #
    return medians_array


# function to plot the connections:
import seaborn as sns
def plot_connections (array, size):
    fig, ax = plt.subplots(figsize=(size, size))
    ax = sns.heatmap(array, linewidth=0, center=1, cbar_kws={"shrink": .5}, vmin=0, vmax=2)
    ax.set_aspect('equal')
    ax.invert_yaxis()
    plt.ylabel('Mutated position', fontsize = 15, weight='bold')
    plt.xlabel('Affected position', fontsize = 15, weight='bold')
    #
    ax.set_xticks(np.linspace(1.5, array.shape[1]-1.5, num=int(array.shape[1]/2)))
    ax.set_yticks(np.linspace(1.5, array.shape[0]-1.5, num=int(array.shape[0]/2)))
    #
    ax.set_xticklabels([num for num in range(2, array.shape[1]+1, 2)])
    ax.set_yticklabels([num for num in range(2, array.shape[0]+1, 2)])
    plt.show()


In [None]:
N45 = 'CGATCATCTTCATCATCGTCATCATCCGTCTTCCATCCATCCAGT'
the_N45_saliency_df = saliency_df(seq=N45, target='export')
the_N45_saliency_df

In [None]:
plot_saliency(df=the_N45_saliency_df, negative=False, figsize=[8,2], xticks=True, yticks=True)

In [None]:
con_arr = connections_array (seq=N45, target='export')
con_arr

In [None]:
con_arr

In [None]:
con_df = pd.DataFrame(con_arr)
con_df

In [None]:
# save:
con_df.to_csv('fig4bbot4.csv', index=True, header=True)

In [None]:
plot_connections(array=con_arr, size=9)