In [1]:
import os
import pandas
import pyBigWig
from datetime import datetime

os.environ["CUDA_VISIBLE_DEVICES"] = "6"

import time 
import numpy
import torch
import matplotlib.pyplot as plt

from logging_myin25 import Logger

from tqdm import tqdm

In [14]:
# bpnet.py
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

"""
This module contains a reference implementation of BPNet that can be used
or adapted for your own circumstances. The implementation takes in a
stranded control track and makes predictions for stranded outputs.
"""

import time 
import numpy
import torch
import matplotlib.pyplot as plt

from logging_myin25 import Logger

from tqdm import tqdm

torch.backends.cudnn.benchmark = True


class DefinitelyNotBPNet(torch.nn.Module):
    """A basic model model with count prediction.

    This model consists of a single Dense layer and an output layer!

    Parameters
    ----------
    n_filters: int, optional
        The number of filters to use per convolution. Default is 64.

    name: str or None, optional
        The name to save the model to during training.
    """

    def __init__(self, n_marks, name=None, alpha=1):
        super(DefinitelyNotBPNet, self).__init__()

        self.name = name or "/users/myin25/projects/celltype_specificity/models/definitelynotbpnet_{}".format(datetime.now())
        self.alpha = alpha
        '''
        self.fc1 = torch.nn.Linear(8, 32)
        self.relu1 = torch.nn.Sigmoid()
        self.fc2 = torch.nn.Linear(32, 64)
        self.relu2 = torch.nn.Sigmoid()
        self.fc3 = torch.nn.Linear(64, 16)
        self.relu3 = torch.nn.Sigmoid()
        self.fc4 = torch.nn.Linear(16, 1)'''

        # self.convi = torch.nn.Linear(n_marks, 1)
        self.convi = torch.nn.Conv1d(n_marks, 64, kernel_size=21)
        self.linear = torch.nn.Linear(64, 1)

        self.logger = Logger(["Epoch", "Iteration", "Training Time",
            "Training MNLL Loss", "Training Count MSE", "Val MNLL Loss","Validation Count Pearson", 
            "Validation Count MSE", "Saved?"], verbose=False)


    def forward(self, x, X_ctl=None):
        # counts prediction
        '''x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = self.relu3(x)
        x = self.fc4(x)'''
        
        #print('shape before prediction', x.shape)
        x = self.convi(x)
        x = torch.mean(x, dim=2)
        x = self.linear(x)
        #print('prediction shape', x.shape)
        
        return x


    def predict(self, X, X_ctl=None, batch_size=64, verbose=False):
        """Make predictions for a large number of examples.

        This method will make predictions for a number of examples that exceed
        the batch size. It is similar to the forward method in terms of inputs 
        and outputs, but will run wrapped with `torch.no_grad()` to speed up
        computation and prevent information leakage into the model.


        Parameters
        ----------
        X: torch.tensor, shape=(-1, 4, length)
            The one-hot encoded batch of sequences.

        X_ctl: torch.tensor or None, shape=(-1, n_strands, length)
            A value representing the signal of the control at each position in 
            the sequence. If no controls, pass in None. Default is None.

        batch_size: int, optional
            The number of examples to run at a time. Default is 64.

        verbose: bool
            Whether to print a progress bar during predictions.


        Returns
        -------
        y_profile: torch.tensor, shape=(-1, n_strands, out_length)
            The output predictions for each strand trimmed to the output
            length.
        """

        #print("X shape", X.shape)

        with torch.no_grad():
            starts = numpy.arange(0, X.shape[0], batch_size)
            ends = starts + batch_size

            y_profiles, y_counts = [], []
            for start, end in tqdm(zip(starts, ends), disable=not verbose):
                X_batch = X[start:end].cuda()
                #print("X_batch shape", X_batch.shape)
                y_prof_ = self(X_batch)
                y_prof_ = y_prof_.cpu()

                y_profiles.append(y_prof_)

            y_profiles = torch.cat(y_profiles)
            return y_profiles

    def fit(self, training_data, optimizer, X_valid=None, X_ctl_valid=None, 
        y_valid=None, max_epochs=100, batch_size=64, validation_iter=100, 
        early_stopping=None, verbose=True):
        """Fit the model to data and validate it periodically.

        This method controls the training of a BPNet model. It will fit the
        model to examples generated by the `training_data` DataLoader object
        and, if validation data is provided, will periodically validate the
        model against it and return those values. The periodicity can be
        controlled using the `validation_iter` parameter.

        Two versions of the model will be saved: the best model found during
        training according to the validation measures, and the final model
        at the end of training. Additionally, a log will be saved of the
        training and validation statistics, e.g. time and performance.


        Parameters
        ----------
        training_data: torch.utils.data.DataLoader
            A generator that produces examples to train on. If n_control_tracks
            is greater than 0, must product two inputs, otherwise must produce
            only one input.

        optimizer: torch.optim.Optimizer
            An optimizer to control the training of the model.

        X_valid: torch.tensor or None, shape=(n, 4, 2114)
            A block of sequences to validate on periodically. If None, do not
            perform validation. Default is None.

        X_ctl_valid: torch.tensor or None, shape=(n, n_control_tracks, 2114)
            A block of control sequences to validate on periodically. If
            n_control_tracks is None, pass in None. Default is None.

        y_valid: torch.tensor or None, shape=(n, n_outputs, 1000)
            A block of signals to validate against. Must be provided if
            X_valid is also provided. Default is None.

        max_epochs: int
            The maximum number of epochs to train for, as measured by the
            number of times that `training_data` is exhausted. Default is 100.

        batch_size: int
            The number of examples to include in each batch. Default is 64.

        validation_iter: int
            The number of batches to train on before validating against the
            entire validation set. When the validation set is large, this
            enables the total validating time to be small compared to the
            training time by only validating periodically. Default is 100.

        early_stopping: int or None
            Whether to stop training early. If None, continue training until
            max_epochs is reached. If an integer, continue training until that
            number of `validation_iter` ticks has been hit without improvement
            in performance. Default is None.

        verbose: bool
            Whether to print out the training and evaluation statistics during
            training. Default is True.
        """

        #print("starting fitting")
        
        if X_valid is not None:
            X_valid = X_valid.cuda()
            y_valid = y_valid.sum(dim=(-1, -2)).cuda()
            
            #X_valid = X_valid.sum(dim=-1).cuda()
            #print("y_valid.shape", y_valid.shape)
            #print("y_valid.shape after", y_valid.shape)
            #print("X_valid.shape", X_valid.shape)
            #print("y_valid.shape", y_valid.shape)

        iteration = 0
        early_stop_count = 0
        best_loss = float("inf")
        self.logger.start()

        #print("starting to train")
        for epoch in range(max_epochs):
            tic = time.time()

            for seqs, X, y in training_data:
                X, y = X.cuda(), y.sum(dim=(-1, -2)).cuda()
                #X, y = X.sum(dim=-1).cuda(), y.sum(dim=(-1, -2)).cuda()
                #print("X numpy size", X.numpy().nbytes)
                #print("y numpy size", y.numpy().nbytes)
                #print("X.shape", X.shape)
                #print("y.shape", y.shape)

                # Clear the optimizer and set the model to training mode
                optimizer.zero_grad()
                self.train()

                # Run forward pass
                y_prof = self(X).squeeze()
                #print("y_prof", y_prof.shape)
                #print("y prediction numpy size", y_prof.numpy().nbytes)

                #print("predicted, going to evaluate losses.")
                # Calculate the profile and count losses
                #print('shape comparison')
                #print(y_prof.shape, y.shape)
                t_prof_loss = MNLLLoss(y_prof, y).mean()
                t_count_loss = log1pMSELoss(torch.sum(y_prof, dim=-1), torch.sum(y, dim=-1)).mean()
                
                loss = t_prof_loss + self.alpha * t_count_loss
                loss.backward()
                optimizer.step()
                
                if verbose and iteration % validation_iter == 0:
                    train_time = time.time() - tic

                    with torch.no_grad():
                        self.eval()
                        
                        tic = time.time()
                        y_profs = self.predict(X_valid, X_ctl_valid)
                        # print("y_profs.shape", y_profs.shape)
                        y_counts = torch.sum(y_profs, dim=-1)
                        
                        #print("y val actual profs numpy size", y_valid.numpy().nbytes)
                        #print("y val predicted profs numpy size", y_profs.numpy().nbytes)
                        
                        '''print("validation predicted profs", y_valid)
                        print("validation predicted profs", y_profs)'''
                        
                        # MSE
                        '''print('y_valid', y_valid)
                        print('y_valid min', torch.min(y_valid))
                        print('y_valid max', torch.max(y_valid))
                        print('y_counts', y_counts)
                        print('y_counts min', torch.min(y_counts))
                        print('y_counts max', torch.max(y_counts))'''
                        
                        log_true = torch.log1p(y_valid)
                        
                        #print("log_true before", y_valid.shape)
                        #log_true = torch.sum(torch.log(y_valid+1), dim=-1)
                        log_true = log_true.cpu()
                        #print("log_true after", log_true.shape)
                        
                        '''print('log_true', log_true)'''
                        '''plt.scatter(log_true,y_counts.squeeze())
                        plt.show()'''
                        
                        count_mse = torch.square(log_true[..., None] - y_counts)
                        count_mse = torch.mean(count_mse.squeeze(), dim=-1)
                        
                        '''print(count_mse)
                        print(any(torch.isnan(count_mse)))'''
                        '''plt.hist(count_mse)
                        plt.show()'''
                        
                        # count pearson
                        y_counts = y_counts.squeeze()
                        '''plt.scatter(log_true, y_counts)
                        plt.show()
                        print('comparison of shapes')
                        print(log_true.shape, y_counts.shape)
                        print(torch.min(log_true))
                        print(torch.max(log_true))
                        print(torch.mean(log_true))
                        plt.hist(log_true)
                        plt.show()
                        print(any(torch.isnan(y_counts)))
                        print(torch.min(y_counts))
                        print(torch.max(y_counts))
                        print(torch.mean(y_counts))
                        plt.hist(y_counts)
                        plt.show()'''
                        
                        count_corr = pearson_corr(log_true, y_counts)
                        
                        y_valid = y_valid.cpu()
                        #X_valid = X_valid.cpu()
                        #print('comparison of shapes 2')
                        #print(y_valid.shape, y_counts.shape)
                        #prof_loss = MNLLLoss(y_profs, y_valid).mean().item()
                        count_loss = log1pMSELoss(y_counts, y_valid.sum(dim=-1).mean())
                        prof_loss = count_loss
                        
                        valid_time = time.time() - tic
                        #valid_loss = count_mse.mean()
                        valid_loss = count_loss
                        count_loss = count_loss.item()

                        '''print('count_loss_')
                        print(count_loss)
                        print('count_corr')
                        print(count_corr)
                        print('nan_to_num')
                        print(numpy.nan_to_num(count_corr).mean())
                        print('count_mse')
                        print(count_mse.mean().item())
                        print('saved')
                        print((valid_loss < best_loss).item())'''
                        
                        t_count_loss = t_count_loss.item()
                        
                        t_prof_loss = t_prof_loss.cpu().item()
                        '''print(t_prof_loss.device)
                        print(prof_loss.device)
                        print(count_mse.device)
                        print(valid_loss.device)'''
                        self.logger.add([epoch, iteration, train_time, 
                            0, t_count_loss, 0,
                            numpy.nan_to_num(count_corr).mean(), 
                            count_mse.mean().item(),
                            (valid_loss < best_loss).item()])
                        
                        #print("name", type(self.name))
                        self.logger.save("{}.log".format(self.name))
                        
                        if valid_loss < best_loss:
                            torch.save(self, "{}.torch".format(self.name))
                            best_loss = valid_loss
                            early_stop_count = 0
                        else:
                            early_stop_count += 1

                # return
                
                # Extract the profile loss for logging
                # loss = count_loss.item()

                if early_stopping is not None and early_stop_count >= early_stopping:
                    break

                iteration += 1
                
                '''print()
                print()'''

            if early_stopping is not None and early_stop_count >= early_stopping:
                break

        torch.save(self, "{}.final.torch".format(self.name))

