In [1]:
import random
import math
import warnings
import pandas as pd
from utils import EnhancerDataset, split_dataset, train_model, regression_model_plot, plot_filter_weight
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import sys
sys.path.append('../model')  
from model import ConvNetDeep, DanQ, ExplaiNN,ConvNetDeep2, ExplaiNN2, ExplaiNN3

In [2]:

def generate_random_dna(length, num_dna, motifA, motifB, directionality_is_true):
    """Generate random DNA sequences and calculate the motif distance score."""
    if (len(motifA) + len(motifB)) > (length - 10):
        raise ValueError('Length of motif A plus motif B greater then dna length minus 10')
    if len(motifA) > length // 3 or len(motifB) > length // 3:
        warnings.warn("One of the motifs is longer than one-third of the DNA length.")
    
    dna_sequences = []
    scores = []
    proportion = {'no_A_no_B': 0, 'no_A_has_B': 0, 'has_A_no_B':0, 'has_A_has_B':0, 'A_before_B': 0, 'B_before_A': 0}
    
    nucleotides = ['A', 'T', 'C', 'G']
    
    for _ in range(num_dna):
        # Generate a random DNA sequence
        dna = ''.join(random.choices(nucleotides, k=length))
        
        # Decide randomly whether to insert motifA and motifB
        insert_motifA = random.choice([True, False])
        insert_motifB = random.choice([True, False])

        index_a = -1
        index_b = -1
        
        if insert_motifA:
            # Pick a random portion of dna to get replaced by motifA
            index_a = random.randint(0, length - len(motifA))
            dna = dna[:index_a] + motifA + dna[index_a + len(motifA):]
        
        if insert_motifB:
            not_inserted_yet = True
            while not_inserted_yet:
                index_b = random.randint(0, length - len(motifB))
                # Check if the selected index b would cause motifB overlap motifA
                if (index_b + len(motifB) <= index_a) or (index_a + len(motifA) <= index_b):
                    dna = dna[:index_b] + motifB + dna[index_b + len(motifB):]
                    not_inserted_yet = False
        
        dna_sequences.append(dna)

        if insert_motifA and insert_motifB:
            proportion['has_A_has_B'] += 1
            if index_a < index_b:
                proportion['A_before_B'] += 1
            elif index_a > index_b:
                proportion['B_before_A'] += 1
        elif insert_motifA and not insert_motifB:
            proportion['has_A_no_B'] += 1
        elif not insert_motifA and insert_motifB:
            proportion['no_A_has_B'] += 1
        elif not insert_motifA and not insert_motifB:
            proportion['no_A_no_B'] += 1
        
        # Calculate the motif distance score
        motifs_a = list(find_all_motifs(dna, motifA)) if insert_motifA else []
        motifs_b = list(find_all_motifs(dna, motifB)) if insert_motifB else []
        
        score = 0
        if motifs_a and motifs_b:
            for (start_a, end_a) in motifs_a:
                for (start_b, end_b) in motifs_b:
                    if end_a <= start_b:
                        distance = start_b - end_a
                    elif end_b <= start_a:
                        distance = start_a - end_b
                    else:
                        overlap_start = max(start_a, start_b)
                        overlap_end = min(end_a, end_b)
                        overlap_length = overlap_end - overlap_start
                        longer_motif_length = max(end_a - start_a, end_b - start_b)
                        distance = overlap_length / longer_motif_length
                    if distance < float('inf'):
                        if distance >= 1:
                            score = math.exp(1 / (distance))
                        elif distance == 0:
                            score = math.exp(1 / (0.8))
                        else:
                            raise ValueError("distance between motifA and motifB cannot be negative")
                    
                    if directionality_is_true and start_b < start_a:
                        score = 0
            
        scores.append(score) 
    return dna_sequences, scores, proportion

def find_all_motifs(sequence, motif):
    """Find all occurrences of a motif in the sequence, returning start and end positions."""
    start = 0
    motif_length = len(motif)
    while True:
        start = sequence.find(motif, start)
        if start == -1:
            break
        end = start + motif_length
        yield (start, end)
        start += 1  # Move start forward for next search to allow overlapping motifs

# Example usage
dna_sequences, scores, portions = generate_random_dna(length=608, num_dna=20000, motifA='ATG', motifB='CGT', directionality_is_true=False)
print(portions)

{'no_A_no_B': 5024, 'no_A_has_B': 4927, 'has_A_no_B': 5087, 'has_A_has_B': 4962, 'A_before_B': 2502, 'B_before_A': 2460}


In [3]:
# Define some hyperparameters
seed = 42
batch = 200
num_cnns = 90
learning_rate = 2e-4
target_labels = ['Motif Distance Score']
output_dir = '/pmglocal/ty2514/Enhancer/Enhancer/data/ExplaiNN_synthetic_motif_results'


df = pd.DataFrame({'sequence':dna_sequences, 'score': scores})
# Plot histogram of scores
#plt.hist(df['score'], bins=100, edgecolor='black')
#plt.title('Distribution of Scores')
#plt.xlabel('Score')
#plt.ylabel('Frequency')
#plt.ylim(0,2000)
#plt.show()

train, test = split_dataset(df, split_type='random', cutoff = 0.8, seed = seed)

train = EnhancerDataset(train, label_mode='score', scale_mode = 'none')
test = EnhancerDataset(test, label_mode='score', scale_mode = 'none')

# DataLoader setup
train_loader = DataLoader(dataset=train, batch_size=batch, shuffle=True)
test_loader = DataLoader(dataset=test, batch_size=batch, shuffle=True)

input_model = ExplaiNN3(num_cnns = num_cnns, input_length = 608, num_classes = 1, 
                 filter_size = 19, num_fc=2, pool_size=7, pool_stride=7, 
                 drop_out = 0.3, weight_path = None)# Training

_, _, model, train_losses_by_batch, test_losses_by_batch, results, best_pearson_epoch, best_r2_epoch, device  = train_model(input_model, train_loader, test_loader, 
                                                                                                                            target_labels=target_labels,num_epochs=200, 
                                                                                                                        batch_size=batch, learning_rate=learning_rate, 
                                                                                                                        criteria='mse',optimizer_type = "adam", patience=10, 
                                                                                                                        seed = seed, save_model= True, dir_path=output_dir)

In [None]:
#best_r2_epoch = 96
model_path = f'/pmglocal/ty2514/Enhancer/Enhancer/data/ExplaiNN_synthetic_motif_results/model_epoch_{best_r2_epoch}.pth'

mse, rmse, mae, r2, pearson_corr, spearman_corr = regression_model_plot(
    model, test_loader, train_losses_by_batch, test_losses_by_batch, 
    device, results, label_mode = "distance", save_plot = False, dir_path = None, model_path = model_path, best_model=best_r2_epoch)

print(f"MSE: {mse:.4f}, RMSE: {rmse:.4f}, MAE: {mae:.4f}")
print(f"R^2: {r2:.4f}, Pearson Correlation: {pearson_corr:.4f}, Spearman Correlation: {spearman_corr:.4f}")