In [73]:
import torch
from torch.autograd import Variable
import loompy
from types import SimpleNamespace
from typing import *
import logging
from tqdm import trange

In [46]:
with loompy.connect("/Users/stelin/velocity_inference/DentateGyrus.loom") as ds:
    genes = ds.ra.Selected == 1
    gene_names = ds.ra.Gene[genes]
    cells = ds.ca.Selected == 1
    expression = ds[genes, :][:, cells]
    velocity = ds["velocity"][genes, :][:, cells]

In [135]:
class VelocityInference:
    def __init__(self, lr: float = 0.0001, n_epochs: int = 10, l1_factor: float = 1, l2_factor: float = 1) -> None:
        self.lr = lr
        self.n_epochs = n_epochs
        self.loss: float = 0
        self.W: np.ndarray = None
        self.l1_factor = l1_factor
        self.l2_factor = l2_factor

    def fit(self, X: np.ndarray, Y: np.ndarray) -> Any:
        """
        Args:
            x      (n_cells, n_genes)
            y      (n_cells, n_genes)
        """
        (n_cells, n_genes) = x.shape

        # Set up the optimization problem
        logging.info("Setting up the optimization problem")
        dt = torch.float32
        self.model = SimpleNamespace(
            x=Variable(torch.tensor(X, dtype=dt)),
            y=Variable(torch.tensor(Y, dtype=dt)),
            W=Variable(torch.tensor(np.random.gamma(1, 1, size=(n_genes, n_genes)), dtype=dt), requires_grad=True)
        )
        logging.info("Optimizing")
        self.epochs(self.n_epochs)
        return self

    def epochs(self, n_epochs: int) -> Any:
        m = self.model
        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.SGD([m.W], lr=self.lr)

        for epoch in trange(n_epochs):
            optimizer.zero_grad()
            prediction = m.x @ m.W
            l1_regularization = self.l1_factor * torch.norm(m.W, 1)
            l2_regularization = self.l2_factor * torch.norm(m.W, 2)
            loss_out = loss_fn(prediction, m.y) + l1_regularization + l2_regularization
            loss_out.backward()
            optimizer.step()
            m.W.data.clamp_(min=0)

        self.loss = float(loss_out)
        self.W = m.W.detach().numpy()
        return self

In [163]:
vi = VelocityInference(l1_factor = 1, n_epochs=100)
vi.fit(expression.T, velocity.T)

100%|██████████| 100/100 [05:12<00:00,  3.13s/it]


<__main__.VelocityInference at 0x10ac67748>

In [164]:
#vi.epochs(10)
gene_names[np.where(vi.W > 12)[0]]

array(['Acadl', 'Ptprn', 'Adora1', 'Ramp1', 'Brinp3', 'Syne1', 'Slc41a2',
       'Iyd', 'Gadd45b', 'Timeless', 'Gm12224', 'Guk1', 'Nfe2l1', 'Gfap',
       'Pitpna', 'Vtn', 'Dio2', 'Hist1h1b', 'Vcan', 'Cenph', 'Ntrk2',
       'Arsb', 'Prkcd', 'Otx2', 'Cmtm5', 'Ywhaz', 'Rac2', 'Shisa8',
       'Them6', 'Pdzrn4', 'Snn', 'Opa1', 'Olig2', 'Tcf19', 'Strn',
       'Psat1', 'Ostf1', 'Slc1a1', 'Atp5g3', 'Mal', 'Kif5c', 'Stmn2',
       'Rhoc', 'Wdr47', 'Gnb1', 'Reln', 'Kit', 'Slc4a4', 'Apobec1',
       'Usp5', 'Ndnf', 'Slc6a13', 'Rabac1', 'Dkkl1', 'Slc17a7', 'Nlrx1',
       'Htr1b', 'Arpp21', 'Esam', 'Tmem42', 'Kif4', 'Fam199x'],
      dtype=object)

In [165]:
density = np.count_nonzero(vi.W) / (vi.W.shape[0] * vi.W.shape[1])
density

0.9697340695387442

In [21]:
# The input and output dimensions of the neural network are the number of genes
(n_genes, n_cells) = expression.shape

# x is the the expression matrix
x = Variable(torch.tensor(expression.T, dtype=torch.float32))
# y is the velocity matrix
y = Variable(torch.tensor(velocity.T, dtype=torch.float32))

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(n_genes, n_genes)
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    l1_regularization, l2_regularization = torch.tensor(0, dtype=torch.float32), torch.tensor(0, dtype=torch.float32)
    for param in model.parameters():
        l1_regularization += torch.norm(param, 1)
        l2_regularization += torch.norm(param, 2)

    loss = loss_fn(y_pred, y) + l1_regularization + l2_regularization

    # Compute and print loss.
    print(t, loss.item(), l1_regularization.item(), l2_regularization.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()
    
    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 223379792.0 82201.375 32.20370864868164
1 183328096.0 82190.40625 32.20002746582031
2 152645168.0 82180.421875 32.197513580322266
3 129659248.0 82171.2890625 32.195960998535156
4 112787896.0 82162.640625 32.1950569152832
5 100621464.0 82154.140625 32.19450378417969
6 91959960.0 82145.4765625 32.19405746459961
7 85821696.0 82136.3671875 32.193511962890625
8 81432552.0 82126.609375 32.1927375793457
9 78203248.0 82116.1015625 32.19163513183594
10 75702408.0 82104.75 32.190147399902344
11 73627840.0 82092.515625 32.18824768066406
12 71778472.0 82079.421875 32.185935974121094
13 70028648.0 82065.484375 32.18321990966797
14 68306424.0 82050.7265625 32.18012237548828
15 66577556.0 82035.2109375 32.17667770385742
16 64834244.0 82019.0 32.172916412353516
17 63086280.0 82002.1328125 32.16887283325195
18 61353080.0 81984.6640625 32.16457748413086
19 59657456.0 81966.6484375 32.160057067871094
20 58021772.0 81948.109375 32.15533447265625
21 56465708.0 81929.109375 32.150428771972656
22 55004300.

In [40]:
beta = model[0].weight.detach().numpy()

In [61]:
import numpy as np
np.max(beta)

0.06284521

In [51]:
gene_names[513]

'Gm16141'

In [63]:
model[0].data.clamp_(min=0)

AttributeError: 'Linear' object has no attribute 'data'