In [3]:
import pandas as pd
import os, pickle, pysam
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as patches  # <-- For drawing rectangular boxes

# Set global font size to 20
plt.rcParams.update({'font.size': 16})

In [2]:
reference_genome_path="/home/pdutta/Data/Human_Genome_Data/GRCh38.primary_assembly.genome.fa"
reference_fasta = pysam.FastaFile(reference_genome_path)

In [3]:
df_tfbs= pd.read_csv("/home/pdutta/Github/Postdoc/DNABERT_data_processing/TFBS/TFBS_accuracy_Stat.tsv", sep="\t")
df_top = df_tfbs[df_tfbs['eval_acc']>=0.85].reset_index(drop=True)
df_top

Unnamed: 0.1,Unnamed: 0,tags,eval_acc
0,66,CTCFL,0.983854
1,599,ZNF426,0.970906
2,389,SAFB,0.970158
3,369,RBM34,0.970000
4,439,TAF15,0.967351
...,...,...,...
457,130,GFI1B,0.851123
458,623,ZNF558,0.851034
459,551,ZNF19,0.850575
460,406,SMAD5,0.850318


In [4]:
# Define a function that searches for motifs in the provided sequence.
def find_motifs(sequence, motif_dict):
    """
    Takes a sequence string and a dictionary of motifs.
    Returns a list of motifs that are present in the sequence.
    """
    matches = []
    for motif in motif_dict.keys():
        if motif in sequence:
            matches.append(motif)
    return matches

In [5]:
# Define the function that checks if the variant falls within a given motif.
def is_variant_in_motif(candidate_seq, variant_relative_pos, motif):
    """
    candidate_seq: The 300bp candidate sequence.
    variant_relative_pos: Zero-indexed position of the variant within candidate_seq.
    motif: The core motif (e.g., "GGCCG").
    """
    # Find the motif position within the candidate sequence.
    motif_start = candidate_seq.find(motif)
    if motif_start == -1:
        # Motif not found in candidate sequence.
        return False
    # Determine the motif boundary.
    motif_end = motif_start + len(motif) - 1
    # Check if the variant falls inside the motif boundaries.
    return motif_start <= variant_relative_pos <= motif_end

In [31]:
def get_sequences(df, reference_fasta):
    data = []
    for idx, row in df.iterrows():
        #print(row)
        chrom = row[0]
        ref_start = row[1]
        ref_end = row[2]
        variant_start = row['START_POS']
        variant_end = row['END_POS']
        ref_nucleotide = row['REF']
        alt = row["ALT"]
        
        # Adjust for 0-based indexing in python
        variant_pos_start = variant_start - ref_start
        variant_pos_end = variant_end - ref_start
        #print(ref_nucleotide , alt, variant_pos_start, variant_pos_end)
        
        
        # Get reference sequence
        #print(chrom)
        ref_seq = row['ref_seq']
        #print(ref_seq[variant_pos_start:variant_pos_end] , ref_nucleotide)
        
        # Handle insertion and deletion to get the correct alt sequence
        # Identify if the variant is an insertion or deletion
        if len(ref_nucleotide) < len(alt):  # Insertion
            delete_size =  len(alt) - len(ref_nucleotide)
            #print(variant_pos_start, variant_pos_end ,delete_size)
            alt_seq = ref_seq[:variant_pos_start] + alt + ref_seq[variant_pos_end:len(ref_seq) - delete_size]

        elif len(ref_nucleotide) > len(alt):  # Deletion
            insert_size = len(ref_nucleotide) - len(alt)
            #print(insert_size)
            extra_bases = reference_fasta.fetch(chrom, ref_end, ref_end + insert_size)
            #print(extra_bases)
            alt_seq = ref_seq[:variant_pos_start] + alt + ref_seq[variant_pos_end:] + extra_bases

        else:  # SNV
            alt_seq = ref_seq[:variant_pos_start] + alt + ref_seq[variant_pos_end:]
        print(ref_nucleotide, alt)
        # print(ref_seq)
        # print(alt_seq)
        # input()
        # Append the computed alt_seq to the list
        data.append(alt_seq)

    df.insert(loc=4, column='alt_seq', value=data)
    return(df)