def pearson_corr(arr1, arr2):
    """The Pearson correlation between two tensors across the last axis.

    Computes the Pearson correlation in the last dimension of `arr1` and `arr2`.
    `arr1` and `arr2` must be the same shape. For example, if they are both
    A x B x L arrays, then the correlation of corresponding L-arrays will be
    computed and returned in an A x B array.

    Parameters
    ----------
    arr1: torch.tensor
        One of the tensor to correlate.

    arr2: torch.tensor
        The other tensor to correlation.

    Returns
    -------
    correlation: torch.tensor
        The correlation for each element, calculated along the last axis.
    """

    #print(arr1.shape)
    #print(torch.min(arr1), torch.max(arr1))
    #print(arr2.shape)
    #print(torch.min(arr2), torch.max(arr2))
    
    mean1 = torch.mean(arr1, axis=-1).unsqueeze(-1)
    mean2 = torch.mean(arr2, axis=-1).unsqueeze(-1)
    dev1, dev2 = arr1 - mean1, arr2 - mean2

    sqdev1, sqdev2 = torch.square(dev1), torch.square(dev2)
    numer = torch.sum(dev1 * dev2, axis=-1)  # Covariance
    var1, var2 = torch.sum(sqdev1, axis=-1), torch.sum(sqdev2, axis=-1)  # Variances
    denom = torch.sqrt(var1 * var2)
   
    # Divide numerator by denominator, but use 0 where the denominator is 0
    correlation = torch.zeros_like(numer)
    correlation[denom != 0] = numer[denom != 0] / denom[denom != 0]
    return correlation

