In [1]:
import torch
import numpy as np
import torch.nn as nn

In [2]:
%pip install torchinfo
from torchinfo import summary

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [3]:
class Sampler(nn.Module):

    def __init__(self, latent_features):
        super().__init__()
        self.latent_features = latent_features

    def forward(self, mean, std):
        epsilon = np.random.standard_normal()

        sample = mean + epsilon * std

        return sample


In [4]:
class VarAutoencoder(nn.Module):

    def __init__(self, input_features, latent_features):
        super().__init__()

        self.input_features = input_features
        self.latent_features = latent_features

        relu = nn.ReLU()
        sigmoid = nn.Sigmoid()

        hidden_features = 2*latent_features
        self.encoder = nn.Sequential(
            nn.Linear(input_features, hidden_features),
            relu,
            nn.Linear(hidden_features, 2*latent_features),
            relu,
        )

        self.sampler = Sampler(latent_features)

        self.decoder = nn.Sequential(
            nn.Linear(latent_features, hidden_features),
            relu,
            nn.Linear(hidden_features, input_features),
            sigmoid,
        )

    def forward(self, x):

        x = self.encoder(x)

        mean, logvar = x[:, :self.latent_features], x[:, self.latent_features:]
        std = torch.exp(0.5 * logvar)

        z = self.sampler(mean, std)

        x = self.decoder(z)

        return x, mean, std, z



In [5]:
class KLDivergence(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, mean, std):
        return (std ** 2 + mean ** 2 - torch.log(std) - 0.5).sum()

In [6]:
class MSELoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, y_true, y_pred):
        return torch.mean((y_true-y_pred)**2)

In [7]:
from torch.utils.data import Dataset, DataLoader

In [8]:
class SCRNASeqDataset(Dataset):

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        sample = self.data[idx]

        return sample

In [9]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [10]:
def train_epoch(dataloader, model, recon_loss_fn, regular_loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    sample_counter = 0
    for batch, x in enumerate(dataloader):
        optimizer.zero_grad()

        x = x.to(device)

        # Reconstruction and regularization losses
        xx, μ, σ, z = model(x)
        regular_loss = regular_loss_fn(μ, σ)
        recon_loss = recon_loss_fn(x, xx)

        # Backpropagation
        loss = regular_loss + recon_loss
        loss.backward()
        optimizer.step()

        sample_counter += len(x)
        if batch % size//10 == 0:
            loss, current = loss.item()/len(x), batch * (len(x)+1)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [11]:
def test(dataloader, model, recon_loss_fn, regular_loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()

    regular_loss = 0
    recon_loss = 0
    with torch.no_grad():
        for x in dataloader:
            x = x.to(device)

            # Reconstruction and regularization losses
            xx, μ, σ, z = model(x)
            regular_loss += regular_loss_fn(μ, σ)
            recon_loss += recon_loss_fn(x, xx)

    avg_regular = regular_loss / num_batches
    avg_recon = recon_loss / num_batches
    avg_total = avg_regular + avg_recon
    print(f"loss total: {(avg_total):>8f}\t regular: {(avg_total):>8f}\t recon: {(avg_total):>8f}\n")

In [12]:
model = VarAutoencoder(1000, 64)
model.to(device)
summary(model, input_size=(69, 1000))

Layer (type:depth-idx)                   Output Shape              Param #
VarAutoencoder                           [69, 1000]                --
├─Sequential: 1-1                        [69, 128]                 16,512
│    └─Linear: 2-1                       [69, 128]                 128,128
├─Sequential: 1-4                        --                        (recursive)
│    └─ReLU: 2-2                         [69, 128]                 --
├─Sequential: 1-3                        --                        (recursive)
│    └─Linear: 2-3                       [69, 128]                 16,512
├─Sequential: 1-4                        --                        (recursive)
│    └─ReLU: 2-4                         [69, 128]                 --
├─Sampler: 1-5                           [69, 64]                  --
├─Sequential: 1-6                        [69, 1000]                --
│    └─Linear: 2-5                       [69, 128]                 8,320
│    └─ReLU: 2-6                         [

In [13]:
#placeholder random data
scrnaseq_data = torch.tensor(np.random.standard_normal((69, 1000)).astype(np.float32))
scrnaseq_dataset = SCRNASeqDataset(scrnaseq_data)

scrnaseq_dataloader = DataLoader(scrnaseq_dataset, batch_size=16, shuffle=True)

In [14]:
regular_loss_fn = KLDivergence()
recon_loss_fn = MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5E-3)

In [15]:
num_epochs = 5

for e in range(num_epochs):
    print(f"Epoch {e+1}\n-------------------------------")
    train_epoch(scrnaseq_dataloader, model, recon_loss_fn, regular_loss_fn, optimizer)
    test(scrnaseq_dataloader, model, regular_loss_fn, recon_loss_fn)

Epoch 1
-------------------------------
loss: 37.800728  [    0/   69]
loss: 34.076790  [   17/   69]
loss: 32.742268  [   34/   69]
loss: 32.239079  [   51/   69]
loss: 32.287088  [   24/   69]
loss total: 21970.658203	 regular: 21970.658203	 recon: 21970.658203

Epoch 2
-------------------------------
loss: 32.451221  [    0/   69]
loss: 32.169357  [   17/   69]
loss: 32.462616  [   34/   69]
loss: 32.218540  [   51/   69]
loss: 32.291174  [   24/   69]
loss total: 93253.257812	 regular: 93253.257812	 recon: 93253.257812

Epoch 3
-------------------------------
loss: 32.062057  [    0/   69]
loss: 32.071823  [   17/   69]
loss: 32.062698  [   34/   69]
loss: 32.062908  [   51/   69]
loss: 32.200491  [   24/   69]
loss total: 64799.542969	 regular: 64799.542969	 recon: 64799.542969

Epoch 4
-------------------------------
loss: 32.062027  [    0/   69]
loss: 32.065777  [   17/   69]
loss: 32.064735  [   34/   69]
loss: 32.062687  [   51/   69]
loss: 32.197891  [   24/   69]
loss total