In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

### Tutorial B7: Cartesian Product

Sometimes, when using multi-input models, one wants to run a function on the cartesian product between the two inputs. Put another way, given a sequence input $X$ and some other input $Y$, one wants to make predictions for $(X_0, Y_0), (X_0, Y_1)... (X_n, Y_m)$ where $X$ has $n$ elements and $Y$ has $m$ inputs. Of course, one could simply run each of the functions for a fixed value of $Y$ across all $X$ (or vice-versa) and then change the value of $Y$ each time. However, having to code this yourself is not convenient and can easily be implemented in an inefficient manner, particularly if one is going to encounter settings where they sometimes have a small number of $X$ and other times have a small number of $Y$. 

Instead of having to implement this yourself, `tangermeme` provides `apply_pairwise` and `apply_product` to make applying functions across products like this easy and memory efficient. `apply_pairwise` yields examples from the pairwise product between `X` and a set of arguments whose ordering is paired. For example, if you have $X$ that has $5$ elements in it and arguments $a$ and $b$ that each have $3$, you would get $(X_0, a_0, b_0), (X_0, a_1, b_1), ... (X_4, a_1, b_1), (X_4, a_2, b_2)$ In contrast, `apply_product` applies a function to the cartesian product of the sequences and each of the arguments provided. This means that you would instead get $(X_0, a_0, b_0), (X_0, a_0, b_1), (X_0, a_0, b_2) ... (X_2, a_0, b_1)... (X_4, a_2, b_1), (X_4, a_2, b_2)$. Although `apply_product` is more general, in the sense that it can be applied across any number of arguments, it is not the right function to run when you have paired inputs like cell information. 

In theory, the most conceptually simple way to set up this function is to unravel the entire product into CPU memory and then run the provided function on the entire thing. However, this can take a huge amount of memory, particularly if the product is over several elements. In practice, it's better to construct each batch iteratively and only run one batch at a time through the model. That way, only the model predictions are stored in CPU memory as opposed to the (usually much larger) inputs.

Let's see all this in action with a toy model that takes an input, flattens it, and applies an optional linear transformation. 

In [2]:
import torch

class FlattenDense(torch.nn.Module):
	def __init__(self, length=10):
		super(FlattenDense, self).__init__()
		self.dense = torch.nn.Linear(length*4, 3)

	def forward(self, X, alpha=0, beta=1):
		X = X.reshape(X.shape[0], -1)
		return self.dense(X) * beta + alpha

This model has two optional inputs: `alpha`, which is an additive constant on the output from the dense layer, and `beta`, which is a multiplicative factor. Yes, it's redundant to have these factors after a dense layer which is doing a pretty similar thing, but this is meant just to demonstrate how to use the functions and to confirm that it's doing the expected thing.

Let's start off by generating some random one-hot encodings and running the model on them. 

In [3]:
from tangermeme.utils import random_one_hot
torch.manual_seed(0)

X = random_one_hot((5, 4, 10), random_state=0).float()
model = FlattenDense()

y = model(X)
y

tensor([[-0.3154, -0.1625, -0.3183],
        [-0.0866,  0.5461, -0.0244],
        [ 0.3089, -0.2828, -0.1485],
        [ 0.1671, -0.1341, -0.3094],
        [-0.0627,  0.0088,  0.3471]], grad_fn=<AddBackward0>)

#### Apply Pairwise

`apply_pairwise` is the correct function to use if you have data that has two axes, where one of the axes is sequences, and the other axis contains multiple tensors of paired information. As an example, if you have a DragoNNFruit model which makes predictions for chromatin accessibility for each cell in a single-cell ATAC-seq experiment, the inputs are sequences, a vector representing the state of the cell, and the read depth of the cell. Because cell state and read depth are paired -- both come from the same cell -- you want to do the product between `X` and `(cell_state, read_depth)` such that you get $(X_0, c_0, r_0), (X_0, c_1, r_1), (X_0, c_2, r_2)...$. Importantly, you do not want to do the full cross product because that will create examples where the read depths and cell states come from different cells.

##### Predict