def log1pMSELoss(log_predicted_counts, true_counts):
	"""A MSE loss on the log(x+1) of the inputs.

	This loss will accept tensors of predicted counts and a vector of true
	counts and return the MSE on the log of the labels. The squared error
	is calculated for each position in the tensor and then averaged, regardless
	of the shape.

	Note: The predicted counts are in log space but the true counts are in the
	original count space.

	Parameters
	----------
	log_predicted_counts: torch.tensor, shape=(n, ...)
		A tensor of log predicted counts where the first axis is the number of
		examples. Important: these values are already in log space.

	true_counts: torch.tensor, shape=(n, ...)
		A tensor of the true counts where the first axis is the number of
		examples.

	Returns
	-------
	loss: torch.tensor, shape=(n, 1)
		The MSE loss on the log of the two inputs, averaged over all examples
		and all other dimensions.
	"""

	log_true = torch.log(true_counts+1)
	return torch.mean(torch.square(log_true - log_predicted_counts), dim=-1)

def MNLLLoss(logps, true_counts):
	"""A loss function based on the multinomial negative log-likelihood.

	This loss function takes in a tensor of normalized log probabilities such
	that the sum of each row is equal to 1 (e.g. from a log softmax) and
	an equal sized tensor of true counts and returns the probability of
	observing the true counts given the predicted probabilities under a
	multinomial distribution. Can accept tensors with 2 or more dimensions
	and averages over all except for the last axis, which is the number
	of categories.

	Adapted from Alex Tseng.

	Parameters
	----------
	logps: torch.tensor, shape=(n, ..., L)
		A tensor with `n` examples and `L` possible categories. 

	true_counts: torch.tensor, shape=(n, ..., L)
		A tensor with `n` examples and `L` possible categories.

	Returns
	-------
	loss: float
		The multinomial log likelihood loss of the true counts given the
		predicted probabilities, averaged over all examples and all other
		dimensions.
	"""

	log_fact_sum = torch.lgamma(torch.sum(true_counts, dim=-1) + 1)
	log_prod_fact = torch.sum(torch.lgamma(true_counts + 1), dim=-1)
	log_prod_exp = torch.sum(true_counts * logps, dim=-1)
	return -log_fact_sum + log_prod_fact - log_prod_exp