In [32]:
def kmers_to_sequence(kmer_str, k=6):
    """
    Convert a space-separated kmer string back to the original sequence.
    
    Arguments:
    kmer_str -- str, e.g., "AGAAAG GAAAGA AAAGAA ..."
    k -- int, length of each kmer (default 6)
    
    Returns:
    original sequence as a string.
    """
    kmers = kmer_str.split()
    #print(len(kmers))
    if not kmers:
        return ""
    # The first kmer contributes all k nucleotides.
    seq = kmers[0]
    # Each subsequent kmer overlaps the previous one by k-1 bases.
    for kmer in kmers[1:]:
        seq += kmer[-1]
    return seq

In [33]:
def add_variant_annotations(ax, sequence, attention_scores, start_index, variant_start, variant_end, cmap, variant_length, nu_color):
    # Normalize the attention_scores to the range of the colormap
    norm = plt.Normalize(vmin=np.min(attention_scores), vmax=np.max(attention_scores))
    
    # Get the RGBA colors from the normalized attention score
    rgba_colors = [cmap(norm(score)) for score in attention_scores]
    
    # Determine the positions to color red
    variant_positions = range(variant_start-start_index, variant_start + variant_length-start_index)
    #print(variant_positions, start_index)
    
    # Assuming each character in sequence is one nucleotide and their attention score is at the same index
    for idx, (nucleotide, rgba) in enumerate(zip(sequence, rgba_colors)):
        # Calculate luminance of the RGBA color
        luminance = (0.299 * rgba[0] + 0.587 * rgba[1] + 0.114 * rgba[2])
        # Determine text color based on luminance
        text_color = 'white' if luminance < 0.7 else 'black'  # Adjust threshold as needed
        # Override text color for variant positions
        if idx in variant_positions:
            text_color = nu_color
        # Add text annotation with the determined color
        ax.text(idx + 0.5, 0.5, nucleotide, horizontalalignment='center', verticalalignment='center', color=text_color, fontsize=20)

In [38]:
buffer_length = 10  # Define buffer length

