In [1]:
import numpy as np

import torch
use_cuda = torch.cuda.is_available()

import altair as alt
alt.data_transformers.enable('default', max_rows=None)

import noise2self_sc as n2s

In [2]:
n_classes = 2
n_latent = 32
n_features = 512

n_cells_per_class = 1024
n_cells = n_cells_per_class * n_classes

exp, class_labels, programs, lib_size, umis = n2s.simulate.make_dataset(
    n_classes, n_latent, n_cells, n_features, 
    prog_kw=dict(scale=4 * np.sqrt(1 / n_features), sparsity=0.5),
    class_kw=dict(scale=1 / np.sqrt(n_latent), sparsity=1.0),
    library_kw=dict(loc=np.log(n_features * 0.1), scale=0.5)
)

lib_size.min(), lib_size.max(), exp.shape, (umis > 0).sum() / umis.size

(31, 255, (2048, 32), 0.11116504669189453)

In [3]:
X = torch.from_numpy(umis).to(torch.float)
loglib = torch.log(X.sum(1, keepdim=True))
Y = torch.sqrt(X)

batch_size = 1024

data_loader_train, data_loader_test = n2s.training.split_dataset(
    Y, loglib, X, batch_size=batch_size, train_p=0.5, use_cuda=use_cuda
)

In [5]:
   
dca_model = n2s.autoencoder.CountAutoencoder(
    n_input=X.shape[-1], 
    n_latent=32, 
    layers=[128, 128], 
    use_cuda=use_cuda
)

criterion = n2s.training.NegativeBinomialNLLoss(eps=1e-6)

num_epochs = 50
test_iter = 5
train_loss = []
test_loss = []

In [11]:
optimizer = torch.optim.SGD(
    dca_model.parameters(), lr=0.5, momentum=0.999, weight_decay=0.0001,
)


In [12]:
for epoch in range(num_epochs):
    train_loss.append(n2s.training.train_loop(dca_model, criterion, optimizer, data_loader_train, use_cuda))
    if epoch % test_iter == 0:
        print(f"[epoch {epoch:03d}]  average training loss: {train_loss[-1]:.5f}")

        test_loss.append(n2s.training.test_loop(dca_model, criterion, optimizer, data_loader_test, use_cuda))
        
print(f"final average training loss: {train_loss[-1]:.5f}")

[epoch 000]  average training loss: 0.38109
[epoch 005]  average training loss: 0.38103
[epoch 010]  average training loss: 0.38101
[epoch 015]  average training loss: 0.38080
[epoch 020]  average training loss: 0.38047
[epoch 025]  average training loss: 0.38006
[epoch 030]  average training loss: 0.37972
[epoch 035]  average training loss: 0.37922
[epoch 040]  average training loss: 0.37886
[epoch 045]  average training loss: 0.37822
final average training loss: 0.37775