In [3]:
# io.py
# Author: Jacob Schreiber <jmschreiber91@gmail.com>
# Code adapted from Alex Tseng, Avanti Shrikumar, and Ziga Avsec

import numpy
import torch
import pandas

import pyfaidx
import pyBigWig

from tqdm import tqdm


def read_meme(filename):
    """Read a MEME file and return a dictionary of PWMs.

    This method takes in the filename of a MEME-formatted file to read in
    and returns a dictionary of the PWMs where the keys are the metadata
    line and the values are the PWMs.


    Parameters
    ----------
    filename: str
        The filename of the MEME-formatted file to read in


    Returns
    -------
    motifs: dict
        A dictionary of the motifs in the MEME file.
    """

    motifs = {}

    with open(filename, "r") as infile:
        motif, width, i = None, None, 0

        for line in infile:
            if motif is None:
                if line[:5] == 'MOTIF':
                    motif = line.split()[1]
                else:
                    continue

            elif width is None:
                if line[:6] == 'letter':
                    width = int(line.split()[5])
                    pwm = numpy.zeros((width, 4))

            elif i < width:
                pwm[i] = list(map(float, line.split()))
                i += 1

            else:
                motifs[motif] = pwm
                motif, width, i = None, None, 0

    return motifs


def one_hot_encode(sequence, alphabet=['A', 'C', 'G', 'T'], dtype='int8', 
    desc=None, verbose=False, **kwargs):
    """Converts a string or list of characters into a one-hot encoding.

    This function will take in either a string or a list and convert it into a
    one-hot encoding. If the input is a string, each character is assumed to be
    a different symbol, e.g. 'ACGT' is assumed to be a sequence of four 
    characters. If the input is a list, the elements can be any size.

    Although this function will be used here primarily to convert nucleotide
    sequences into one-hot encoding with an alphabet of size 4, in principle
    this function can be used for any types of sequences.

    Parameters
    ----------
    sequence : str or list
        The sequence to convert to a one-hot encoding.

    alphabet : set or tuple or list
        A pre-defined alphabet where the ordering of the symbols is the same
        as the index into the returned tensor, i.e., for the alphabet ['A', 'B']
        the returned tensor will have a 1 at index 0 if the character was 'A'.
        Characters outside the alphabet are ignored and none of the indexes are
        set to 1. Default is ['A', 'C', 'G', 'T'].

    dtype : str or numpy.dtype, optional
        The data type of the returned encoding. Default is int8.

    desc : str or None, optional
        The title to display in the progress bar.

    verbose : bool or str, optional
        Whether to display a progress bar. If a string is passed in, use as the
        name of the progressbar. Default is False.

    kwargs : arguments
        Arguments to be passed into tqdm. Default is None.

    Returns
    -------
    ohe : numpy.ndarray
        A binary matrix of shape (alphabet_size, sequence_length) where
        alphabet_size is the number of unique elements in the sequence and
        sequence_length is the length of the input sequence.
    """

    d = verbose is False
    alphabet_lookup = {char: i for i, char in enumerate(alphabet)}

    ohe = numpy.zeros((len(sequence), len(alphabet)), dtype=dtype)
    for i, char in tqdm(enumerate(sequence), disable=d, desc=desc, **kwargs):
        idx = alphabet_lookup.get(char, -1)
        if idx != -1:
            ohe[i, idx] = 1

    return ohe


class DataGenerator(torch.utils.data.Dataset):
    """A data generator for BPNet inputs.

    This generator takes in an extracted set of sequences, output signals,
    and control signals, and will return a single element with random
    jitter and reverse-complement augmentation applied. Jitter is implemented
    efficiently by taking in data that is wider than the in/out windows by
    two times the maximum jitter and windows are extracted from that.
    Essentially, if an input window is 1000 and the maximum jitter is 128, one
    would pass in data with a length of 1256 and a length 1000 window would be
    extracted starting between position 0 and 256. This  generator must be 
    wrapped by a PyTorch generator object.

    Parameters
    ----------
    sequences: torch.tensor, shape=(n, 4, in_window+2*max_jitter)
        A one-hot encoded tensor of `n` example sequences, each of input 
        length `in_window`. See description above for connection with jitter.

    signals: torch.tensor, shape=(n, t, out_window+2*max_jitter)
        The signals to predict, usually counts, for `n` examples with
        `t` output tasks (usually 2 if stranded, 1 otherwise), each of 
        output length `out_window`. See description above for connection 
        with jitter.

    controls: torch.tensor, shape=(n, t, out_window+2*max_jitter) or None, optional
        The control signal to take as input, usually counts, for `n`
        examples with `t` strands and output length `out_window`. If
        None, does not return controls.

    in_window: int, optional
        The input window size. Default is 2114.

    out_window: int, optional
        The output window size. Default is 1000.

    max_jitter: int, optional
        The maximum amount of jitter to add, in either direction, to the
        midpoints that are passed in. Default is 0.

    reverse_complement: bool, optional
        Whether to reverse complement-augment half of the data. Default is False.

    random_state: int or None, optional
        Whether to use a deterministic seed or not.
    """

    def __init__(self, sequences, signals, controls=None, in_window=2114, 
        out_window=1000, max_jitter=0, reverse_complement=False, 
        random_state=None):
        self.in_window = in_window
        self.out_window = out_window
        self.max_jitter = max_jitter

        self.reverse_complement = reverse_complement
        self.random_state = numpy.random.RandomState(random_state)

        self.signals = signals
        self.controls = controls
        self.sequences = sequences	

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        i = self.random_state.choice(len(self.sequences))
        j = 0 if self.max_jitter == 0 else self.random_state.randint(self.max_jitter*2) 

        X = self.sequences[i]
        y = self.signals[i]

        if self.controls is not None:
            X_ctl = self.controls[i]

        if self.controls is not None:
            return X, X_ctl, y

        return X, y


