In [321]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import math
import copy
import pysam
import scipy.stats as stats
import random
%matplotlib inline

In [322]:
torch.__version__

'1.9.0'

In [323]:
if hasattr(torch, 'cuda') and torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
device

device(type='cuda', index=0)

In [337]:

STOP_TOKEN_INDEX = 4

INDEX_TO_BASE = [
    'A', 'C', 'G', 'T'
]

def base_index(base):
    if base=='A':
        return 0
    elif base=='C':
        return 1
    elif base=='G':
        return 2
    elif base=='T':
        return 3
    raise ValueError("Expected [ACTG]")
    
def update_from_base(base, tensor):
    if base=='A':
        tensor[0] = 1
    elif base=='C':
        tensor[1] = 1
    elif base=='G':
        tensor[2] = 1
    elif base=='T':
        tensor[3] = 1
    elif base=='N':
        tensor[0:3] = 0.25
    elif base=='-':
        tensor[0:3] = 0.0
    return tensor

def encode_basecall(base, qual, cigop):
    ebc = torch.zeros(6)
    ebc = update_from_base(base, ebc)
    ebc[4] = qual / 100 - 0.5
    ebc[5] = cigop
    return ebc

def encode_cigop(readpos, refpos):
    if readpos == refpos:
        return 0
    elif readpos is None:
        return -1
    elif refpos is None:
        return 1
    return 0

def variantrec_to_tensor(rec):
    seq = []
    for readpos, refpos in rec.get_aligned_pairs():
        if readpos is not None and refpos is not None:
            seq.append(encode_basecall(rec.query_sequence[readpos], rec.query_qualities[readpos], encode_cigop(readpos, refpos)))
        elif readpos is None and refpos is not None:
            seq.append(encode_basecall('-', 50, encode_cigop(readpos, refpos)))  # Deletion
        elif readpos is not None and refpos is None:
            seq.append(encode_basecall(rec.query_sequence[readpos], rec.query_qualities[readpos], encode_cigop(readpos, refpos)))  # Insertion
        
    return torch.vstack(seq)


def string_to_tensor(bases):
    return torch.vstack([encode_basecall(b, 50, 0) for b in bases])


def target_string_to_tensor(bases, include_stop=True):
    """
    The target version doesn't include the qual or cigop features
    """
    result = torch.tensor([base_index(b) for b in bases]).long()
    return result

In [349]:
from torch.nn.utils.rnn import pad_sequence

def pad_zeros(pre, data, post):
    if pre:
        prepad = torch.zeros(pre, data.shape[-1])
        data = torch.cat((prepad, data))
    if post:
        postpad = torch.zeros(post, data.shape[-1])
        data = torch.cat((data, postpad))
    return data
    
def tensors_from_seq(refseq, numreads, readlen, error_rate=0.0, align_to_ref=True):
    seqs = []
    for i in range(numreads):
        startpos = random.randint(0, len(refseq) - readlen)
        if align_to_ref:
            seqs.append(
                pad_zeros(startpos, 
                      string_to_tensor(mutate_seq(refseq[startpos:startpos+readlen], error_rate)),
                      len(refseq) - startpos - readlen)
            )
        else:
            seqs.append(
                string_to_tensor(mutate_seq(refseq[startpos:startpos+readlen], error_rate)),
            )
    return torch.stack(seqs)


def mutate_seq(seq, error_rate):
    output = []
    for b in seq:
        if np.random.rand() < error_rate:
            c = random.choice('ACTG')
            while c == b:
                c = random.choice('ACTG')
        else:
            c = b
        output.append(c)
    return "".join(output)
                

def random_bases(n):
    return "".join(random.choices("ACTG", k=n))

In [392]:
def pad_and_tensorize_reads(reads):
    minref = min(r.reference_start for r in reads)
    maxref = max(r.reference_start + r.query_length for r in reads)
    reflen = maxref - minref
    print(f"{minref} - {maxref}")
    tensors = [
        pad_zeros(r.reference_start - minref, variantrec_to_tensor(r), maxref - (r.reference_start + r.query_length))
        for r in reads
    ]
    for r, t in zip(reads, tensors):
        print(f"{r.cigarstring} {r.reference_start - minref}, {maxref - r.reference_end} \t{t.shape}")
    return torch.stack(tensors)

