In [234]:
import ghmm
from collections import OrderedDict
import cPickle as pickle
import numpy as np
from itertools import product as iterproduct, chain
from pprint import pprint
import pysam
import os
import pandas
from copy import deepcopy
import re
import editdistance
import sys
import math
import random
from nbwrapper import getargs
from multiprocessing import Pool
%run "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/lib/alignment_validation.ipynb"
%run "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/lib/alignment_lib.ipynb"
%run "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/lib/analysis_lib.ipynb"
%run "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/lib/feature_lib.ipynb"

In [235]:
args = getargs()

In [236]:
NMERS = 6

In [237]:
# args = {
#     "events" : "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/loman006-1_100.events.template.pickle",
#     "out_basename" : "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/loman006-1_100.called",
#     "ref": "/home/ibis/gregor.sturm/nanopore/NanoporeData/PublicData/LomanLab_MAP-006/ecoli_mg1655.fa",
#     "hmm_params": "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/loman006-1.model.pickle",
#     "corr_model": "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/context_prediction/models/model-test2.pickle",
#     "ncores": 62,
#     "nmers": NMERS,
#     "multivariate": False
# }
args = {
    "events" : "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/wouter_lambda006_100.events.template.pickle",
    "out_basename" : "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/loman006-1_100.called",
    "ref": "/home/ibis/gregor.sturm/nanopore/own/notebooks/03_pipeline/lambda_ref.fasta",
    "hmm_params": "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/loman006-1.model.pickle",
    "corr_model": "/home/ibis/gregor.sturm/nanopore/own/notebooks/05_MAP006-basecaller/context_prediction/models/model-test2.pickle",
    "ncores": 62,
    "nmers": NMERS,
    "multivariate": False
}


In [238]:
NMERS = int(args["nmers"])
NSTATES = 4**NMERS
MULTIVARIATE = bool(int(args["multivariate"]))
args["ncores"] = int(args["ncores"])

In [239]:
HMM_PARAMS = pickle.load(open(args["hmm_params"], 'rb'))
HMM_PARAMS = HMM_PARAMS["/opt/chimaera/model/r7.3_e6_70bps_6mer/template_median68pA.model"]
ALL_KMERS = ["".join(x) for x in iterproduct("ACGT", repeat=NMERS)]
assert HMM_PARAMS["kmer"].tolist() == ALL_KMERS

# Train Model 

In [174]:
def mk_transmat1(nmers):
    """make a transition matrix assuming move=1"""
    n_components = len(ALL_KMERS)
    transmat = np.empty((n_components, n_components))
    for j, from_kmer in enumerate(ALL_KMERS):
        for i, to_kmer in enumerate(ALL_KMERS):
            p = 1/4. if from_kmer[-(NMERS-1):] == to_kmer[:(NMERS-1)] else 0.
            transmat[j, i] = p          
            
    return transmat.tolist()

In [175]:
def mk_transmat0(nmers):
    """make a transition matrix assuming move=0 or move=1"""
    n_components = len(ALL_KMERS)
    transmat = np.empty((n_components, n_components))
    for j, from_kmer in enumerate(ALL_KMERS):
        for i, to_kmer in enumerate(ALL_KMERS):
            p = 0
            if from_kmer[-(NMERS-1):] == to_kmer[:(NMERS-1)]:
                """move=1"""
                p = (9/10.) * (1/4.) 
            elif from_kmer == to_kmer:
                """move=0"""
                p = (1/10.) * 1
            transmat[j, i] = p          
            
    return transmat.tolist()

In [176]:
def mk_transmat2(nmers):
    """make a transition matrix assuming move=0 or move=1 or move=2"""
    n_components = len(ALL_KMERS)
    transmat = np.empty((n_components, n_components))
    for j, from_kmer in enumerate(ALL_KMERS):
        for i, to_kmer in enumerate(ALL_KMERS):
            p = 0
            if from_kmer[-(NMERS-2):] == to_kmer[:(NMERS-2)]:
                """move=2"""
                p = (2/50.) * (1/16.)
            elif from_kmer[-(NMERS-1):] == to_kmer[:(NMERS-1)]:
                """move=1"""
                p = (47/50.) * (1/4.) 
            elif from_kmer == to_kmer:
                """move=0"""
                p = (1/50.) * 1
            transmat[j, i] = p          
            
    return transmat.tolist()

In [177]:
mk_transmat = mk_transmat2

In [178]:
def mk_model_simple(): 
    """ simple model, only taking the means into account. """
    A = mk_transmat(NMERS)
    B = HMM_PARAMS[["level_mean", "level_stdv"]].values.tolist() #mu, std of each state
    pi = [1/float(NSTATES)] * NSTATES   # initial probabilities per state
    # generate model from parameters
    model = ghmm.HMMFromMatrices(F,ghmm.GaussianDistribution(F), A, B, pi)
    return model