We can begin by checking what the predictions would be when using this function with arguments that only have a batch size of 1. Conceptually, this should be identical to just running the predict function, and we can compare our results here to the predictions that we got above.

In [4]:
from tangermeme.predict import predict
from tangermeme.product import apply_pairwise

torch.manual_seed(0)
alpha = torch.zeros(1, 1)
beta = torch.ones(1, 1)

y_product = apply_pairwise(predict, model, X, args=(alpha, beta))[:, 0]
y_product

tensor([[-0.3154, -0.1625, -0.3183],
        [-0.0866,  0.5461, -0.0244],
        [ 0.3089, -0.2828, -0.1485],
        [ 0.1671, -0.1341, -0.3094],
        [-0.0627,  0.0088,  0.3471]])

Looks like the values are identical, although we do have to index a little bit because the additional index corresponds to the length of the argument tensors.

Next, we can look at what happens when we set `alpha` and `beta` to be more than just one example.

In [5]:
alpha = torch.zeros(2, 1)
beta = torch.ones(2, 1)

y_product = apply_pairwise(predict, model, X, args=(alpha, beta))
y_product

tensor([[[-0.3154, -0.1625, -0.3183],
         [-0.3154, -0.1625, -0.3183]],

        [[-0.0866,  0.5461, -0.0244],
         [-0.0866,  0.5461, -0.0244]],

        [[ 0.3089, -0.2828, -0.1485],
         [ 0.3089, -0.2828, -0.1485]],

        [[ 0.1671, -0.1341, -0.3094],
         [ 0.1671, -0.1341, -0.3094]],

        [[-0.0627,  0.0088,  0.3471],
         [-0.0627,  0.0088,  0.3471]]])

Here, we see that the results are the same for adjacent predictions, which makes sense because `alpha` is just zeros in both cases and `beta` is just ones in both cases. Next, we can see that changing the values of `alpha` and `beta` will lead to different predictions.

In [6]:
alpha = torch.randn(2, 1)
beta = torch.randn(2, 1)

y_product = apply_pairwise(predict, model, X, args=(alpha, beta))
y_product

tensor([[[ 2.2283,  1.8950,  2.2344],
         [-0.4727, -0.3858, -0.4743]],

        [[ 1.7297,  0.3512,  1.5941],
         [-0.3427,  0.0170, -0.3073]],

        [[ 0.8680,  2.1571,  1.8646],
         [-0.1178, -0.4542, -0.3779]],

        [[ 1.1769,  1.8331,  2.2151],
         [-0.1984, -0.3696, -0.4693]],

        [[ 1.6775,  1.5218,  0.7847],
         [-0.3290, -0.2884, -0.0961]]])

As mentioned repeatedly, `tangermeme` tries to be as assumption-free as possible. This means that `alpha` and `beta` can be any shape that works with the math provided in the implementation. Because three outputs are generated for each example, we can have our `alpha` and `beta` tensors also have three dimensions.

In [7]:
alpha = torch.zeros(1, 3)
beta = torch.ones(1, 3)

y_product = apply_pairwise(predict, model, X, args=(alpha, beta))
y_product.shape

torch.Size([5, 1, 3])

##### Attributions

In addition to working with the `predict` function, these product functions can take in any other tangermeme function and apply them to the respect product of examples. This means that we can apply `deep_lift_shap` just eas easily as we apply `predict`. 

In [8]:
from tangermeme.deep_lift_shap import deep_lift_shap

y_attr = apply_pairwise(deep_lift_shap, model, X, args=(alpha, beta))
y_attr.shape

torch.Size([5, 1, 4, 10])

The shape follows from the previous examples: the first dimension is the size of `X`, the second dimension is the size of `alpha` and `beta`, and the remaining dimensions are those from the function being applied.

##### Marginalize

Next, we can apply `marginalize` just as easily as we can apply `predict`. A major difference in the output here will be that there will be two tensors returned: one before making the substitution, and one after. Importantly, when using `apply_pairwise` and `apply_product` additional arguments can be passed into the inner function positionally as simply more arguments. Note the "TGA" below. 

In [9]:
from tangermeme.marginalize import marginalize