def extract_loci(loci, sequences, signals=None, controls=None, chroms=None, 
    in_window=2114, out_window=1000, max_jitter=0, min_counts=None,
    max_counts=None, n_loci=None, verbose=False):
    """Extract sequences and signals at coordinates from a locus file.

    This function will take in genome-wide sequences, signals, and optionally
    controls, and extract the values of each at the coordinates specified in
    the locus file/s and return them as tensors.

    Signals and controls are both lists with the length of the list, n_s
    and n_c respectively, being the middle dimension of the returned
    tensors. Specifically, the returned tensors of size 
    (len(loci), n_s/n_c, (out_window/in_wndow)+max_jitter*2).

    The values for sequences, signals, and controls, can either be filepaths
    or dictionaries of numpy arrays or a mix of the two. When a filepath is 
    passed in it is loaded using pyfaidx or pyBigWig respectively.   

    Parameters
    ----------
    loci: str or pandas.DataFrame or list/tuple of such
        Either the path to a bed file or a pandas DataFrame object containing
        three columns: the chromosome, the start, and the end, of each locus
        to train on. Alternatively, a list or tuple of strings/DataFrames where
        the intention is to train on the interleaved concatenation, i.e., when
        you want to train on peaks and negatives.

    sequences: str or dictionary
        Either the path to a fasta file to read from or a dictionary where the
        keys are the unique set of chromosoms and the values are one-hot
        encoded sequences as numpy arrays or memory maps.

    signals: list of strs or list of dictionaries or None, optional
        A list of filepaths to bigwig files, where each filepath will be read
        using pyBigWig, or a list of dictionaries where the keys are the same
        set of unique chromosomes and the values are numpy arrays or memory
        maps. If None, no signal tensor is returned. Default is None.

    controls: list of strs or list of dictionaries or None, optional
        A list of filepaths to bigwig files, where each filepath will be read
        using pyBigWig, or a list of dictionaries where the keys are the same
        set of unique chromosomes and the values are numpy arrays or memory
        maps. If None, no control tensor is returned. Default is None. 

    chroms: list or None, optional
        A set of chromosomes to extact loci from. Loci in other chromosomes
        in the locus file are ignored. If None, all loci are used. Default is
        None.

    in_window: int, optional
        The input window size. Default is 2114.

    out_window: int, optional
        The output window size. Default is 1000.

    max_jitter: int, optional
        The maximum amount of jitter to add, in either direction, to the
        midpoints that are passed in. Default is 0.

    min_counts: float or None, optional
        The minimum number of counts, summed across the length of each example
        and across all tasks, needed to be kept. If None, no minimum. Default 
        is None.

    max_counts: float or None, optional
        The maximum number of counts, summed across the length of each example
        and across all tasks, needed to be kept. If None, no maximum. Default 
        is None.  

    n_loci: int or None, optional
        A cap on the number of loci to return. Note that this is not the
        number of loci that are considered. The difference is that some
        loci may be filtered out for various reasons, and those are not
        counted towards the total. If None, no cap. Default is None.

    verbose: bool, optional
        Whether to display a progress bar while loading. Default is False.

    Returns
    -------
    seqs: torch.tensor, shape=(n, 4, in_window+2*max_jitter)
        The extracted sequences in the same order as the loci in the locus
        file after optional filtering by chromosome.

    signals: torch.tensor, shape=(n, len(signals), out_window+2*max_jitter)
        The extracted signals where the first dimension is in the same order
        as loci in the locus file after optional filtering by chromosome and
        the second dimension is in the same order as the list of signal files.
        If no signal files are given, this is not returned.

    controls: torch.tensor, shape=(n, len(controls), out_window+2*max_jitter)
        The extracted controls where the first dimension is in the same order
        as loci in the locus file after optional filtering by chromosome and
        the second dimension is in the same order as the list of control files.
        If no control files are given, this is not returned.
    """

    seqs, signals_, controls_ = [], [], []
    in_width, out_width = in_window // 2, out_window // 2

    # Load the sequences
    if isinstance(sequences, str):
        sequences = pyfaidx.Fasta(sequences)

    names = ['chrom', 'start', 'end']
    if not isinstance(loci, (tuple, list)):
        loci = [loci]

    loci_dfs = []
    for i, df in enumerate(loci):
        if isinstance(df, str):
            df = pandas.read_csv(df, sep='\t', usecols=[0, 1, 2], 
                header=None, index_col=False, names=names)
        elif isinstance(df, pandas.DataFrame):
            df = df.iloc[:, [0, 1, 2]].copy()

        df['idx'] = numpy.arange(len(df)) * len(loci) + i
        loci_dfs.append(df)

    loci = pandas.concat(loci_dfs).set_index("idx").sort_index().reset_index(drop=True)

    if chroms is not None:
        loci = loci[numpy.isin(loci['chrom'], chroms)]
    print("loci shape", loci.shape)

    # Load the signal and optional control tracks if filenames are given
    _signals = []
    if signals is not None:
        for i, signal in enumerate(signals):
            if isinstance(signal, str):
                signal = pyBigWig.open(signal)
            _signals.append(signal)

        signals = _signals


    _controls = []
    if controls is not None:
        for i, control in enumerate(controls):
            if control == "":
                _controls.append("")
            else:
                if isinstance(control, str):
                    control = pyBigWig.open(control, "r")
                _controls.append(control)

        controls = _controls
    print("done iterating through controls")

    desc = "Loading Loci"
    d = not verbose

    max_width = max(in_width, out_width)
    loci_count = 0
    for chrom, start, end in tqdm(loci.values, disable=d, desc=desc):
        mid = start + (end - start) // 2

        if start - max_width - max_jitter < 0:
            continue

        if end + max_width + max_jitter >= len(sequences[chrom]):
            continue

        if n_loci is not None and loci_count == n_loci:
            break 

        start = mid - out_width - max_jitter
        end = mid + out_width + max_jitter

        # Extract the signal from each of the signal files
        if signals is not None:
            signals_.append([])
            for signal in signals:
                if isinstance(signal, dict):
                    signal_ = signal[chrom][start:end]
                else:
                    signal_ = signal.values(chrom, start, end, numpy=True)
                    signal_ = numpy.nan_to_num(signal_)

                signals_[-1].append(signal_)

        # For the sequences and controls extract a window the size of the input
        '''start = mid - in_width - max_jitter
        end = mid + in_width + max_jitter'''

        # Extract the controls from each of the control files
        if controls is not None:
            controls_.append([])
            for control in controls:
                if isinstance(control, dict):
                    control_ = control[chrom][start:end]
                elif control == "":
                    control_ = numpy.zeros(end - start)
                else:
                    control_ = control.values(chrom, start, end, numpy=True)
                    control_ = numpy.nan_to_num(control_)

                controls_[-1].append(control_)

        # Extract the sequence
        if isinstance(sequences, dict):
            seq = sequences[chrom][start:end].T
        else:
            seq = one_hot_encode(sequences[chrom][start:end].seq.upper(),
                alphabet=['A', 'C', 'G', 'T']).T

        seqs.append(seq)
        loci_count += 1
    print('done with chrom, start, end')
    seqs = torch.tensor(numpy.array(seqs), dtype=torch.float32)

    if signals is not None:
        signals_ = torch.tensor(numpy.array(signals_), dtype=torch.float32)

        idxs = torch.ones(signals_.shape[0], dtype=torch.bool)
        if max_counts is not None:
            idxs = (idxs) & (signals_.sum(dim=(1, 2)) < max_counts)
        if min_counts is not None:
            idxs = (idxs) & (signals_.sum(dim=(1, 2)) > min_counts)

        if controls is not None:
            controls_ = torch.tensor(numpy.array(controls_), dtype=torch.float32)
            return seqs[idxs], signals_[idxs], controls_[idxs]

        return seqs[idxs], signals_[idxs]
    else:
        if controls is not None:
            controls_ = torch.tensor(numpy.array(controls_), dtype=torch.float32)
            return seqs, controls_

        return seqs

