In [1]:
from Bio import SeqIO
from Bio import motifs
import click
from click_option_group import optgroup
import gc
import gzip
from io import StringIO
import numpy as np
import os
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

In [2]:
from utils.architectures import CAM, NonStrandSpecific
from utils.jaspar import get_figure, reformat_jaspar_motif
from utils.sequence import one_hot_encode, one_hot_decode, rc_one_hot_encoding

In [3]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

In [7]:
model_file = "../../results/TF-Binding-Matrix/CAM.cnn-units=16/TP53/best_model.pth.tar"
training_file = "../../results/TF-Binding-Matrix/FASTA/Train/TP53.fa.gz"
batch_size = 2**6
debugging = False
output_dir = "../../results/TF-Binding-Matrix/CAM.cnn-units=16/TP53/"
threads = 4

In [8]:
# Create output dirs
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)
for subdir in ["profiles", "logos"]:
    if not os.path.isdir(os.path.join(output_dir, subdir)):
        os.makedirs(os.path.join(output_dir, subdir))

In [9]:
def _fix_state_dict(state_dict):

    for k in frozenset(state_dict.keys()):
        state_dict[k[6:]] = state_dict.pop(k)

    return(state_dict)

In [10]:
# Load model
selene_dict = torch.load(model_file)
model = CAM(
    selene_dict["options"]["cnn_units"],
    selene_dict["options"]["motif_length"],
    selene_dict["options"]["sequence_length"],
    selene_dict["options"]["n_features"],
    selene_dict["options"]["input_data"],
    selene_dict["options"]["weights_file"],
)
try:
    model.load_state_dict(selene_dict["state_dict"])
except:
    model.load_state_dict(_fix_state_dict(selene_dict["state_dict"]))
model.to(device)