y_before, y_after = apply_pairwise(marginalize, model, X, motif="TGA", args=(alpha, beta))
y_before[:, 0], y_after[:, 0]

(tensor([[-0.3154, -0.1625, -0.3183],
         [-0.0866,  0.5461, -0.0244],
         [ 0.3089, -0.2828, -0.1485],
         [ 0.1671, -0.1341, -0.3094],
         [-0.0627,  0.0088,  0.3471]]),
 tensor([[-0.0615, -0.2536, -0.1744],
         [-0.1973,  0.6584,  0.2584],
         [ 0.2046,  0.1125, -0.0750],
         [ 0.0317,  0.0328, -0.1166],
         [ 0.0374,  0.1503,  0.4602]]))

If we wanted to also pass in an argument for `start` we could just keep adding in arguments.

In [10]:
y_before, y_after = apply_pairwise(marginalize, model, X, motif="TGA", start=0, args=(alpha, beta))
y_before[:, 0], y_after[:, 0]

(tensor([[-0.3154, -0.1625, -0.3183],
         [-0.0866,  0.5461, -0.0244],
         [ 0.3089, -0.2828, -0.1485],
         [ 0.1671, -0.1341, -0.3094],
         [-0.0627,  0.0088,  0.3471]]),
 tensor([[-0.3603, -0.2071, -0.2159],
         [-0.1005,  0.3231, -0.1471],
         [ 0.1569, -0.3900, -0.1329],
         [ 0.1721, -0.2478, -0.2496],
         [-0.1870, -0.0630,  0.1463]]))

Note that the values before making the substitution are the same, but the values after are different.

Naturally, being able to pass in any function, e.g., marginalize, and being able to pass in any arguments to those functions makes it possible to nest functions even further! After all, `marginalize` itself defaults to predictions but can apply other functions just as easily. Although the signature will be a little bit messy, we can easily use `apply_pairwise` with the `marginalize` function that itself is applying `deep_lift_shap` instead of `predict`! All we have to do is use the `additional_func_kwargs` argument, which is a dictionary of arguments that get passed directly into the provided func. This is somewhat redundant with passing in arguments directly, but circumvents issues where you want to pass an argument into `func` that is the same name as an argument needed by `apply_pairwise`. 

In [11]:
y_before, y_after = apply_pairwise(marginalize, model, X, motif="TGA", alphabet=['A', 'C', 'G', 'T',], 
                                   additional_func_kwargs={'func': deep_lift_shap}, args=(alpha, beta))
y_before.shape, y_after.shape

(torch.Size([5, 1, 4, 10]), torch.Size([5, 1, 4, 10]))

Even though it is a little messy to define the signature, look at how easy it is to do marginalized attributions across a product of examples, and you have the power to change any of the arguments in any of the functions called along the way. You can now do it in a single line instead of having to think of how to efficiently do each of the parts.

#### Apply Product

In contrast to `apply_pairwise`, `apply_product` is a more general function that will construct examples from the product of any number of arguments that have been passed in. If you have a model that takes in many inputs and each input corresponds to an orthogonal sort of value, e.g., a model that takes in DNA sequence, and protein sequence, and some sort of conditions, etc, and predicts something like binding structure, this would be the function for you. The signature is identical to `apply_pairwise` except the function is applied to more constructed examples.

Let's start off by seeing this in action with the same prediction as before.

In [12]:
from tangermeme.product import apply_product

alpha = torch.zeros(1, 1)
beta = torch.ones(1, 1)

y_product = apply_product(predict, model, X, args=(alpha, beta))[:, 0, 0]
y_product

tensor([[-0.3154, -0.1625, -0.3183],
        [-0.0866,  0.5461, -0.0244],
        [ 0.3089, -0.2828, -0.1485],
        [ 0.1671, -0.1341, -0.3094],
        [-0.0627,  0.0088,  0.3471]])

Looks like we are getting the same thing as before, except that there is an additional axis that needs to be indexed into because on of the axes corresponds to `alpha` and one of them corresponds to `beta`.

Since all we are doing is adding a value in a broadcasted manner, we can easily check by adding in the appropriate dimensions and doing the addition outside the context of this function.