def extract_signals(loci, sequences, signals=None, controls=None, chroms=None, 
    in_window=2114, out_window=1000, max_jitter=0, min_counts=None,
    max_counts=None, n_loci=None, verbose=False):

    # print(loci)
    # print(signals)

    seqs, signals_, controls_ = [], [], []
    in_width, out_width = in_window // 2, out_window // 2


    names = ['chrom', 'start', 'end']
    if not isinstance(loci, (tuple, list)):
        loci = [loci]

    loci_dfs = []
    for i, df in enumerate(loci):
        if isinstance(df, str):
            try:
                df = pandas.read_csv(df, sep='\t', usecols=[0, 1, 2], 
                    header=None, index_col=False, names=names)
            except:
                print("File Doesn't Exist!")
                return
            df['idx'] = numpy.arange(len(df)) * len(loci) + i
        loci_dfs.append(df)

    loci = pandas.concat(loci_dfs).set_index("idx").sort_index().reset_index(drop=True)

    if chroms is not None:
        loci = loci[numpy.isin(loci['chrom'], chroms)]

    # Load the signal and optional control tracks if filenames are given
    _signals = []
    if signals is not None:
        for i, signal in enumerate(signals):
            if isinstance(signal, str):
                try:
                    signal = pyBigWig.open(signal)
                except:
                    print("Null File")
                    return
            _signals.append(signal)

        signals = _signals

    _controls = []
    if controls is not None:
        for i, control in enumerate(controls):
            if isinstance(control, str):
                control = pyBigWig.open(control, "r")
            _controls.append(control)

        controls = _controls

    desc = "Loading Loci"
    d = not verbose

    max_width = max(in_width, out_width)
    loci_count = 0

    # print(loci)
    # print(loci.values)
    for chrom, start, end in tqdm(loci.values, disable=d, desc=desc):
        mid = start + (end - start) // 2

        if start - max_width - max_jitter < 0:
            continue

        if n_loci is not None and loci_count == n_loci:
            break 

        start = mid - out_width - max_jitter
        end = mid + out_width + max_jitter

        # Extract the signal from each of the signal files
        if signals is not None:
            signals_.append([])
            for signal in signals:
                if isinstance(signal, dict):
                    signal_ = signal[chrom][start:end]
                else:
                    try:
                        signal_ = signal.values(chrom, start, end, numpy=True)
                        signal_ = numpy.nan_to_num(signal_)
                    except:
                        print("error with interval bounds")
                        print(signal)
                        print(chrom)
                        print(start)
                        print(end)
                        print(signals_)

                signals_[-1].append(signal_)

        # For the sequences and controls extract a window the size of the input
        start = mid - in_width - max_jitter
        end = mid + in_width + max_jitter

        # Extract the controls from each of the control files
        if controls is not None:
            controls_.append([])
            for control in controls:
                if isinstance(control, dict):
                    control_ = control[chrom][start:end]
                else:
                    control_ = control.values(chrom, start, end, numpy=True)
                    control_ = numpy.nan_to_num(control_)

                controls_[-1].append(control_)

        loci_count += 1

    if signals is not None:
        signals_ = torch.tensor(numpy.array(signals_), dtype=torch.float32)

        idxs = torch.ones(signals_.shape[0], dtype=torch.bool)
        if max_counts is not None:
            idxs = (idxs) & (signals_.sum(dim=(1, 2)) < max_counts)
        if min_counts is not None:
            idxs = (idxs) & (signals_.sum(dim=(1, 2)) > min_counts)

        return signals_[idxs]


