In [5]:
import sys
import pickle
import matplotlib.pyplot as plt

### Panel B,C

#### download validation data:
* wget https://public.gi.ucsc.edu/brookslab/addseq/231205_valdata_pred_out.pkl

In [7]:
pred_out = '231205_valdata_pred_out.pkl'
with open(pred_out, 'rb') as pred_outf:
    seq_preds = pickle.load(pred_outf)
print("Loaded Python object:", seq_preds)

In [1]:
print('computing accuracy...')
correct = {0: 0, 1: 0}
total = {0: 0, 1: 0}

for seq_id in tqdm(seq_preds):
    label = seq_id[0]
    pred_arr = np.round(np.array(seq_preds[seq_id]))
    if label == 0:
        label_arr = np.zeros(len(pred_arr))
    else:
        label_arr = np.ones(len(pred_arr))
    correct_arr = (pred_arr == label_arr)
    correct[label] += np.sum(correct_arr)
    total[label] += len(pred_arr)

accuracy = (correct[0] + correct[1]) / float(total[0] + total[1])  

true_negatives = correct[0]
true_positives = correct[1]
false_negatives = total[1] - correct[1]
false_positives = total[0] - correct[0]

precision = true_positives / float(true_positives + false_positives)
recall = true_positives / float(true_positives + false_negatives)

print("True negatives:", true_negatives)
print("True positives:", true_positives)
print("False negatives:", false_negatives)
print("False positives:", false_positives)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)


In [None]:
# Plot prediction mean and std for each validation sequence
seq_means = {0: [], 1: []}
seq_stds = {0: [], 1: []}
for seq_id in tqdm(seq_preds):
    label = seq_id[0]
    seq_means[label].append(np.mean(seq_preds[seq_id]))
    seq_stds[label].append(np.std(seq_preds[seq_id]))
fig = plt.figure(figsize=(5,4))
plt.scatter(seq_means[0], seq_stds[0], label='negative')
plt.scatter(seq_means[1], seq_stds[1], label='positive')
plt.legend()
plt.xlabel('Prediction Mean')
plt.ylabel('Prediction Std')
plt.show()
plt.savefig('/private/groups/brookslab/gabai/projects/addseq_manuscript/figures/Figure4_resnet1d_val_mean_std.pdf', dpi = 1000)
plt.close()

print('Plot accuracy by sequence length for each sequence in validation set...')
seq_lens = {0: [], 1: []}
seq_accs = {0: [], 1: []}
for seq_id in tqdm(seq_preds):
    seq_len = len(seq_preds[seq_id])
    label = seq_id[0]
    if label == 0:
        label_arr = np.zeros(seq_len)
    else:
        label_arr = np.ones(seq_len)
    pred_arr = np.round(np.array(seq_preds[seq_id]))
    correct_arr = (pred_arr == label_arr)
    correct = np.sum(correct_arr)
    seq_acc = float(correct) / seq_len
    seq_lens[label].append(seq_len)
    seq_accs[label].append(seq_acc)
fig = plt.figure(figsize=(5,4))
plt.scatter(seq_lens[0], seq_accs[0], label='negative')
plt.scatter(seq_lens[1], seq_accs[1], label='positive')
plt.legend()
plt.xlabel('Sequence Length')
plt.ylabel('Sequence Accuracy')
plt.savefig('/private/groups/brookslab/gabai/projects/addseq_manuscript/figures/Figure4_resnet1d_val_accurracy_vs_seqlen.pdf', dpi = 1000)
plt.close()

# Compute ROC curve
print('computing roc...')
pred_list = []
label_list = []
for seq_id in tqdm(seq_preds):
    seq_len = len(seq_preds[seq_id])
    label = seq_id[0]
    preds = seq_preds[seq_id]
    if label == 0:
        labels = np.zeros(seq_len)
    else:
        labels = np.ones(seq_len)
    pred_list.append(preds)
    label_list.append(labels)
    
pred_cat = np.concatenate(pred_list)
label_cat = np.concatenate(label_list)

fpr, tpr, thresholds = roc_curve(label_cat, pred_cat)
roc_auc = auc(fpr, tpr)
fig = plt.figure(figsize=(5,4))
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.plot(fpr, tpr, color="darkorange", lw=2,
        label="AUC = %0.2f" % roc_auc)
plt.title("Receiver operating characteristic (ROC)", size = 'medium')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.legend(loc="lower right")
plt.savefig('/private/groups/brookslab/gabai/projects/addseq_manuscript/figures/Figure4_resnet1d_val_roc.pdf', dpi = 1000)
plt.close()
print('best cutoff:', thresholds[np.argmax(tpr - fpr)])

# Calculate kernel density estimate
pos_reads = []
neg_reads = []
for seq_id in tqdm(seq_preds):
    label = seq_id[0]
    if label == 0:
        neg_reads += seq_preds[seq_id]
    else:
        pos_reads += seq_preds[seq_id]
pos_kde = gaussian_kde(pos_reads)
neg_kde = gaussian_kde(neg_reads)
# Create a range of values for x-axis
pos_values = np.linspace(-0.01,1.01, 100)
neg_values = np.linspace(-0.01,1.01, 100)
# Plot the density curve
fig = plt.figure(figsize=(5,4))
plt.plot(pos_values, pos_kde(pos_values), label='Pos Ctrl')
plt.plot(neg_values, neg_kde(neg_values), label='Neg Ctrl')
# Add labels and title
plt.xlabel('Predicted scores')
plt.ylabel('Density')
plt.title('Density of predicted scores')
# Show legend
plt.legend()
# Show the plot
plt.savefig('/private/groups/brookslab/gabai/projects/addseq_manuscript/figures/Figure4_resnet1d_val_density.pdf', dpi = 1000)
plt.close()

### Panel D

In [12]:
import sys
sys.path.insert(0, '/private/groups/brookslab/gabai/tools/NEMO/src/')
from plotUtil import *

In [None]:
cln2_pred  = '../data/240116_resnetv1_CLN2_step20_chrXVI:66000-67600_all_prediction.tsv'
region = 'chrXVI:66000-67600'
pregion = 'chrXVI:66400-67550'
step = 20
chrom = region.split(':')[0]
qstart, qend = int(region.split(':')[1].split('-')[0]), int(region.split(':')[1].split('-')[1])
bins = np.arange(qStart, qEnd, step)

In [None]:
plotAllTrack(cln2_pred, gtfFile, outpath = '../figures/', 
             prefix ='240117_resnetv1_CLN2_3_clusters_seed42_step20', bins = bins, region = region, seed=42,
             pregion = pregion, method = 'pca', vlines = {'tss':66800, 'enh': 67200, 'cds': 66614}, step=20,
             ncluster = 3, savefig = True, trackHeight = 0.8, colorRange=(0.3, 0.55, 0.6), fig_size=(6,4), track_ylim_adjust=10)