pad_and_tensorize_reads(reads).shape

27023285 - 27023435
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 0, 4 	torch.Size([150, 6])
146M 1, 3 	torch.Size([150, 6])
146M 1, 3 	torch.Size([150, 6])
146M 1, 3 	torch.Size([150, 6])
146M 1, 3 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 2, 2 	torch.Size([150, 6])
146M 3, 1 	torch.Size([150, 6])
146M 3, 1 	torch.Siz

torch.Size([50, 150, 6])

In [390]:
def tensor_pileup(bam, chrom, start, maxnumreads):
    bamit = bam.fetch(chrom, start)
    reads = []
    num_reads = 50
    while len(reads) < num_reads:
        reads.append(next(bamit))
    return pad_and_tensorize_reads(reads)
    

In [400]:
tensor_pileup(bam, "1", 27023400, 50)

27023255 - 27023404
146M 0, 3 	torch.Size([149, 6])
146M 0, 3 	torch.Size([149, 6])
146M 0, 3 	torch.Size([149, 6])
121M3D25M 0, 0 	torch.Size([152, 6])
146M 0, 3 	torch.Size([149, 6])
146M 0, 3 	torch.Size([149, 6])
146M 0, 3 	torch.Size([149, 6])
146M 0, 3 	torch.Size([149, 6])
146M 0, 3 	torch.Size([149, 6])
121M3D25M 0, 0 	torch.Size([152, 6])
146M 0, 3 	torch.Size([149, 6])
146M 0, 3 	torch.Size([149, 6])
146M 0, 3 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 1, 2 	torch.Size([149, 6])
146M 2, 1 	torch.Size([149, 6])
146M 2, 1 	torch.Size([149, 6])
146M 2, 1 

RuntimeError: stack expects each tensor to be equal size, but got [149, 6] at entry 0 and [152, 6] at entry 3

In [338]:
target_string_to_tensor("ACTG")

tensor([0, 1, 3, 2])

In [425]:
bam = pysam.AlignmentFile("/home/brendan/Public/genomics/onccn_15_car641/bam/roi.bam")
bamit = bam.fetch("1", 27022800)

reads = []
num_reads = 500
while len(reads) < num_reads:
    reads.append(next(bamit))

    

In [430]:
def get_read_by_name(bam, name):
    results = []
    for read in bam:
        if read.query_name == name:
            results.append(read)
    return results
readhits = get_read_by_name(reads, "A00576:57:HCN2WDRXX:1:1115:28646:35070:CTGTG+ATCTC")
print(readhits)

[<pysam.libcalignedsegment.AlignedSegment object at 0x7f6cd738ae20>, <pysam.libcalignedsegment.AlignedSegment object at 0x7f6cd7386f40>]


In [427]:
variantrec_to_tensor(read).shape

torch.Size([146, 6])

In [474]:
def iterate_cigar(rec):
    cigtups = rec.cigartuples
    bases = rec.query_sequence
    quals = rec.query_qualities
    cig_index = 0
    n_bases_cigop = cigtups[cig_index][1]
    cigop = cigtups[cig_index][0]
    is_ref_consumed = cigop in {0,   2,    4, 5, 7} # 2 is deletion
    is_seq_consumed = cigop in {0, 1,   3, 4, 5, 7} # 1 is insertion, 3 is 'ref skip'
    base_index = 0
    refstart = rec.query_alignment_start - (n_bases_cigop if cigop in {4,5} else 0)
    refpos = refstart
    while True:
        reftok = refpos if is_ref_consumed else "-"
        if is_seq_consumed:
            base = bases[base_index]
            qual = quals[base_index]
            base_index += 1
        else:
            base = "-"
        if is_ref_consumed:
            refpos += 1
    
        print(f"{base}\t{reftok}\t cig op: {cigop} num bases left in cig op: {n_bases_cigop}")
        encoded_cig = 0
        if is_ref_consumed and is_seq_consumed:
            encoded_cig = 0
        elif is_ref_consumed and not is_seq_consumed:
            encoded_cig = -1
        else:
            encoded_cig = 1
        yield encode_basecall(base, qual, encoded_cig), is_ref_consumed
        n_bases_cigop -= 1
        if n_bases_cigop <= 0:
            cig_index += 1
            if cig_index >= len(cigtups):
                break
            n_bases_cigop = cigtups[cig_index][1]
            cigop = cigtups[cig_index][0]
            is_ref_consumed = cigop in {0, 2, 4, 5, 7}
            is_seq_consumed = cigop in {0, 1, 3, 4, 5, 7}