CAM(
  (linears): Sequential(
    (0): Conv1d(64, 16, kernel_size=(26,), stride=(1,), groups=16)
    (1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ExpAct()
    (3): MaxPool1d(kernel_size=7, stride=7, padding=0, dilation=1, ceil_mode=False)
    (4): Flatten(start_dim=1, end_dim=-1)
    (5): UnSqueeze()
    (6): Conv1d(400, 1600, kernel_size=(1,), stride=(1,), groups=16)
    (7): BatchNorm1d(1600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Dropout(p=0.3, inplace=False)
    (10): Conv1d(1600, 16, kernel_size=(1,), stride=(1,), groups=16)
    (11): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): Flatten(start_dim=1, end_dim=-1)
  )
  (final): Linear(in_features=16, out_features=1, bias=True)
)

In [11]:
def _get_Xs_ys(fasta_file, debugging=False):

    # Initialize
    Xs = []
    ys = []

    # Xs / ys
    handle = _get_handle(fasta_file)
    for record in SeqIO.parse(handle, "fasta"):
        _, y_list = record.description.split()
        Xs.append(one_hot_encode(str(record.seq).upper()))
        ys.append([float(y) for y in y_list.split(";")])

    if debugging:
        return(np.array(Xs)[:10000], np.array(ys)[:10000])

    return(np.array(Xs), np.array(ys))

def _get_data_loader(Xs_train, ys_train, batch_size=2**6, threads=1):

    # Reverse complement
    n = len(Xs_train)
    for i in range(n):
        encoded_seq = rc_one_hot_encoding(Xs_train[i])
        Xs_train.append(encoded_seq)
        ys_train.append(ys_train[i])

    # TensorDatasets
    train_set = TensorDataset(torch.Tensor(Xs_train), torch.Tensor(ys_train))

    # DataLoaders
    kwargs = dict(batch_size=batch_size, num_workers=threads)
    train_loader = DataLoader(train_set, **kwargs)

    return(train_loader)

def _get_handle(file_name):
    if file_name.endswith("gz"):
        handle = gzip.open(file_name, "rt")
    else:
        handle = open(file_name, "rt")
    return(handle)

def _release_memory(my_object):
   del my_object
   gc.collect()

In [12]:
# Get Xs/ys
Xs, ys = _get_Xs_ys(training_file, debugging)

# Get DataLoader
data_loader = _get_data_loader(list(Xs), list(ys), batch_size, threads)

# Free up memory
_release_memory(Xs)
_release_memory(ys)

In [13]:
def __get_activations(model, data_loader):

    # Initialize
    activations = torch.tensor([], dtype=torch.float32)

    with torch.no_grad():
        for x, _ in tqdm(data_loader, total=len(data_loader)):
            x = x.to(device)
            x = x.repeat(1, model._options["cnn_units"], 1)
            activation = model.linears[:3](x)
            activations = torch.cat([activations, activation.cpu()])

    return(activations.numpy())

In [14]:
def __get_pfms(sequences, activations, motif_length=19):

    """
    For each filter, build a Position Frequency Matrix (PFM) from all sites
    reaching at least ½ the maximum activation value for that filter across all input sequences
    
     the
    activations and the original sequences, and keep the sites used to
    derive such matrix.

    params :
        actvations (np.array) : (N*N_filters*L) array containing the ourput for each filter and selected sequence of the test set
        sequnces (np.array) : (N*4*200) selected sequences (ACGT)
        y (np.array) : (N*T) original target of the selected sequnces
        output_file_path (str) : path to directory to store the resulting pwm meme file
    """

    # Initialize
    n_filters = activations.shape[1]
    pfms = np.zeros((n_filters, 4, motif_length))
    # sites = [[] for _ in range(n_filters)]

    # Find the threshold value for activations (i.e. 50%)
    activation_thresholds = 0.5*np.amax(activations, axis=(0, 2))

    # For each filter...
    for i in range(n_filters):

        activated_sequences_list = []

        # For each sequence...
        for j in range(len(sequences)):

            # Get indices of sequences that activate the filter
            idx = np.where(activations[j,i,:] > activation_thresholds[i])

            for ix in idx[0]:

                s = sequences[j][:,ix:ix+motif_length]
                # activated_sequences_list.append(s)

                # Build PFM
                pfms[i] = np.add(pfms[i], s)

        # # If activated sequences...
        # if activated_sequences_list:

        #     # Convert activated sequences to array
        #     activated_sequences_arr = np.stack(activated_sequences_list)

        #     # Build PFM
        #     pfms[i] = np.sum(activated_sequences_arr, axis=0)

            # # Save sites that activated the filter
            # for s in activated_sequences_list:
            #     sites[i].append(one_hot_decode(s))

    # return(pfms, sites)
    return(pfms)

In [15]:
# Initialize
outputs = []
labels = []
sequences = []
profiles = []

with torch.no_grad():
    for x, label in tqdm(data_loader, total=len(data_loader)):
        for encoded_seq in x:
            sequence = "N" * model._options["motif_length"]
            sequence += one_hot_decode(encoded_seq.numpy())
            sequence += "N" * model._options["motif_length"]
            sequences.append(one_hot_encode(sequence))
        x = x.to(device)
        out = model(x)
        if model._options["input_data"] == "binary":
            out = torch.sigmoid(out)
        outputs.extend(out.detach().cpu().numpy())
        labels.extend(label.numpy())

100%|██████████| 85/85 [00:03<00:00, 22.68it/s]


In [16]:
# Get activations
activations = __get_activations(model, data_loader)

100%|██████████| 85/85 [00:01<00:00, 80.66it/s] 


In [17]:
# Free memory
_release_memory(data_loader)

# Filter sequences/activations
sequences = np.array(sequences)
outputs = np.array(outputs)
labels = np.array(labels)
if model._options["input_data"] == "binary":
    ixs = np.where(labels == 1.)
    sequences = sequences[ixs, :, :][0]
    activations = activations[ixs, :, :][0]

In [18]:
# Get Position Frequency Matrices (PFMs)
pfms = __get_pfms(sequences, activations, model._options["motif_length"])

In [19]:
# Get profiles
for j in range(len(pfms)):
    handle = StringIO("\n".join(["\t".join(map(str, i)) for i in pfms[j]]))
    profiles.append(motifs.read(handle, "pfm-four-rows"))

In [20]:
# Get weights
weights = model.final.weight.detach().cpu().numpy().flatten().tolist()

In [21]:
weights

[1.2451865673065186,
 0.0,
 0.13975411653518677,
 0.0,
 1.3888148069381714,
 1.2927360534667969,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.316008448600769,
 0.0,
 0.0,
 0.0]

In [22]:
for z in zip(profiles, weights):
    print(z[0].consensus, z[1])

CACATTTTTCAAGATATAAAATTTTA 1.2451865673065186
CCCTGGGGTGGTTCTATTGAACAGTT 0.0
AGTTTAAACAGGAGCTGAGACGAACC 0.13975411653518677
AGATGGAAGAAAAAGGGAGGGAATGA 0.0
TCTTTAAGAAATCTATTTTGTATTAC 1.3888148069381714
GATATGAAAAACTACACAGGGGACAT 1.2927360534667969
TTATAATAAGAAGTAAAATTAAAAAA 0.0
CGTGATGGAACGTCGGTGCGGTTGGG 0.0
GAGCTTTTCGAGATAGGCTAGCCCAG 0.0
GGATGTTCTAATGGTGGGGAGTGGTA 0.0
CATGTTCTGAAAGATTGATTAGTTTT 0.0
GCAACCACGGAATATATCAAGTCTGT 0.0
TATCCAGAGATGAAATATTCCACTGC 1.316008448600769
GTGAGATTTTAAAGAACAGAGGTCTT 0.0
AGGATACGCCGTGTGTGCCAGAGGGA 0.0
TATGGAGATGGAAATAAGTCCGAACT 0.0


In [23]:
pfms

array([[[161., 164., 167., ..., 167., 162., 190.],
        [162., 154., 177., ..., 149., 137., 184.],
        [129., 134., 121., ..., 139., 151., 153.],
        [158., 160., 148., ..., 204., 214., 139.]],

       [[  0.,   0.,   0., ...,   2.,   0.,   0.],
        [  2.,   2.,   2., ...,   0.,   1.,   0.],
        [  0.,   0.,   0., ...,   2.,   1.,   0.],
        [  1.,   1.,   1., ...,   0.,   2.,   4.]],

       [[  5.,   3.,   2., ...,   5.,   2.,   1.],
        [  3.,   1.,   4., ...,   4.,   7.,  10.],
        [  2.,   7.,   2., ...,   3.,   1.,   2.],
        [  4.,   3.,   6., ...,   4.,   6.,   3.]],

       ...,

       [[  1.,   5.,   5., ...,   4.,   5.,   4.],
        [  4.,   1.,   2., ...,   7.,   1.,   5.],
        [  5.,   2.,   5., ...,   3.,   2.,   4.],
        [  4.,   6.,   3., ...,   2.,   9.,   5.]],

       [[ 14.,  11.,  11., ...,  11.,  13.,  15.],
        [ 11.,   7.,  11., ...,  11.,   6.,   9.],
        [  9.,  15.,  11., ...,  12.,  14.,  11.],
        [ 