In [179]:
F = ghmm.Float()  # emission domain of this model
def mk_model():
    if MULTIVARIATE: 
        return mk_model_multivariate()
    else: 
        return mk_model_simple()

In [180]:
model = mk_model()
s = str(model)
print(s)

GaussianEmissionHMM(N=4096)
  state 0 (initial=0.00, mu=62.78, sigma=0.84)
    Transitions: ->0 (0.00), ->1 (0.00), ->2 (0.00), ->3 (0.00), ->4 (0.00), ->5 (0.00), ->6 (0.00), ->7 (0.00), ->8 (0.00), ->9 (0.00), ->10 (0.00), ->11 (0.00), ->12 (0.00), ->13 (0.00), ->14 (0.00), ->15 (0.00)
  state 1 (initial=0.00, mu=58.02, sigma=0.66)
    Transitions: ->1 (0.02), ->4 (0.23), ->5 (0.23), ->6 (0.23), ->7 (0.23), ->16 (0.00), ->17 (0.00), ->18 (0.00), ->19 (0.00), ->20 (0.00), ->21 (0.00), ->22 (0.00), ->23 (0.00), ->24 (0.00), ->25 (0.00), ->26 (0.00), ->27 (0.00), ->28 (0.00), ->29 (0.00), ->30 (0.00), ->31 (0.00)

  ...

  state 4094 (initial=0.00, mu=45.36, sigma=0.64)
    Transitions: ->4064 (0.00), ->4065 (0.00), ->4066 (0.00), ->4067 (0.00), ->4068 (0.00), ->4069 (0.00), ->4070 (0.00), ->4071 (0.00), ->4072 (0.00), ->4073 (0.00), ->4074 (0.00), ->4075 (0.00), ->4076 (0.00), ->4077 (0.00), ->4078 (0.00), ->4079 (0.00), ->4088 (0.23), ->4089 (0.23), ->4090 (0.23), ->4091 (0.23), ->409

In [181]:
def result_to_seq(result):
    states = result[0]
    kmers = [ALL_KMERS[x] for x in states]
    seq = [kmer[0] for kmer in kmers] + [kmers[-1][1:]]
    return "".join(seq)

In [182]:
def predict(events):
    """mixed is a set of tuples (event_mean, event_stdv)"""
    emissions = [x[0] for x in events]
    seq = ghmm.EmissionSequence(F, emissions)
    result = model.viterbi(seq)
    return result_to_seq(result)

In [183]:
s = model.sampleSingle(10)
s = [x for x in s]
seq = zip([s[i] for i in range(0, len(s), 2)], [s[i] for i in range(1, len(s), 2)])

In [184]:
predict(seq)

'TAGTGGCGCC'

In [185]:
model

<ghmm.GaussianEmissionHMM at 0x7f2f00435210>

## multistep-prediction

In [240]:
from sklearn import ensemble
import joblib
import mltools

In [241]:
corr_model = joblib.load(args["corr_model"])

In [242]:
OFFSET = 20

In [270]:
def get_correction(features):
    correction = corr_model.predict(features)
    i_min = np.argmin([abs(x) for x in correction])
    correction = mltools.normalize_between(correction, -.5, .5)
    shift = correction[i_min]
    correction = [x - shift for x in correction]
    return correction

In [252]:
def get_features(events, seq): 
    corr_range = (OFFSET, len(seq)-OFFSET-NMERS)
 
    features = []
    for i in range(*corr_range):
        mean, stdv = events[i]
        context = seq[i-OFFSET:i+NMERS+OFFSET]
        assert seq[i:i+6] == context[20:26]
        features.append(mk_feature_all(mean, stdv, context))

    features = postprocess_features(features)
    
    return features

In [271]:
def correct_events(events, seq): 
    corr_range = (OFFSET, len(seq)-OFFSET-NMERS)
 
    features = get_features(events, seq)
    correction = get_correction(features)
    for i, j  in enumerate(range(*corr_range)):
        mean, stdv = events[j]
        tmp_mean = mean - correction[i]
        ratio = tmp_mean/mean
        tmp_stdv = ratio * stdv
        events[j] = tmp_mean, tmp_stdv

    return events

In [272]:
def predict_iterative(events, n_steps=5): 
    for _ in range(n_steps):
        seq = predict(events)
        if _ < n_steps-1: 
            events = correct_events(events, seq)
    return seq

In [273]:
file_data = pickle.load(open(args["events"], 'rb'))
file_data = [f for f in file_data if f is not None]