In [13]:
alpha = torch.randn(3, 1)

y_product = apply_product(predict, model, X, args=(alpha,))
y_product

tensor([[[-1.4000, -1.2470, -1.4028],
         [-1.7140, -1.5611, -1.7168],
         [ 0.0879,  0.2409,  0.0851]],

        [[-1.1711, -0.5384, -1.1089],
         [-1.4852, -0.8525, -1.4230],
         [ 0.3167,  0.9494,  0.3790]],

        [[-0.7756, -1.3673, -1.2330],
         [-1.0897, -1.6814, -1.5471],
         [ 0.7123,  0.1206,  0.2548]],

        [[-0.9174, -1.2186, -1.3939],
         [-1.2315, -1.5327, -1.7080],
         [ 0.5704,  0.2693,  0.0940]],

        [[-1.1472, -1.0757, -0.7374],
         [-1.4613, -1.3898, -1.0515],
         [ 0.3407,  0.4122,  0.7505]]])

In [14]:
y.unsqueeze(1) + alpha.unsqueeze(0)

tensor([[[-1.4000, -1.2470, -1.4028],
         [-1.7140, -1.5611, -1.7168],
         [ 0.0879,  0.2409,  0.0851]],

        [[-1.1711, -0.5384, -1.1089],
         [-1.4852, -0.8525, -1.4230],
         [ 0.3167,  0.9494,  0.3790]],

        [[-0.7756, -1.3673, -1.2330],
         [-1.0897, -1.6814, -1.5471],
         [ 0.7123,  0.1206,  0.2548]],

        [[-0.9174, -1.2186, -1.3939],
         [-1.2315, -1.5327, -1.7080],
         [ 0.5704,  0.2693,  0.0940]],

        [[-1.1472, -1.0757, -0.7374],
         [-1.4613, -1.3898, -1.0515],
         [ 0.3407,  0.4122,  0.7505]]], grad_fn=<AddBackward0>)

Same values. If we add in a `beta` value, we see the same thing.

In [15]:
torch.manual_seed(0)
alpha = torch.randn(3, 1)
beta = torch.randn(1, 1)

y_product = apply_product(predict, model, X, args=(alpha, beta))[:, :, 0]
y_product

tensor([[[ 1.3617,  1.4486,  1.3601],
         [-0.4727, -0.3858, -0.4743],
         [-2.3581, -2.2711, -2.3597]],

        [[ 1.4918,  1.8514,  1.5271],
         [-0.3427,  0.0170, -0.3073],
         [-2.2280, -1.8684, -2.1926]],

        [[ 1.7166,  1.3803,  1.4566],
         [-0.1178, -0.4542, -0.3779],
         [-2.0032, -2.3395, -2.2632]],

        [[ 1.6360,  1.4648,  1.3651],
         [-0.1984, -0.3696, -0.4693],
         [-2.0838, -2.2550, -2.3546]],

        [[ 1.5054,  1.5460,  1.7383],
         [-0.3290, -0.2884, -0.0961],
         [-2.2144, -2.1738, -1.9815]]])

In [16]:
y.unsqueeze(1) * beta.unsqueeze(0) + alpha.unsqueeze(0)

tensor([[[ 1.3617,  1.4486,  1.3601],
         [-0.4727, -0.3858, -0.4743],
         [-2.3581, -2.2711, -2.3597]],

        [[ 1.4918,  1.8514,  1.5271],
         [-0.3427,  0.0170, -0.3073],
         [-2.2280, -1.8684, -2.1926]],

        [[ 1.7166,  1.3803,  1.4566],
         [-0.1178, -0.4542, -0.3779],
         [-2.0032, -2.3395, -2.2632]],

        [[ 1.6360,  1.4648,  1.3651],
         [-0.1984, -0.3696, -0.4693],
         [-2.0838, -2.2550, -2.3546]],

        [[ 1.5054,  1.5460,  1.7383],
         [-0.3290, -0.2884, -0.0961],
         [-2.2144, -2.1738, -1.9815]]], grad_fn=<AddBackward0>)