def rec_tensor_it(read, minref, maxref):
    for i in range(read.reference_start - minref):
        yield encode_basecall('-', 50, 0), True

    try:
        for t in iterate_cigar(read):
            yield t 
    except StopIteration:
        pass
    
    for i in range(maxref - read.reference_start - read.query_length):
        yield encode_basecall('-', 50, 0), True
        
print(read.cigarstring)
# for t in rec_tensor_it(read, read.reference_start-5, read.reference_start + read.query_length+5):
#     print(t)

27S28M2I89M


In [None]:


def iterate_aligned_reads(reads):
    EMPTY_TENSOR = encode_basecall('-', 50, 0)
    minref = min(r.reference_start for r in reads)
    maxref = max(r.reference_start + r.query_length for r in reads)
    its = [rec_tensor_it(r, minref, maxref) for r in reads]
    refpos = minref
    pos_tensors = [next(it) for it in its]
    
    while refpos < maxref:
        any_insertion = any(r[1] for r in pos_tensors)
        thispos = []
        while any_insertion:
            for i, (it, pos_tensor) in enumerate(zip(its, pos_tensors)):
                if ins:
                    thispos.append(pos_tensor[0])
                    pos_tensors[i] = next(it)
                else:
                    thispos.append(EMPTY_TENSOR)
        
            any_insertion = any(r[1] for r in pos_tensors)
    
        all_stacked = torch.stack(thispos)
        refpos += 1
        for i, (it, pos_tensor) in enumerate(zip(its, pos_tensors)):
            thispos.append(pos_tensor[0])
            pos_tensors[i] = next(it)
        
        
        
    

In [368]:


def make_batch(batchsize, seqlen, readsperbatch, readlength, error_rate=0):
    src = []
    tgt = []
    for i in range(batchsize):
        seq = random_bases(seqlen)
        reads = tensors_from_seq(seq, readsperbatch, readlength, error_rate, align_to_ref=True)
        src.append(reads)
        tgt.append(target_string_to_tensor(seq))
    return torch.stack(src).transpose(1,2).to(device), torch.stack(tgt).to(device)

In [799]:



def make_het_snv_batch(batchsize, seqlen, readsperbatch, readlength, error_rate):
    src = []
    tgt = []
    for i in range(batchsize):
        seq = random_bases(seqlen)
        reads, altseq = make_het_snv(seq, readlength, readsperbatch, 0.5, error_rate)
        src.append(reads)
        
        
        alt_t = target_string_to_tensor(altseq)
        seq_t = target_string_to_tensor(seq)
        x = torch.stack((seq_t, alt_t))
        tgt.append(x)
    return torch.stack(src).transpose(1,2).to(device), torch.stack(tgt).to(device)
    
src, tgt = make_het_snv_batch(5, 25, 20, 16, 0)
print(to_pileup(src[0, :, :]))

...ATGCTTGTAGCACTAC......
.GAATGCTTGTAGTACT........
.........GTAGCACTACGATGAT
....TGCTTGTAGCACTACG.....
....TGCTTGTAGCACTACG.....
.....GCTTGTAGTACTACGA....
..AATGCTTGTAGCACTA.......
.....GCTTGTAGCACTACGA....
.GAATGCTTGTAGCACT........
.......TTGTAGCACTACGATG..
TGAATGCTTGTAGCAC.........
.....GCTTGTAGTACTACGA....
.....GCTTGTAGCACTACGA....
.GAATGCTTGTAGTACT........
..AATGCTTGTAGTACTA.......
.....GCTTGTAGTACTACGA....
......CTTGTAGTACTACGAT...
TGAATGCTTGTAGTAC.........
.....GCTTGTAGCACTACGA....
....TGCTTGTAGTACTACG.....