# Iterate over pairs of sequences (reference and alternative) and attention scores
for i in range(0, len(kmers), 2):
    if(i%20==0):
        print(i)
    sequence_ref = kmers[i]
    sequence_alt = kmers[i + 1]
    
    
    #print(filtered_df1['Reference_Nucleotide'].iloc[i // 2], filtered_df1['Alternative_Nucleotides'].iloc[i//2])
    # Calculate relative positions of variant start and end within the sequence
    relative_variant_start = df_TRIM22['variant_start_position'].iloc[i // 2] - df_TRIM22['TFBS_start_position'].iloc[i // 2]
    relative_variant_end = df_TRIM22['variant_end_position'].iloc[i // 2] - df_TRIM22['TFBS_start_position'].iloc[i // 2]

    # Adjust indices for buffer length, ensuring they are within sequence boundaries
    start_index = max(0, relative_variant_start - buffer_length)
    end_index = min(len(sequence_ref), relative_variant_end + buffer_length)
    print(relative_variant_start, relative_variant_end, start_index, end_index)

    # Check if slicing range is valid
    if start_index >= end_index:
        print("Invalid slice range, skipping...")
        continue

    # Slice the sequences and ` data
    sliced_sequence_ref = sequence_ref[start_index+1:end_index+1]
    sliced_sequence_alt = sequence_alt[start_index+1:end_index+1]
    
    sliced_attention_ref = df_TRIM22['Ref_attn'].iloc[i // 2][start_index:end_index]
    sliced_attention_alt = df_TRIM22['Alt_attn'].iloc[i // 2][start_index:end_index]
    print(sliced_sequence_ref , sliced_sequence_alt)
    print(df_TRIM22['Ref_attn'].iloc[i // 2])
    print(type(sliced_attention_ref))
    print(sliced_attention_ref)
    print(sliced_attention_alt)
    print(start_index, end_index)
    #print(f"Reference Neucleotide: {filtered_df1['Reference_Nucleotide'].iloc[i // 2]}; Alternative Neucleotide: {filtered_df1['Alternative_Nucleotides'].iloc[i // 2]}")
    
    original_start_index=  max(df_TRIM22['TFBS_start_position'].iloc[i//2], df_TRIM22['variant_start_position'].iloc[i//2] - buffer_length)
    original_end_index=  min(df_TRIM22['TFBS_end_position'].iloc[i//2], df_TRIM22['variant_end_position'].iloc[i//2] + buffer_length)
    #print(original_start_index, original_end_index)
    

    
    # Determine the number of ticks needed, with a tick for every 5 nucleotides
    num_ticks = (original_end_index - original_start_index) // 5 + 1  # Ensure we include the end index

    # Generate the tick positions and labels
    tick_positions = [i * 5 for i in range(num_ticks)]  # Positions in the sequence
    tick_labels = [original_start_index + pos for pos in tick_positions]  # Map to actual nucleotide positions

    # Ensure tick positions and labels do not exceed the length of the sequence or the end index
    tick_positions = [pos for pos, label in zip(tick_positions, tick_labels) if label <= original_end_index]
    tick_labels = [label for label in tick_labels if label <= original_end_index]
    # print(plot_range)
    print(tick_positions)
    print(tick_labels)
    
    # Calculate variant_length_change based on your actual data
    # Calculate variant_length based on the nucleotides in the variant
    ref_nucleotides = df_TRIM22['ref_nucleotide'].iloc[i // 2]
    alt_nucleotides = df_TRIM22['alternative_nucleotide'].iloc[i // 2]
    variant_length = max(len(ref_nucleotides), len(alt_nucleotides))
    #print(variant_length)
    
    # Create a new figure for each iteration
    fig, axs = plt.subplots(2, 1, figsize=(18, 3))
    plt.subplots_adjust(hspace=1)  # Adjust spacing between subplots
    
    #y_position = 1.08
    
        # Add text for the Acceptor Region
    plt.text(
        x=0.5, 
        y=1.30, 
        s=f"Region: {df_TRIM22['TFBS_start_position'].iloc[i // 2]}-{df_TRIM22['TFBS_end_position'].iloc[i // 2]}", 
        horizontalalignment='center', 
        verticalalignment='center', 
        transform=plt.gcf().transFigure
    )
    #y_position -= 0.04  # Adjust the y position for the next line

    # Add text for the Variant Region
    # Replace variant_region_info with the appropriate variable or value
    plt.text(
        x=0.5, 
        y=1.24, 
        s=f"Variant Region: {df_TRIM22['variant_start_position'].iloc[i // 2]}-{df_TRIM22['variant_end_position'].iloc[i // 2]}", 
        horizontalalignment='center', 
        verticalalignment='center', 
        transform=plt.gcf().transFigure
    )

    # Replace plotted_region_info with the appropriate variable or value
    plt.text(
        x=0.5, 
        y=1.18, 
        s=f"Region Plotted: {original_start_index}-{original_end_index}", 
        horizontalalignment='center', 
        verticalalignment='center', 
        transform=plt.gcf().transFigure
    )   # Create subplots
    
    # Replace plotted_region_info with the appropriate variable or value
    plt.text(
        x=0.5, 
        y=1.12, 
        s=f"Reference Neucleotide: {df_TRIM22['ref_nucleotide'].iloc[i // 2]}; Alternative Neucleotide: {df_TRIM22['alternative_nucleotide'].iloc[i // 2]}" , 
        horizontalalignment='center', 
        verticalalignment='center', 
        transform=plt.gcf().transFigure
    )   # Create subplots
    
    plt.text(
        x=0.5, 
        y=1.06, 
        s=f"Reference probablity: {df_TRIM22['Ref_probab'].iloc[i // 2]:.4f}; Alternative probablity: {df_TRIM22['Alt_probab'].iloc[i // 2]:.4f}" , 
        horizontalalignment='center', 
        verticalalignment='center', 
        transform=plt.gcf().transFigure
    )   # Create subplots
    
#     plt.text(
#         x=0.5, 
#         y=1.00, 
#         s=f"DBSNP_IDs: {df_TRIM22['dbsnp_id'].iloc[i // 2]}" , 
#         horizontalalignment='center', 
#         verticalalignment='center', 
#         transform=plt.gcf().transFigure
#     )   # Create subplots
    
    
    # Define the colormap you are using for your heatmap
    cmap = sns.color_palette("YlGnBu", as_cmap=True)

     # Plot heatmap for reference sequence
    sns.heatmap([sliced_attention_ref], cmap=cmap, cbar=True, annot=False, ax=axs[0])
    add_variant_annotations(axs[0], sliced_sequence_ref, sliced_attention_ref, start_index, relative_variant_start, relative_variant_end, cmap, len(ref_nucleotides))
    axs[0].set_title("Reference Sequence", pad=10)
    axs[0].set_xticks(tick_positions)
    axs[0].set_xticklabels(tick_labels)
    


    # Plot heatmap for alternative sequence
    sns.heatmap([sliced_attention_alt], cmap=cmap, cbar=True, annot=False, ax=axs[1])
    add_variant_annotations(axs[1], sliced_sequence_alt, sliced_attention_alt, start_index, relative_variant_start, relative_variant_end, cmap, len(alt_nucleotides))
    axs[1].set_title("Alternative Sequence", pad=10)
    axs[1].set_xticks(tick_positions)
    axs[1].set_xticklabels(tick_labels)
    
    # Set the general title for the figure
    fig.text(0.5, -0.2, f"Plotted Region across variant region {df_TRIM22['CHROM'].iloc[i // 2]}:{original_start_index}-{original_end_index}", ha='center', fontsize=16, fontweight='bold')
    
    # Save the plot
    plt.savefig(visualization_path + "/" + f"Paired_sequence_{df_TRIM22['CHROM'].iloc[i // 2]}_{original_start_index}_{original_end_index}.png", bbox_inches='tight')
    plt.close()

G GC
C CG
TG T
GT G
T A
G A
C G
Processing CTCFL
0
GAGCC
CCCCC
CCTGT
GGTGG
CCCTA
AAGCG
CTCCA_CCACAA
G A
T C
TCAAA T
CAAAAAATA C
CAATAAATA C
Processing ZNF426
0
CGCCA
GATTC
CTCAA
CTCAA
CTCAA
C T
G T
Processing SAFB
0
CCGCA
CCGCA
Processing RBM34
G A
TG T
G C
AGC A
AGCGC A
AGCGCGC A
AGCGCGCGC A
CCCTT TC
C T
C G
G A
G C
C CA
G T
C T
G A
G A
C T
Processing TAF15
0
CTCGT
CTGCC
CTGCC
GGACA
GGACA
GGACA
GGACA
CTGCC
CGAGA
GCGGG
20
GTTCC
GCCAG
CCGGT
GGTCG
TGCGA
GGTCG_GTGCT
GGTCG_GGACA
CGAAC
TG T
AGAGGCGGG A
T A
TGGC T
TGGCG T
GGC G
CGGG C
TGAGAGAGA T
CGGGGG C
CGGGGGG C
CGGGGGGG C
CGGGGGGGG C
C CT
ACCCCC A
TTCCC T
C T
TCC T
TCCC T
CG C
GCGCACA G
A T
GC G
C T
A ACCCCCC
C T
G C
C T
CGG C
CGGGG C
CGGGGGG C
C CG
G C
Processing PCBP1
0
GTGGG
AGAGG
CCTCC
TGGCG
TGGCG
TGGCG
TGGCG
CTGAG
GACGG
GACGG
20
GACGG
GACGG
AGCTC
CACCC
TTCCC
GCCCC
TTCCC
TTCCC
TGGCG
GCGCA
40
GCGCA
GGATG
GCCCC
GCCCC
CACCC
CCTCC
AGGAG
GCGGT_TGGCG
GGGGC
GGGGC
60
GGGGC
CCTCC
GGCTT_GGGGC
T G
C CCCT
T G
T G
Processing SRSF3
0
TGGGG_GGGTG
G