def PeakGenerator(loci, sequences, signals, controls=None, chroms=None, 
    in_window=2114, out_window=1000, max_jitter=0, reverse_complement=True, 
    min_counts=None, max_counts=None, random_state=None, pin_memory=True, 
    num_workers=0, batch_size=32, verbose=False):
    """This is a constructor function that handles all IO.

    This function will extract signal from all signal and control files,
    pass that into a DataGenerator, and wrap that using a PyTorch data
    loader. This is the only function that needs to be used.

    Parameters
    ----------
    loci: str or pandas.DataFrame or list/tuple of such
        Either the path to a bed file or a pandas DataFrame object containing
        three columns: the chromosome, the start, and the end, of each locus
        to train on. Alternatively, a list or tuple of strings/DataFrames where
        the intention is to train on the interleaved concatenation, i.e., when
        you want ot train on peaks and negatives.

    sequences: str or dictionary
        Either the path to a fasta file to read from or a dictionary where the
        keys are the unique set of chromosoms and the values are one-hot
        encoded sequences as numpy arrays or memory maps.

    signals: list of strs or list of dictionaries
        A list of filepaths to bigwig files, where each filepath will be read
        using pyBigWig, or a list of dictionaries where the keys are the same
        set of unique chromosomes and the values are numpy arrays or memory
        maps.

    controls: list of strs or list of dictionaries or None, optional
        A list of filepaths to bigwig files, where each filepath will be read
        using pyBigWig, or a list of dictionaries where the keys are the same
        set of unique chromosomes and the values are numpy arrays or memory
        maps. If None, no control tensor is returned. Default is None. 

    chroms: list or None, optional
        A set of chromosomes to extact loci from. Loci in other chromosomes
        in the locus file are ignored. If None, all loci are used. Default is
        None.

    in_window: int, optional
        The input window size. Default is 2114.

    out_window: int, optional
        The output window size. Default is 1000.

    max_jitter: int, optional
        The maximum amount of jitter to add, in either direction, to the
        midpoints that are passed in. Default is 128.

    reverse_complement: bool, optional
        Whether to reverse complement-augment half of the data. Default is True.

    min_counts: float or None, optional
        The minimum number of counts, summed across the length of each example
        and across all tasks, needed to be kept. If None, no minimum. Default 
        is None.

    max_counts: float or None, optional
        The maximum number of counts, summed across the length of each example
        and across all tasks, needed to be kept. If None, no maximum. Default 
        is None.  

    random_state: int or None, optional
        Whether to use a deterministic seed or not.

    pin_memory: bool, optional
        Whether to pin page memory to make data loading onto a GPU easier.
        Default is True.

    num_workers: int, optional
        The number of processes fetching data at a time to feed into a model.
        If 0, data is fetched from the main process. Default is 0.

    batch_size: int, optional
        The number of data elements per batch. Default is 32.

    verbose: bool, optional
        Whether to display a progress bar while loading. Default is False.

    Returns
    -------
    X: torch.utils.data.DataLoader
        A PyTorch DataLoader wrapped DataGenerator object.
    """

    X = extract_loci(loci=loci, sequences=sequences, signals=signals, 
        controls=controls, chroms=chroms, in_window=in_window, 
        out_window=out_window, max_jitter=max_jitter, min_counts=min_counts,
        max_counts=max_counts, verbose=verbose)

    if controls is not None:
        sequences, signals_, controls_ = X
    else:
        sequences, signals_ = X
        controls_ = None

    X_gen = DataGenerator(sequences, signals_, controls=controls_, 
        in_window=in_window, out_window=out_window, max_jitter=max_jitter,
        reverse_complement=reverse_complement, random_state=random_state)

    X_gen = torch.utils.data.DataLoader(X_gen, pin_memory=pin_memory,
        num_workers=num_workers, batch_size=batch_size) 

    return X_gen