In [1015]:

def stack_refalt_tensrs(refseq, altseq, readlength, totreads, vaf=0.5):
    assert len(refseq) == len(altseq), f"Sequences must be the same length (got {len(refseq)} and {len(altseq)})"
    num_altreads = stats.binom(totreads - 2, vaf).rvs(1)[0] + 1
    reftensors = tensors_from_seq(refseq, totreads-num_altreads, readlength, error_rate)
    alttensors = tensors_from_seq(altseq, num_altreads, readlength, error_rate)
    combined = torch.cat([reftensors, alttensors])
    idx = np.random.permutation(totreads)
    combined[range(totreads)] = combined[idx]
    return combined, altseq

def make_het_snv(seq, readlength, totreads, vaf, error_rate):
    snvpos = random.choice(range(max(0, len(seq) // 2 - 8), min(len(seq), len(seq) // 2 + 8)))
    altseq = list(seq)
    altseq[snvpos] = random.choice('ACTG')
    while altseq[snvpos] == seq[snvpos]:
        altseq[snvpos] = random.choice('ACTG')
    altseq = "".join(altseq)
    return stack_refalt_tensrs(seq, altseq, readlength, totreads, vaf)

def make_het_del(seq, readlength, totreads, vaf=0.5):
    del_len = random.choice(range(10))
    delpos = random.choice(range(max(0, len(seq) // 2 - 8), min(len(seq) - del_len, len(seq) // 2 + 8)))
    ls = list(seq)
    for i in range(del_len):
        del ls[delpos]
    altseq = "".join(ls + ["A"] * del_len)
    return stack_refalt_tensrs(seq, altseq, readlength, totreads, vaf)


def make_het_ins(seq, readlength, totreads, vaf=0.5):
    ins_len = random.choice(range(10)) + 1
    inspos = random.choice(range(max(0, len(seq) // 2 - 8), min(len(seq) - ins_len, len(seq) // 2 + 8)))
    altseq = "".join(seq[0:inspos]) + "".join(random.choices("ACTG", k=ins_len)) + "".join(seq[inspos:-ins_len])
    altseq = altseq[0:len(seq)]
    return stack_refalt_tensrs(seq, altseq, readlength, totreads, vaf)
    

def make_het_ins_batch(batchsize, seqlen, readsperbatch, readlength, error_rate):
    src = []
    tgt = []
    for i in range(batchsize):
        seq = [b for b in random_bases(seqlen)]
        reads, altseq = make_het_ins(seq, readlength, readsperbatch, vaf=0.5)
        src.append(reads)
        alt_t = target_string_to_tensor(altseq)
        seq_t = target_string_to_tensor(seq)
        x = torch.stack((seq_t, alt_t))
        tgt.append(x)
    return torch.stack(src).transpose(1,2).to(device), torch.stack(tgt).to(device)



def make_het_del_batch(batchsize, seqlen, readsperbatch, readlength, error_rate):
    src = []
    tgt = []
    for i in range(batchsize):
        seq = [b for b in random_bases(seqlen)]
        reads, altseq = make_het_del(seq, readlength, readsperbatch, vaf=0.5)
        src.append(reads)
        alt_t = target_string_to_tensor(altseq)
        seq_t = target_string_to_tensor(seq)
        x = torch.stack((seq_t, alt_t))
        tgt.append(x)
    return torch.stack(src).transpose(1,2).to(device), torch.stack(tgt).to(device)
    
    

# reads, seq = make_het_del([b for b in random_bases(20)], 15, 10)
src, tgt = make_het_ins_batch(5, 30, 10, 20, error_rate=0)
# print(to_pileup(src[0, :, :, :]))
# # src, tgt = make_het_ins("AAAAAAAAAAAAAAA", 8, 10)
# # print(to_pileup(src[:, :, :].transpose(0,1)))
src.shape

torch.Size([5, 30, 10, 6])

In [1017]:
def make_mixed_batch(size, seqlen, readsperbatch, readlength, error_rate):
    snv_w = 9 # Bigger values here equal less variance among sizes
    del_w = 8
    ins_w = 8
    mix = np.random.dirichlet((snv_w, del_w, ins_w)) * size
    snv_src, snv_tgt = make_het_snv_batch(int(mix[0]), seqlen, readsperbatch, readlength, error_rate)
    del_src, del_tgt = make_het_del_batch(int(mix[1]), seqlen, readsperbatch, readlength, error_rate)
    ins_src, ins_tgt = make_het_ins_batch(int(mix[1]), seqlen, readsperbatch, readlength, error_rate)
    return torch.cat((snv_src, del_src, ins_src)), torch.cat((snv_tgt, del_tgt, ins_tgt))
src, tgt = make_mixed_batch(100, 30, 10, 18, 0)
src.shape

torch.Size([95, 30, 10, 6])

In [942]:
def sort_by_ref(seq, reads):
    results = []
    for batch in range(reads.shape[0]):
        w = reads[batch, :, :, 0:4].sum(dim=-1)
        t = reads[batch, :, :, 0:4].argmax(dim=-1)
        matchsum = (t == (seq[batch, 0, :].repeat(reads.shape[2], 1).transpose(0,1)*w).long()).sum(dim=0)
        results.append(reads[batch, :, torch.argsort(matchsum), :])
    return torch.stack(results)

print(to_pileup(src[1,:,:,:]))
srt_result = sort_by_ref(tgt, src)
print("Sorted:")
print(to_pileup(srt_result[1, :, :, :]))


.........AATGGGCGAAATCGGGCCAC.
AGCCGCATCAATGGGCGAAA..........
......ATCAATGGGCGATCGGGCCA....
....GCATCAATGGGCGATCGGGC......
AGCCGCATCAATGGGCGAAA..........
.......TCAATGGGCGATCGGGCCAC...
...CGCATCAATGGGCGATCGGG.......
.GCCGCATCAATGGGCGAAAT.........
........CAATGGGCGATCGGGCCACT..
....GCATCAATGGGCGATCGGGC......
Sorted:
........CAATGGGCGATCGGGCCACT..
.......TCAATGGGCGATCGGGCCAC...
......ATCAATGGGCGATCGGGCCA....
....GCATCAATGGGCGATCGGGC......
....GCATCAATGGGCGATCGGGC......
...CGCATCAATGGGCGATCGGG.......
AGCCGCATCAATGGGCGAAA..........
.........AATGGGCGAAATCGGGCCAC.
AGCCGCATCAATGGGCGAAA..........
.GCCGCATCAATGGGCGAAAT.........


In [930]:
torch.argsort(matchsum)

tensor([1, 2, 4, 3, 9, 0, 8, 7, 5, 6], device='cuda:0')

In [442]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [1041]:

class TwoHapDecoder(nn.Module):
    
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, 128)
        self.fc2 = nn.Linear(128, out_dim)
        self.fc3 = nn.Linear(128, out_dim)
        self.elu = nn.ELU()
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, x):
        x = self.elu(self.fc1(x))
        x1 = self.softmax(self.fc2(x))
        x2 = self.softmax(self.fc3(x))
        return x1, x2

class VarTransformer(nn.Module):
    
    def __init__(self, in_dim, out_dim, nhead=6, d_hid=256, n_encoder_layers=2, p_dropout=0.1):
        super().__init__()
        self.embed_dim = nhead * 12 # Was 24
        self.fc1 = nn.Linear(in_dim, self.embed_dim)
        self.pos_encoder = PositionalEncoding(self.embed_dim, p_dropout)
        encoder_layers = nn.TransformerEncoderLayer(self.embed_dim, nhead, d_hid, p_dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_encoder_layers)
        self.decoder = TwoHapDecoder(self.embed_dim, out_dim)
        self.elu = torch.nn.ELU()
    
    def forward(self, src):
        src = self.elu(self.fc1(src))
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src)
        output = self.decoder(output)
        return output
        

In [1042]:
num_reads = 16
feats_per_read = 6 
ref_seq_length = 30
read_length = 15
src, tgt = make_het_snv_batch(5, seqlen=ref_seq_length, readsperbatch=num_reads, readlength=read_length, error_rate=0)
print(src.shape)
vt = VarTransformer(in_dim=num_reads * feats_per_read, out_dim=4).to(device)
seq1, seq2 = vt(src.flatten(start_dim=2))
print(seq1.shape)
print(seq2.shape)

torch.Size([5, 30, 16, 6])
torch.Size([5, 30, 4])
torch.Size([5, 30, 4])


In [1043]:
tgt.shape

torch.Size([5, 2, 30])

In [1044]:
tgt_seq1 = tgt[:, 0, :]
tgt_seq1.shape

torch.Size([5, 30])

In [1051]:
ref_seq_length = 100
read_length = 70

In [1046]:
num_reads = 25
feats_per_read = 6 
in_dim = num_reads * feats_per_read
print(f"Input dimension: {in_dim}")
model = VarTransformer(in_dim=in_dim, out_dim=4, nhead=5, d_hid=200, n_encoder_layers=2).to(device)


Input dimension: 150


In [1047]:
print(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad)} params")

Model has 96180 params


In [1048]:
criterion = nn.CrossEntropyLoss()
# mbloss = MatchingBasesLoss()
# divloss = nn.KLDivLoss()
lr = 0.001 # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)


In [1052]:
steps = 50
batch_size = 128
error_rate = 0.02


for step in range(steps):    
    batchsrc, batchtgt = make_mixed_batch(batch_size * 10, seqlen=ref_seq_length, readsperbatch=num_reads, readlength=read_length, error_rate=error_rate)
#     batchsrc, batchtgt = make_het_del_batch(batch_size * 10, seqlen=ref_seq_length, readsperbatch=num_reads, readlength=read_length, error_rate=0.02)
    batchoffset = 0
    epoch_loss_sum = 0
    while batchoffset < batchsrc.shape[0]:
        unsorted_src = batchsrc[batchoffset:batchoffset + batch_size, :, :, :]
        tgt = batchtgt[batchoffset:batchoffset + batch_size, :, :]
        src = sort_by_ref(tgt, unsorted_src)
        batchoffset += batch_size
        optimizer.zero_grad()

        tgt_seq1 = tgt[:, 0, :]
        tgt_seq2 = tgt[:, 1, :]

        seq1preds, seq2preds = model(src.flatten(start_dim=2))

        loss = criterion(seq1preds.flatten(start_dim=0, end_dim=1), tgt_seq1.flatten())
        loss += 2*criterion(seq2preds.flatten(start_dim=0, end_dim=1), tgt_seq2.flatten())

        with torch.no_grad():
            matches1 = (torch.argmax(seq1preds.flatten(start_dim=0, end_dim=1), dim=1) == tgt_seq1.flatten()).float().mean()
            matches2 = (torch.argmax(seq2preds.flatten(start_dim=0, end_dim=1), dim=1) == tgt_seq2.flatten()).float().mean()
            
        loss.backward(retain_graph=True)
        optimizer.step()
        epoch_loss_sum += loss.detach().item()
        
        
    print(f"Step: {step} loss: {(epoch_loss_sum) / 10:.4f} frac ref matching bases ref / alt: {matches1.item():.4f} / {matches2.item():.4f}" )
    epoch_loss_sum = 0
    
    

Step: 0 loss: 2.2867 frac ref matching bases ref / alt: 0.9726 / 0.9743
Step: 1 loss: 2.2918 frac ref matching bases ref / alt: 0.9645 / 0.9776
Step: 2 loss: 2.2908 frac ref matching bases ref / alt: 0.9749 / 0.9766
Step: 3 loss: 2.7463 frac ref matching bases ref / alt: 0.9744 / 0.9804
Step: 4 loss: 2.7422 frac ref matching bases ref / alt: 0.9730 / 0.9819
Step: 5 loss: 2.5140 frac ref matching bases ref / alt: 0.9708 / 0.9777
Step: 6 loss: 2.2879 frac ref matching bases ref / alt: 0.9658 / 0.9808
Step: 7 loss: 2.7435 frac ref matching bases ref / alt: 0.9721 / 0.9769
Step: 8 loss: 2.2844 frac ref matching bases ref / alt: 0.9753 / 0.9782
Step: 9 loss: 1.6007 frac ref matching bases ref / alt: 0.9653 / 0.9747
Step: 10 loss: 2.0583 frac ref matching bases ref / alt: 0.9726 / 0.9794
Step: 11 loss: 2.9750 frac ref matching bases ref / alt: 0.9766 / 0.9773
Step: 12 loss: 2.2863 frac ref matching bases ref / alt: 0.9713 / 0.9776
Step: 13 loss: 2.5132 frac ref matching bases ref / alt: 0.97

KeyboardInterrupt: 

tensor(0.7090, device='cuda:0')

In [1050]:
which = 2
def readstr(t):
    t = t.detach().cpu().numpy()
    assert len(t.shape) == 2, "Need two dimensional input"
    output = []
    for pos in range(t.shape[0]):
        if t[pos, :].sum() == 0:
            output.append(".")
        else:
            output.append(INDEX_TO_BASE[np.argmax(t[pos, :])])

    return "".join(output)

def to_pileup(data):
    pileup = []
    for i in range(data.shape[1]):
        pileup.append(readstr(data[:, i, :]))
    return "\n".join(pileup)

def predprobs(t):
    t = t.detach().cpu().numpy()
    output = []
    for pos in range(t.shape[0]):
        if t[pos, :].sum() == 0:
            output.append(".")
        else:
            output.append(f"{t[pos, np.argmax(t[pos, :])]:.1}f")
    return "".join(output)

def correctstr(seq, predseq):
    seq = "".join(INDEX_TO_BASE[b] for b in seq)
    output = "".join('*' if a==b else 'x' for a,b in zip(seq, predseq))
    return output
    

print(to_pileup(src[which, :, :, :]))

print("")
predstr = readstr(seq1preds[which, :, :])
print(predstr)
print(correctstr(tgt_seq1[which, :], predstr))
predstr = readstr(seq2preds[which, :, :])
print(predstr)
print(correctstr(tgt_seq2[which, :], predstr))


....................CGACACGTCTTTAATGCAAAATTAACCTGA
...................GCGACACGTCTTTAAAGCAAAATTAACCTG.
...............TTGGGCGACACGTCTTTAAAGCAAAATTAA.....
..............CTTGGGCGACACGTCTTTAAAGCAAAATTA......
............CGCTTGGGCGACACGTCTTTAAAGCAAAAT........
..........AGCGCTTGGGCGACACGTCTTTAAAGCAAA..........
......ATGAAGCGCTTGGGGGACACGTCTTTAAAG..............
......ATGAAGCGCTTGGGCGACACGTCTTCAAAG..............
...CGCATTAAGCGCTTGGGCGACACGTCTTTA.................
ACCCGCATCAAGCGCTTGGGCGACACGTCT....................
..CCGCATGAAGCGCTTGGGCGACACGTCTTT..................
...........GCGCAAGGGCGACACGTTTTAAAGCAAAAT.........
ACCCCCATGAAGCGCTTGGGCGACACGTTT....................
.CCCGCATGAATCGCTTGGGCGACACGTTTT...................
..............CTTGGGCGACACGTTTTAATGCAAAATTAA......
...............TTGGGCGACACGTTTTAAAGCAATATTAAC.....
......ATGAAGCGCTTGGGCGACACATTTTAAAGC..............
................TGGGCGACACGTTTTAAAGCAAAATGAACC....
...CGCATGAAGCGCTTGGGCAACACGTTTTAA.................
..CCGCATGAAGCGCTTGGGCGACACGTTTT

In [512]:
t = make_het_snv('AAAAAAAAAAAAAAAAAAAA', readlength=10, totreads=10, vaf=0.5, error_rate=0).transpose(0,1)
print(to_pileup(t))

AttributeError: 'tuple' object has no attribute 'transpose'

In [460]:
criterion(predictions.flatten(start_dim=0, end_dim=1), tgt.flatten())

tensor(1.3756, device='cuda:0', grad_fn=<NllLossBackward>)

In [1055]:
x = torch.rand(2, 149)
z = torch.zeros(2,1)
torch.cat((x,z), dim=1).shape

torch.Size([2, 150])