In [274]:
file_obj = correct_read(file_data[0], col="mean")
events = [(x["mean"], x["stdv"]) for x in file_obj["events"].to_dict("records")]

In [275]:
seq = predict(events)

In [276]:
seq[18:40]

'GATTAACGTGGCGCAGATTCGG'

In [277]:
events[18:30]

[(55.39867036449739, 1.3852944632352422),
 (60.686479184018104, 0.8191980317255078),
 (55.58010979370342, 0.6470169739678482),
 (59.338799977196054, 1.2207921358584932),
 (52.37617320636567, 0.5578358994299949),
 (50.49390191853866, 0.6509070383102106),
 (53.193770634269384, 0.583154757922479),
 (58.19354351144337, 0.9033973004060635),
 (63.71277754231278, 0.6658554436211722),
 (61.8072902034906, 0.5674972037985707),
 (67.49056713916954, 1.3460100491010998),
 (65.73566348616725, 0.6411588641045136)]

In [278]:
features = get_features(events, seq)
[x for x in enumerate(features[0])]

[(0, 0.4489413386974811),
 (1, 0.10150871820044287),
 (2, 0),
 (3, 1),
 (4, 0),
 (5, 0),
 (6, 0),
 (7, 1),
 (8, 0),
 (9, 0),
 (10, 0),
 (11, 1),
 (12, 0),
 (13, 0),
 (14, 0),
 (15, 0),
 (16, 0),
 (17, 1),
 (18, 0),
 (19, 0),
 (20, 0),
 (21, 1),
 (22, 0),
 (23, 0),
 (24, 0),
 (25, 1),
 (26, 0),
 (27, 1),
 (28, 0),
 (29, 0),
 (30, 0),
 (31, 0),
 (32, 1),
 (33, 0),
 (34, 0),
 (35, 0),
 (36, 0),
 (37, 1),
 (38, 0),
 (39, 0),
 (40, 0),
 (41, 1),
 (42, 0),
 (43, 0),
 (44, 0),
 (45, 1),
 (46, 0),
 (47, 0),
 (48, 0),
 (49, 1),
 (50, 0),
 (51, 0),
 (52, 1),
 (53, 0),
 (54, 0),
 (55, 0),
 (56, 1),
 (57, 0),
 (58, 0),
 (59, 0),
 (60, 0),
 (61, 1),
 (62, 0),
 (63, 0),
 (64, 1),
 (65, 0),
 (66, 0),
 (67, 1),
 (68, 0),
 (69, 0),
 (70, 0),
 (71, 0),
 (72, 0),
 (73, 1),
 (74, 0),
 (75, 0),
 (76, 1),
 (77, 0),
 (78, 1),
 (79, 0),
 (80, 0),
 (81, 0),
 (82, 0),
 (83, 0),
 (84, 0),
 (85, 1),
 (86, 0),
 (87, 0),
 (88, 0),
 (89, 1),
 (90, 1),
 (91, 0),
 (92, 0),
 (93, 0),
 (94, 1),
 (95, 0),
 (96, 0),
 (97,

In [279]:
correction = corr_model.predict(features)
i_min = np.argmin([abs(x) for x in correction])
correction = mltools.normalize_between(correction, -.5, .5)
shift = correction[i_min]
correction = [x - shift for x in correction]
correction

[0.039291781364663991,
 0.057211576490652394,
 0.046722957474960303,
 -0.0048799989126879773,
 0.0048976354112869602,
 0.039863685604725807,
 0.068242534056340287,
 0.097276967164115657,
 0.12564511413966734,
 0.087245911786453467,
 0.086135591116433707,
 0.027252272942833666,
 0.00014055641369381533,
 0.014026946177669586,
 0.082784726803055708,
 0.0059986423749062157,
 0.051074668952371838,
 0.12229809548929171,
 0.19610895227276554,
 0.11060871392137234,
 0.20060761956812218,
 0.0096344240790190439,
 0.15626639798217962,
 0.20568505606225507,
 0.0055133055088059657,
 0.15450940442474048,
 0.077429156575923197,
 0.044322001084346196,
 0.029421104534374021,
 0.32552676156322496,
 0.18681732351401009,
 0.073406573826504895,
 0.018419158235339128,
 0.0043922835484535017,
 0.21054453347493735,
 0.099765728890016114,
 0.12058103528437419,
 0.29121888489927544,
 0.00023365807981751141,
 0.032242720701280869,
 0.19760685086892255,
 0.060483882287316049,
 0.040557046473721936,
 0.02660392398

In [280]:
events = correct_events(events, seq)
events[18:30]

[(55.39867036449739, 1.3852944632352422),
 (60.686479184018104, 0.8191980317255078),
 (55.540818012338761, 0.64655957203800807),
 (59.281588400705402, 1.2196151076292958),
 (52.329450248890709, 0.55733827347888953),
 (50.498781917451346, 0.65096994542410991),
 (53.188872998858095, 0.58310106593263489),
 (58.153679825838644, 0.9027784560500316),
 (63.64453500825644, 0.66514224817528633),
 (61.710013236326489, 0.56660403396896186),
 (67.36492202502987, 1.3435042235699004),
 (65.648417574380801, 0.64030790304727048)]

# Validate Model 

In [281]:
!pwd

/home/ibis/gregor.sturm/nanopore/own/notebooks


In [282]:
assert os.path.isfile(args["events"])

In [283]:
ref = load_ref(args["ref"])

['>gi|556503834|ref|NC_000913.3| Escherichia coli str. K-12 substr. MG1655, complete genome']
AGCTTTTCATTCTGACTGCAACGGGCAATATGTCTCTGTGTGGATTAAAAAAAGAGTGTCTGATAGCAGCTTCTGAACTGGTTACCTGCCGTGAGTAAAT


In [284]:
file_data = pickle.load(open(args["events"], 'rb'))
file_data = [f for f in file_data if f is not None]

In [285]:
prepare_filemap(file_data)

In [286]:
def basecall_read(file_obj):
    file_obj = correct_read(file_obj, col="mean")
    events = [(x["mean"], x["stdv"]) for x in file_obj["events"].to_dict("records")]
    called_seq = predict_iterative(events)
    return (file_obj["channel"], file_obj["file_id"], called_seq)

In [287]:
# """ train with baum-welch """
# for i, file_obj in enumerate(file_data): 
#     sys.stdout.write('\rdone {0:%}'.format(i/float(len(file_data))))
#     train_read(file_obj)

In [288]:
p = Pool(args["ncores"])

In [None]:
#basecall_read(file_data[3])

In [None]:
print("Prediction: ")
results = []
try:
    for i, res in enumerate(p.imap_unordered(basecall_read, file_data), 1):
        results.append(res)
        sys.stdout.write('\rdone {0:%}'.format(i/float(len(file_data))))
    p.close()
    p.join()
except KeyboardInterrupt:
    p.terminate()

### Stats

In [None]:
types = ["metrichor", "called", "random"]
fasta_files = {t: "{0}.{1}.fa".format(args["out_basename"], t) for t in types}

In [None]:
## metrichor fasta
with open(fasta_files["metrichor"], 'w') as f: 
    for file_obj in file_data: 
        f.write(">ch{0}_file{1}_metrichor".format(file_obj["channel"], file_obj["file_id"])+ "\n")
        f.write(file_obj["fastq"].split("\n")[1] + "\n")

In [None]:
## called fasta/random fasta
with open(fasta_files["called"], 'w') as f: 
    with open(fasta_files["random"], 'w') as fr:
        for channel, file_id, seq in results: 
            f.write(">ch{0}_file{1}_called".format(channel, file_id)+ "\n")
            fr.write(">ch{0}_file{1}_random".format(channel, file_id)+ "\n")
            f.write(seq + "\n")
            fr.write("".join([random.choice("ACGT") for _ in range(len(seq))]))

In [None]:
for t in types: 
    sam_file = "{0}.{1}.sam".format(args["out_basename"], t)
    graphmap(args["ref"], fasta_files[t], sam_file, args["ncores"])
    prepare_sam("{0}.{1}".format(args["out_basename"], t))

In [None]:
def mk_stat(t):
    samfile = "{0}.{1}.sorted.bam".format(args["out_basename"], t)
    sst = samstats(samfile, ref, ncores=args["ncores"])
    return pandas.DataFrame(sst.print_summary())

In [None]:
stats = map(mk_stat, types)
print(types)
side_by_side(*stats)

In [None]:
# for t, df in zip(types, stats):
#     with open("{0}.stats.{1}.html".format(args["out_basename"], t), 'w') as f:
#         f.write(df.to_html())

In [None]:
# def score_consensus(t):
#     consensus = mk_consensus("{0}.{1}.sorted.bam".format(args["out_basename"], t), ref_file)
#     return(consensus)
#     consensus = consensus.split("\n")[1].to_upper()
#     score = needle(ref, consensus)
#     return (consensus, score)

In [None]:
# p = Pool(args["ncores"])
# try:
#     consensus = p.map(score_consensus, types)
#     p.close()
# except KeyboardInterrupt:
#     p.terminate()

In [None]:
# consensus

In [None]:
# mk_consensus("{0}.{1}.sorted.bam".format(args["out_basename"], "metrichor"), ref_file)