In [243]:
from Bio import SeqIO
from Bio import motifs
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import click
import numpy as np
import os
import pandas as pd
import torch
from tqdm import tqdm
bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}"

In [244]:
from architectures import CAM, get_metrics
from jaspar import get_figure, reformat_motif
from sequence import one_hot_encode, rc_one_hot_encoding, rc, one_hot_decode
from train import _get_data_loaders, __get_handle
from predict import _predict

In [245]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

In [246]:
CAM_DIR = "/mnt/md1/home/oriol/CAM/results/IRF4"
ASSAY = "ChIP-seq"
LABEL = "T95R"
model_file = f"{CAM_DIR}/CAM/{ASSAY}/{LABEL}.1/best_model.pth.tar"
batch_size = 2**6
debugging = False
name = f"{ASSAY}.{LABEL}"
output_dir = f"{CAM_DIR}/CAM/{ASSAY}/{LABEL}.1"
rev_complement = True
threads = 1
FASTA_DIR = f"{CAM_DIR}/CAM/{ASSAY}/T95R-WT.1"
NAME = "test"
fasta_file = f"{FASTA_DIR}/{NAME}.fa"

In [247]:
# Initialize
torch.set_num_threads(threads)

In [248]:
# Create output dirs
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)
for subdir in ["predictions", "embeddings"]:
    if not os.path.isdir(os.path.join(output_dir, subdir)):
        os.makedirs(os.path.join(output_dir, subdir))

In [249]:
# Get model
selene_dict = torch.load(model_file)
model = CAM(
    selene_dict["options"]["cnn_units"],
    selene_dict["options"]["kernel_size"],
    selene_dict["options"]["sequence_length"],
    selene_dict["options"]["n_features"],
    selene_dict["options"]["clamp_weights"],
    selene_dict["options"]["no_padding"],
    selene_dict["options"]["weights_file"],
)
model.load_state_dict(selene_dict["state_dict"])
model.to(device)

CAM(
  (linears): Sequential(
    (0): Conv1d(512, 128, kernel_size=(19,), stride=(1,), padding=(19,), groups=128)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ExpAct()
    (3): MaxPool1d(kernel_size=7, stride=7, padding=0, dilation=1, ceil_mode=False)
    (4): Flatten(start_dim=1, end_dim=-1)
    (5): UnSqueeze()
    (6): Conv1d(3968, 12800, kernel_size=(1,), stride=(1,), groups=128)
    (7): BatchNorm1d(12800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Dropout(p=0.3, inplace=False)
    (10): Conv1d(12800, 128, kernel_size=(1,), stride=(1,), groups=128)
    (11): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): Flatten(start_dim=1, end_dim=-1)
  )
  (final): Linear(in_features=128, out_features=1, bias=True)
)

In [250]:
# Padding
if selene_dict["options"]["no_padding"]:
    padding = 0
else:
    padding = selene_dict["options"]["kernel_size"]

In [251]:
def _get_Xs_ys_seq_ids(fasta_file, debugging=False, reverse_complement=False):

    # Initialize
    Xs = []
    ys = []
    seq_ids = []

    # Xs / ys
    handle = __get_handle(fasta_file)
    for record in SeqIO.parse(handle, "fasta"):
        Xs.append(one_hot_encode(str(record.seq).upper()))
        ys.append([1.])
        seq_ids.append(record.id)
    handle.close()

    # Reverse complement
    if reverse_complement:
        n = len(Xs)
        for i in range(n):
            Xs.append(rc_one_hot_encoding(Xs[i]))
            ys.append(ys[i])
            seqs.append(rc(seqs[i]))

    # Return 1,000 sequences
    if debugging:
        return(np.array(Xs)[:1000], np.array(ys)[:1000], 
               np.array(seq_ids)[:1000])

    return(np.array(Xs), np.array(ys), np.array(seq_ids))

In [252]:
##############
# Load Data  #
##############

# Get data
Xs, ys, seq_ids = _get_Xs_ys_seq_ids(fasta_file, debugging, rev_complement)

# Get DataLoader
data_loader = _get_data_loaders(list(Xs), list(ys), batch_size=batch_size)

In [253]:
for x, _ in data_loader:
    for seq, seq_id in zip(x, seq_ids[:x.shape[0]]):
        print(one_hot_decode(seq), seq_id)
    break

AAATCTGCGTTTCATCATCTATAAGAAAGGTACCTATCGAGAACACCCTGCTGGCCAGTGTGTAAATATCTAAAGGAGGACTCAGAAAACACCGGGGAAGTCCAGCCTGCACGTGGTGGCTGGGCTTCAGTGAAGCATGCAGCACAACAGGAGTTGTAAGTAGTAGTTACATCAGCAGCCCTGGAAATTCTGCTCAGAACC hg38_dna
GGTTCTGAGCAGAATTTCCAGGGCTGCTGATGTAACTACTACTTACAACTCCTGTTGTGCTGCATGCTTCACTGAAGCCCAGCCACCACGTGCAGGCTGGACTTCCCCGGTGTTTTCTGAGTCCTCCTTTAGATATTTACACACTGGCCAGCAGGGTGTTCTCGATAGGTACCTTTCTTATAGATGATGAAACGCAGATTT hg38_dna


In [254]:
############
# Predict  #
############ 

# Initialize
input_data = "binary"
if selene_dict["options"]["no_padding"]:
    padding = 0
else:
    padding = selene_dict["options"]["kernel_size"]