In [4]:
import torch

# from datafunctions_withprofile import extract_loci, PeakGenerator
#from DefinitelyNotBPNet import DefinitelyNotBPNet

root = "/users/myin25/projects/celltype_specificity/"
cell_types = ['K562', 'CACO2', 'A673', 'HUVEC']
histone_types = ['H3K9me3', 'H3K4me1', 'H3K27me3', 'H3K27ac', 'H3K36me3', 'H3K4me3', 'H3K9ac', 'H3K79me2']
sequences = root + "refs/hg38.fasta"
peaks = root + "data/procap/union_peaks_fold1_train.bed.gz"
peaks_val = root + "data/procap/union_peaks_fold1_val.bed.gz"

histonefolders = {celltype: root + "data/{}".format(celltype) for celltype in cell_types}

def getfoldchangebigwig(cell, histone):
    if ((cell == 'CACO2' and (histone == 'H3K27ac' or histone == 'H3K9ac' or histone == 'H3K79me2')) or
    (cell == 'A673' and (histone == 'H3K9ac' or histone == 'H3K79me2'))):
        return ""
    return histonefolders[cell] + '/{}/foldchange.bigWig'.format(histone)

# Paths to different histones


levels_actual_path = {'K562' : [root + "data/procap/observed/K562/5prime.neg.bigWig", root + "data/procap/observed/K562/5prime.pos.bigWig"],
                       'CACO2' : [root + "data/procap/observed/CACO2/5prime.neg.bigWig", root + "data/procap/observed/CACO2/5prime.pos.bigWig"],
                       'A673' : [root + "data/procap/observed/A673/5prime.neg.bigWig", root + "data/procap/observed/A673/5prime.pos.bigWig"],
                       'HUVEC' : [root + "data/procap/observed/HUVEC/5prime.neg.bigWig", root + "data/procap/observed/HUVEC/5prime.pos.bigWig"]}

# peaks = root + "data/procap/observed/K562/peaks_fold1_train.bed.gz"
# peaks_val = root + "data/procap/observed/K562/peaks_fold1_val.bed.gz"

training_chroms = ['chr{}'.format(i) for i in range(1, 23)]
training_chroms.append('chrX')
training_chroms.append('chrY')
valid_chroms = ['chr{}'.format(i) for i in range(1, 23)]
valid_chroms.append('chrX')
valid_chroms.append('chrY')


In [5]:
cell = 'K562'
y_train = levels_actual_path[cell]

signalpaths = []
for histone in histone_types:
    if not ((cell == 'CACO2' and (histone == 'H3K27ac' or histone == 'H3K9ac' or histone == 'H3K79me2')) or
    (cell == 'A673' and (histone == 'H3K9ac' or histone == 'H3K79me2'))):
        signalpaths.append(getfoldchangebigwig(cell, histone))
    else:
        signalpaths.append("")



In [6]:
training_data = PeakGenerator(peaks, sequences, y_train, signalpaths, chroms=training_chroms)

loci shape (77109, 3)
done iterating through controls
done with chrom, start, end


In [7]:
def getallhistonesignals(signalpaths):
    _, y_valid, controls = extract_loci(loci=peaks_val, sequences=sequences, signals=y_train, controls=signalpaths, chroms=valid_chroms)

    return y_valid, controls


In [8]:
y_valid, X_valid = getallhistonesignals(signalpaths)

loci shape (14269, 3)
done iterating through controls
done with chrom, start, end


In [15]:
model = DefinitelyNotBPNet(8).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

In [None]:
model.fit(training_data, optimizer, X_valid=X_valid, y_valid=y_valid)

In [None]:
# Check the expected MSE when predicting just using the average
_, train_signalss__ = extract_loci(loci=peaks, sequences=sequences, signals=y_train, controls=None, chroms=training_chroms)


# Read in signals, take sum across axis to get average log observed value
train_signalss__ = extract_signals(loci=peaks, sequences=sequences, signals=y_train, controls=None, chroms=training_chroms)

train_signalss__ = torch.sum(train_signalss__, (1, 2))
print(train_signalss__)

train_signalss__ = torch.log(train_signalss__ + 1)
print(train_signalss__)

plt.hist(train_signalss__, bins=20)
print(torch.mean(train_signalss__))

# Calculate the mse
avgmse = torch.fill_(torch.empty((14269)), 2.6262)

# Compare the logs of the two datasets
count_mse = torch.square(torch.log(y_valid + 1) - avgmse)
count_mse = torch.mean(count_mse.squeeze(), dim=-1)
print(count_mse)