In [24]:
import torch
import torch.nn as nn

import numpy as np
from collections import defaultdict

%load_ext autoreload
%autoreload 1
    
import sys
if "/Users/brendanofallon/src/jovian/dnaseq2seq" in sys.path:
    sys.path.insert(0, "/Users/brendanofallon/src/jovian/dnaseq2seq")
    
import vcf
import util
import loader
import call

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
modelpath = "/Users/brendanofallon/data/jovian/100M_s28_cont_mapsus_lolr2_epoch4.model"
refpath = "/Users/brendanofallon/data/ref/human_g1k_v37_decoy_phiXAdaptr.fasta.gz"
bam = "/Users/brendanofallon/data/WGS/99702111878_NA12878_1ug.cram"
classifier = "/Users/brendanofallon/data/jovian/s28ce40_bamfix.model"


In [14]:
model = call.load_model(modelpath)

In [8]:
datapath = call.encode_and_save_region(bam, refpath, ".", 
                                       region=("21", 0, 22504100, 22504110),
                                       max_read_depth=150,
                                       window_size=150,
                                       min_reads=5,
                                       batch_size=16,
                                       window_step=25)
data = torch.load(datapath)

In [39]:
START_TOKEN = torch.zeros((1, util.FEATURE_DIM),  dtype=float)
def predict_sequence(src, model, reftoks, n_output_toks, device):
    """
    Generate a predicted sequence with next-word prediction be repeatedly calling the model
    """
    predictions = torch.stack((START_TOKEN, START_TOKEN), dim=0).expand(src.shape[0], -1, -1, -1).float().to(device)
    probs = torch.zeros(src.shape[0], 2, 1).float().to(device)
    mem = model.encode(src)
    for i in range(n_output_toks + 1):
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(predictions.shape[-2]).to(device)
        new_preds = model.decode(mem, predictions, tgt_mask=tgt_mask)[:, :, -1:, :]
        new_probs, tophit = torch.max(new_preds, dim=-1)
        p = torch.nn.functional.one_hot(tophit, num_classes=260)
        predictions = torch.concat((predictions, p), dim=2)
        probs = torch.concat((probs, new_probs), dim=-1)

    return predictions[:, :, 1:, :], probs[:, :, 1:]


In [40]:
data['start_positions']

[22503995, 22504020, 22504045, 22504070]

In [41]:
enc = data['encoded_pileup']
e = enc[0:1, :, :, :]
e.shape

torch.Size([1, 150, 150, 10])

In [42]:
preds, probs = predict_sequence(e, model, None, 10, 'cpu')

In [44]:
preds.shape

torch.Size([1, 2, 11, 260])

In [45]:
ptoks = torch.argmax(preds[0, 1, :, :], dim=-1)

In [46]:
ptoks

tensor([ 15,  28,  67, 252, 204, 201, 204, 204, 225,  28,  92])

In [47]:
pseq = util.kmer_idx_to_str(ptoks, util.i2s)
pseq

'AATTACTACAATTTTATATATAGCTATATATATGACACTACCTA'