In [255]:
predictions, _ = _predict(model, data_loader, input_data)
predictions.shape

100%|████████████████████| 1/1 [00:00<00:00, 96.91it/s]


(4, 1)

In [256]:
def __get_max_predictions(predictions):

    # Initialize
    data = []
    fwd = []
    rev = []
    strands = []

    # DataFrame
    for i, values in enumerate(predictions[:len(predictions)//2]):
        fwd.append(values.tolist())
        data.append([i] + fwd[-1])       
    for i, values in enumerate(predictions[len(predictions)//2:]):
        rev.append(values.tolist())
        data.append([i] + rev[-1])
    df = pd.DataFrame(data)

    # Get max. values
    df = df.groupby(0).max()
    max_predictions = df.values.tolist()

    # Get strands
    for i in range(len(max_predictions)):
        strands.append([])
        for j in range(len(max_predictions[i])):
            if max_predictions[i][j] == fwd[i][j]:
                strands[-1].append("+")
            else:
                strands[-1].append("-")

    return(np.array(max_predictions), np.array(strands))

In [257]:
#max_predictions, strands = __get_max_predictions(predictions)
seq_ids = seq_ids.reshape(1, -1).T
#data = np.concatenate((seq_ids, max_predictions, strands), axis=1)
data = np.concatenate((seq_ids, predictions), axis=1)
df = pd.DataFrame(data)
tsv_file = f"{output_dir}/predictions/{NAME}.tsv"
df.to_csv(tsv_file, sep="\t", header=False, index=False)

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 2 and the array at index 1 has size 4

In [259]:
predictions

array([[4.3109158e-04],
       [1.0000000e+00],
       [9.9999893e-01],
       [3.6856499e-02]], dtype=float32)

In [260]:
seq_ids

array([['hg38_dna'],
       ['hg38_dna']], dtype='<U8')

In [197]:
def _get_embeddings(model, data_loader, input_data):

    # Initialize
    embeddings = []
    labels = []

    with torch.no_grad():
        for x, label in tqdm(iter(data_loader), total=len(data_loader),
                             bar_format=bar_format):

            # Get embeddings
            x = x.to(device)
            x = x.repeat(1, model._options["cnn_units"], 1)
            out = model.linears(x)
            embeddings.extend(out.detach().cpu().numpy())

            # Get labels
            labels.extend(label.numpy())

    return(np.array(embeddings), np.array(labels))

In [198]:
embeddings, _ = _get_embeddings(model, data_loader, input_data)

100%|████████████████████| 250/250 [00:01<00:00, 140.54it/s]


In [199]:
#values = np.concatenate(np.array_split(embeddings, 2), axis=1)
#data = np.concatenate((seq_ids, values), axis=1)
data = np.concatenate((seq_ids, embeddings), axis=1)
df = pd.DataFrame(data)
tsv_file = f"{output_dir}/embeddings/{NAME}.tsv"
df.to_csv(tsv_file, sep="\t", header=False, index=False)

In [200]:
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,119,120,121,122,123,124,125,126,127,128
0,chr1:890347-890548,0.7464627,0.0,1.3187275,0.0,1.0244732,0.0,0.0,0.44794384,1.1286832,...,0.0,0.0,2.202464,1.6232232,0.2522869,0.0,0.9141333,0.0,1.0658158,1.2202611
1,chr1:1213019-1213220,0.09739119,0.0,0.0,0.0,0.40994245,0.0,0.0,0.0,0.0,...,0.0,1.8096246,1.4389081,0.0,0.27763152,0.25536928,0.37086844,2.42282,0.0,1.6291996
2,chr1:1219451-1219652,0.0,0.0,0.0,0.3033954,0.0,0.0,0.0,0.0,0.4743773,...,0.0,0.0,0.0,0.0,0.0,0.3831599,0.0,0.0,0.0,0.0
3,chr1:1375323-1375524,0.0,1.3739494,0.0,1.7883452,1.372933,2.0850608,0.0,0.0,0.13726093,...,0.0,0.79185915,0.0,0.0,0.0,1.0642388,0.65111166,2.4851842,1.7214277,0.85132915
4,chr1:1417947-1418148,0.3857286,0.0,0.0,2.1507537,0.9763753,0.0,0.6279799,0.0,0.0,...,0.0,0.0,0.0,1.1950945,0.0,0.13153571,0.0,1.3038348,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15964,chrY:20798523-20798724,0.59028876,0.0,0.15509619,0.023913326,1.1694824,0.90835696,0.45059773,0.0,1.3306165,...,0.0,0.30142453,0.0,0.0,1.1054194,1.2165644,0.3669954,2.77028,0.0,0.2251991
15965,chrY:20799643-20799844,1.4428897,0.0,2.149443,0.0,1.9104419,1.2625622,2.5944998,0.97249174,0.78397524,...,0.0,0.0,0.0,0.06868644,2.0274549,0.0,0.0,0.0,0.0,0.37953374
15966,chrY:20891622-20891823,0.0,0.73754203,0.92224365,0.4588326,1.3890985,0.0,0.63063276,0.3797961,0.66554224,...,1.517823,0.0,1.4503822,0.22513184,1.6412098,0.2508307,0.68962467,0.0,0.0,0.0
15967,chrY:21248603-21248804,0.6614586,0.103128314,0.57888025,0.64559454,0.0,1.0684323,0.0,0.0,1.6861072,...,0.0,0.0,2.2804859,0.0,0.0,1.975696,1.312235,0.72986907,0.21027112,0.0
