In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

In [3]:
import numpy
import torch
import pandas
import matplotlib.pyplot as plt

from data_loading import extract_peaks
from performance_metrics import compute_performance_metrics

In [7]:
expt_name = "stranded"
run = 1

timestamp = "2022-08-16_21-29-37"   # change this to the timestamp for your best model

in_window = 2114
out_window = 1000


proj_root = "/users/kcochran/projects/drosophila_procap/"
sequence_path = proj_root + "refs/genome.fasta"

val_peak_path = proj_root + "/data/overlap_peaks.val.bed.gz"
plus_bw_path = proj_root + "/data/bothreps.bowtie2.filtered.uniq.pos.bigWig"
minus_bw_path = proj_root + "/data/bothreps.bowtie2.filtered.uniq.neg.bigWig"

model_save_dir = proj_root + "models/" + expt_name + "/"
model_path = model_save_dir + timestamp + "_run" + str(run) + ".model"

val_save_dir = proj_root + "model_out/" + expt_name + "/"
os.makedirs(val_save_dir, exist_ok=True) 

pred_counts_path = val_save_dir + timestamp + "_run" + str(run) + "_val.counts.npy"
pred_profiles_path = val_save_dir + timestamp + "_run" + str(run) + "_val.profs.npy"
metrics_path = val_save_dir + timestamp + "_run" + str(run) + "_metrics.tsv"

In [9]:
# Load Model

model = torch.load(model_path)
model.eval()
model = model.cuda()
    
# Load Data

val_sequences, val_profs = extract_peaks(sequence_path, 
    plus_bw_path, minus_bw_path, val_peak_path, in_window, out_window,
    max_jitter=0, verbose=True)

val_profs[val_profs > 5452.0] = 5452.0


# Predict on Validation Set

with torch.no_grad():
    val_sequences = torch.tensor(val_sequences, dtype=torch.float32).cuda()
    pred_profile, pred_counts = model.predict(val_sequences)
    
    
# Save Predictions

numpy.save(pred_profiles_path, pred_profile)
numpy.save(pred_counts_path, pred_counts)


# re-format arrays for performance metrics code

#val_profs = val_profs.reshape(val_profs.shape[0], -1)
val_profs = numpy.swapaxes(numpy.expand_dims(val_profs, 1),2,3)
val_counts = val_profs.sum(axis=2)

#pred_profile = pred_profile.reshape(pred_profile.shape[0], -1)
pred_profile = numpy.swapaxes(numpy.expand_dims(pred_profile, 1),2,3)
pred_counts = numpy.expand_dims(pred_counts, 1)


# Compute Performance Metrics

metrics = compute_performance_metrics(val_profs, pred_profile, 
    val_counts, pred_counts)

metrics_to_save = ["nll", "jsd", "profile_pearson"]
metrics_dict = { metric : list(vals.squeeze()) for metric, vals in metrics.items() if metric in metrics_to_save }
metrics_df = pandas.DataFrame(metrics_dict)
metrics_df.to_csv(metrics_path, sep="\t", index=False)

metrics_to_report = ["nll", "jsd", "profile_pearson", "count_pearson", "count_mse"]
metrics_summary = [str(metrics[metric].mean()) for metric in metrics_to_report] 

print("Peaks: " + val_peak_path)
print("Model: " + model_path)
print("Pred_profiles: " + pred_profiles_path)
print("Pred_counts: " + pred_counts_path)
mean_metrics = ["Mean " + metric + ": " + val for metric, val in zip(metrics_to_report, metrics_summary)]
print("\n".join(mean_metrics))

Reading FASTA: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.53it/s]
Loading Peaks: 4117it [00:04, 1015.15it/s]
  cross_ent = (-log_prob_pows_sum) / trials


Peaks: /users/kcochran/projects/drosophila_procap//data/overlap_peaks.val.bed.gz
Model: /users/kcochran/projects/drosophila_procap/models/stranded/2022-08-16_21-29-37_run1.model
Pred_profiles: /users/kcochran/projects/drosophila_procap/model_out/stranded/2022-08-16_21-29-37_run1_val.profs.npy
Pred_counts: /users/kcochran/projects/drosophila_procap/model_out/stranded/2022-08-16_21-29-37_run1_val.counts.npy
Mean nll: 748.8649046962082
Mean jsd: nan
Mean profile_pearson: 0.2233306852690775
Mean count_pearson: 0.590163
Mean count_mse: 3.5363805
