In [83]:
import pysam
import bisect
from pyarrow import feather, Table
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchsummary import summary
from tqdm import tqdm
from nanopore_dataset import create_sample_map
from nanopore_dataset import create_splits
from nanopore_dataset import load_csv
from nanopore_dataset import NanoporeDataset

from resnet1d import ResNet1D
import seaborn as sns

In [128]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [129]:
model = ResNet1D(
            in_channels=1,
            base_filters=128,
            kernel_size=3,
            stride=2,
            groups=1,
            n_block=8,
            n_classes=2,
            downsample_gap=2,
            increasefilter_gap=4,
            use_do=False)
summary(model, (1, 400), device= device)

ADDSEQ_FN = './nanopore_classification/best_models/addseq_resnet1d.pt'

weights_path = ADDSEQ_FN
model.load_state_dict(torch.load(weights_path, map_location=torch.device(device)))
model.to(device)
model.eval()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1             [-1, 128, 400]             512
   MyConv1dPadSame-2             [-1, 128, 400]               0
       BatchNorm1d-3             [-1, 128, 400]             256
              ReLU-4             [-1, 128, 400]               0
            Conv1d-5             [-1, 128, 400]          49,280
   MyConv1dPadSame-6             [-1, 128, 400]               0
       BatchNorm1d-7             [-1, 128, 400]             256
              ReLU-8             [-1, 128, 400]               0
            Conv1d-9             [-1, 128, 400]          49,280
  MyConv1dPadSame-10             [-1, 128, 400]               0
       BasicBlock-11             [-1, 128, 400]               0
      BatchNorm1d-12             [-1, 128, 400]             256
             ReLU-13             [-1, 128, 400]               0
           Conv1d-14             [-1, 1

ResNet1D(
  (first_block_conv): MyConv1dPadSame(
    (conv): Conv1d(1, 128, kernel_size=(3,), stride=(1,))
  )
  (first_block_bn): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (first_block_relu): ReLU()
  (basicblock_list): ModuleList(
    (0): BasicBlock(
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU()
      (do1): Dropout(p=0.5, inplace=False)
      (conv1): MyConv1dPadSame(
        (conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,))
      )
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU()
      (do2): Dropout(p=0.5, inplace=False)
      (conv2): MyConv1dPadSame(
        (conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,))
      )
      (max_pool): MyMaxPool1dPadSame(
        (max_pool): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (1): BasicBlock(
      (bn1)

In [208]:
samfile_pos = pysam.AlignmentFile("/scratch/gabai/addseq/data/ctrl/unique.500.pass.sorted.bam", "rb")
parse_bam_pos = {}
for line in samfile_pos:
    read = str(line)
    readID = read.split('\t')[0]
    strandID = int(read.split('\t')[1])
    if strandID == 0:
        strand = 1
    elif strandID == 16:
        strand = -1
    else:
        strand = 0
    parse_bam_pos[readID] = strand
samfile_pos.close()

In [191]:
samfile_neg = pysam.AlignmentFile("/scratch/gabai/addseq/data/ctrl/unique.0.pass.sorted.bam", "rb")
parse_bam_neg = {}
for line in samfile_neg:
    read = str(line)
    readID = read.split('\t')[0]
    strandID = int(read.split('\t')[1])
    if strandID == 0:
        strand = 1
    elif strandID == 16:
        strand = -1
    else:
        strand = 0
    parse_bam_neg[readID] = strand
samfile_neg.close()

In [194]:
eventAlign_pos = '/gicephfs/brookslab/bsaintjo/220516_ang_conc_unique/unique.500.eventalign.tsv'

In [281]:
def modPredict(read, chrom, start, seq, signals, siganlLengthList, strandInt, 
               signalWindow = 400, downScores = 500, seqLength = 80):
    
    strand = {
        "strand": strandInt,
    }

    metadata = {
        "chrom": str(chrom),
        "length": int(len(seq)),
        "name": str(read),
        # seq will always be an empty string
        "seq": "addseq_pos",
        "start": int(start),
        # True is positive strand, False is negative strand
        "strand": strand,
    }

    scores = []
    outScore = {
        # left most position represent the position
        # ie T is 101, A is 102, T is 103, T is 104, etc.
        "kmer": "",
        "pos": 0,
        # For your case, "score" and "signal_score" should be the same
        # This is where you will put in the scores output by the method
        # you are working with.
        "score": 0,
        "signal_score": 0,
        # For your case, skip_score is always 0.0
        "skip_score": 0.0,
        # For your case, skipped will always be set to False
        "skipped": False,
    }

    input_tensor = torch.zeros((1, 1, 400)).to(device)
    sequence_tensor = torch.tensor(signals)

    for pos in tqdm(range(len(signals)-signalWindow)):
        input_tensor[:, :, :] = sequence_tensor[pos:pos+signalWindow]
        start = bisect.bisect_right(siganlLengthList, pos)
        end = start + seqLength
#         end = bisect.bisect_left(siganlLengthList, pos+signalWindow)
        kmer = seq[start:end]
        prob = model(input_tensor).sigmoid().item()
        # either the very first kmer or the new kmer
        if (kmer, start) != (outScore["kmer"], outScore["pos"]):
            # store kmer
            if outScore["kmer"]:
                scores.append({
                    # left most position represent the position
                    # ie T is 101, A is 102, T is 103, T is 104, etc.
                    "kmer": str(outScore["kmer"]),
                    "pos": int(outScore["pos"]),
                    # For your case, "score" and "signal_score" should be the same
                    # This is where you will put in the scores output by the method
                    # you are working with.
                    "score": float(np.max(probList)),
                    "signal_score": float(np.max(probList)),
                    # For your case, skip_score is always 0.0
                    "skip_score": 0.0,
                    # For your case, skipped will always be set to False
                    "skipped": False,
                })
                if len(scores) >= downScores:
                    outScore["kmer"] = ""
                    break
            outScore["kmer"] = kmer
            outScore["pos"] = start
            probList = [prob]
        else:
            probList.append(prob)
    if outScore["kmer"]:
        scores.append({
            # left most position represent the position
            # ie T is 101, A is 102, T is 103, T is 104, etc.
            "kmer": str(outScore["kmer"]),
            "pos": int(outScore["pos"]),
            # For your case, "score" and "signal_score" should be the same
            # This is where you will put in the scores output by the method
            # you are working with.
            "score": float(np.max(probList)),
            "signal_score": float(np.max(probList)),
            # For your case, skip_score is always 0.0
            "skip_score": 0.0,
            # For your case, skipped will always be set to False
            "skipped": False,
        })
    
    return (metadata, scores)

In [288]:
def alignScore(eventAlign, downReads = 10, downScore = 5000):
    
    cawlr_read = []
    readID = ''
    sequence = ''
    signalLength = 0
    siganlLengthList = []
    with open(eventAlign, 'r') as inFile:
        header = inFile.readline()
        for line in inFile:
            line = line.strip()
            read = line.split('\t')[3]
            # the very first read
            if readID != read:
                if sequence:
                    # Store information for the last read
                    metadata, scores = modPredict(read=readID, chrom=chrom, start=start,
                                                  seq = sequence, signals = siganlList,
                                                  siganlLengthList=siganlLengthList, strandInt=strand, 
                                                  downScores = 50)
                    cawlr_read.append({"metadata": metadata, "scores": scores})
                    if len(cawlr_read) > downReads:
                        sequence = ''
                        break
                    # Set variables back to initial state
                    readID = ''
                    sequence = ''
                    siganlList = []
                    signalLength = 0
                    siganlLengthList = []
                readID = read
                chrom = line.split('\t')[0]
                start = line.split('\t')[1]
                kmer = line.split('\t')[2]
                signals = [float(i) for i in line.split('\t')[13].split(',')]
                siganlList = signals
                signalLength = len(signals)
                strand = 'NA'
                if readID in parse_bam_pos:           
                    strand = parse_bam_pos[readID]
                sequence += kmer
            # next kmer within the same read
            else:
                signals = [float(i) for i in line.split('\t')[13].split(',')]
                siganlList += signals
                signalLength += len(signals)
                # different kmer
                #  (kmer1, chrom1, start1) = (kmer0, chrom0, start0)
                if (line.split('\t')[2], line.split('\t')[0], line.split('\t')[1]) != (kmer, chrom, start):
                    kmer = line.split('\t')[2]
                    sequence += kmer[-1]
                    siganlLengthList.append(signalLength)
        if sequence:
            metadata, scores = modPredict(read=readID, chrom=chrom, start=start,
                                          seq = sequence, signals = siganlList,
                                          siganlLengthList=siganlLengthList, strandInt=strand, 
                                          downScores = downScore)
            cawlr_read.append({"metadata": metadata, "scores": scores})
    
    return cawlr_read

In [289]:
outputFile = '/scratch/gabai/addseq/data/ctrl/unique.500.arrow'
cawlrRead = alignScore(eventAlign = eventAlign_pos, downReads=500, downScore=5000)
scored = {"scored": cawlrRead}
new_fdict = scored
output = Table.from_pydict(new_fdict, schema=table.schema)
feather.write_feather(output, outputFile)
print('done')

  0%|▏                                                                                                                               | 257/138963 [09:31<85:42:23,  2.22s/it]
  2%|██▍                                                                                                                               | 241/12803 [09:14<8:01:42,  2.30s/it]
  1%|▋                                                                                                                                | 253/48555 [07:23<23:30:24,  1.75s/it]
  1%|█▍                                                                                                                               | 285/26330 [07:29<11:24:50,  1.58s/it]
  0%|▋                                                                                                                                | 269/54120 [06:31<21:47:10,  1.46s/it]
  0%|▌                                                                                                                            

  7%|████████▋                                                                                                                            | 268/4081 [00:05<01:14, 51.36it/s]
 10%|█████████████                                                                                                                        | 254/2589 [00:05<00:49, 47.26it/s]
  3%|███▋                                                                                                                                 | 275/9970 [00:04<02:26, 66.12it/s]
  2%|██                                                                                                                                  | 278/17830 [00:04<04:52, 60.09it/s]
  0%|▍                                                                                                                                   | 287/87873 [00:04<23:24, 62.35it/s]
  7%|████████▊                                                                                                                    

  8%|██████████▌                                                                                                                          | 259/3277 [00:03<00:44, 67.93it/s]
  0%|▎                                                                                                                                  | 251/127739 [00:03<31:10, 68.17it/s]
 21%|████████████████████████████▍                                                                                                        | 274/1284 [00:04<00:14, 68.08it/s]
 13%|█████████████████                                                                                                                    | 242/1888 [00:03<00:25, 64.94it/s]
  3%|████▌                                                                                                                                | 271/7835 [00:04<01:57, 64.47it/s]
  1%|▊                                                                                                                            

  0%|▎                                                                                                                                  | 259/108844 [00:03<26:19, 68.75it/s]
  0%|▌                                                                                                                                   | 251/61930 [00:03<14:50, 69.24it/s]
  0%|▎                                                                                                                                   | 262/93196 [00:03<21:01, 73.70it/s]
  1%|█▏                                                                                                                                  | 271/31213 [00:03<07:13, 71.45it/s]
  1%|█▏                                                                                                                                  | 279/31168 [00:05<09:14, 55.74it/s]
  0%|▍                                                                                                                            

  1%|▉                                                                                                                                   | 278/39876 [00:04<11:20, 58.22it/s]
 12%|███████████████▋                                                                                                                     | 281/2386 [00:04<00:32, 63.80it/s]
  2%|██▋                                                                                                                                 | 325/15911 [00:05<04:12, 61.79it/s]
  0%|▍                                                                                                                                   | 276/77992 [00:04<21:01, 61.62it/s]
  0%|▌                                                                                                                                   | 270/70573 [00:04<19:46, 59.26it/s]
  9%|███████████▋                                                                                                                 

 11%|███████████████                                                                                                                      | 258/2272 [00:03<00:28, 71.04it/s]
 11%|██████████████▌                                                                                                                      | 250/2279 [00:03<00:29, 68.10it/s]
  4%|█████▋                                                                                                                               | 269/6260 [00:04<01:29, 67.09it/s]
  9%|███████████▌                                                                                                                         | 255/2926 [00:03<00:41, 63.91it/s]
  4%|█████▏                                                                                                                               | 247/6374 [00:03<01:21, 74.96it/s]
  4%|█████▊                                                                                                                       

done


In [258]:
table = feather.read_table("../data/pos_scores.arrow")

In [262]:
pos_pyarrow = Table.to_pydict(table)

In [275]:
# Brandon shared
table.schema

scored: struct<metadata: struct<name: string not null, chrom: string not null, start: uint64 not null, length: uint64 not null, strand: struct<strand: int8 not null> not null, seq: string not null> not null, scores: list<item: struct<pos: uint64 not null, kmer: string not null, skipped: bool not null, signal_score: double, skip_score: double not null, score: double not null> not null> not null> not null
  child 0, metadata: struct<name: string not null, chrom: string not null, start: uint64 not null, length: uint64 not null, strand: struct<strand: int8 not null> not null, seq: string not null> not null
      child 0, name: string not null
      child 1, chrom: string not null
      child 2, start: uint64 not null
      child 3, length: uint64 not null
      child 4, strand: struct<strand: int8 not null> not null
          child 0, strand: int8 not null
      child 5, seq: string not null
  child 1, scores: list<item: struct<pos: uint64 not null, kmer: string not null, skipped: bool not

In [276]:
output.schema

scored: struct<metadata: struct<name: string not null, chrom: string not null, start: uint64 not null, length: uint64 not null, strand: struct<strand: int8 not null> not null, seq: string not null> not null, scores: list<item: struct<pos: uint64 not null, kmer: string not null, skipped: bool not null, signal_score: double, skip_score: double not null, score: double not null> not null> not null> not null
  child 0, metadata: struct<name: string not null, chrom: string not null, start: uint64 not null, length: uint64 not null, strand: struct<strand: int8 not null> not null, seq: string not null> not null
      child 0, name: string not null
      child 1, chrom: string not null
      child 2, start: uint64 not null
      child 3, length: uint64 not null
      child 4, strand: struct<strand: int8 not null> not null
          child 0, strand: int8 not null
      child 5, seq: string not null
  child 1, scores: list<item: struct<pos: uint64 not null, kmer: string not null, skipped: bool not

In [None]:
# chrI	73384	ATTTAT	d066e5c0-4f9d-4d18-83ea-67b7a18165ed	t	1000	88.71	3.248	0.00150	ATAAAT	91.89	2.42	-1.11	91.2435,91.2435,86.3527,88.9728,91.5929,82.8593