In [1]:
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor
from torch.optim.lr_scheduler import StepLR

torch.__version__

'1.11.0+cu102'

In [2]:
class SingleCellDataset(Dataset):
    def __init__(self, file_name):
        self.adata = sc.read(file_name)
        self.x = self.adata.X
 
    def __len__(self):
        return len(self.x)
   
    def __getitem__(self,idx):
        return self.x[idx]


dataset = SingleCellDataset('./brain_normalized.h5ad')

In [3]:
batch_size = 64
dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
len(dataloader)

51

In [4]:
example_x = next(iter(dataloader))
example_x.shape

torch.Size([64, 18585])

In [5]:
example_x

tensor([[-0.0809,  0.7513, -0.5763,  ..., -0.5655, -0.0438, -0.6812],
        [-0.0809, -0.0248, -0.5763,  ..., -0.0751, -0.0438, -0.6812],
        [-0.0809, -1.4846, -0.5763,  ...,  1.2355, -0.0438,  2.0596],
        ...,
        [-0.0809,  0.5327, -0.5763,  ..., -0.5655, -0.0438, -0.6812],
        [-0.0809, -1.4846, -0.5763,  ...,  2.6709, -0.0438, -0.6812],
        [-0.0809,  0.4098, -0.3287,  ...,  2.2217, -0.0438,  1.1114]])

In [6]:
dim = int(example_x.shape[-1])

to_split = []
while dim > 1024:
    to_split.append(1024)
    dim -= 1024
to_split.append(dim)
to_split

[1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 1024,
 153]

In [7]:
device = 'cuda'

In [8]:
class Encoder(nn.Module):
    def __init__(self, to_split, embedding_size=256, nhead=8, dropout=0.1, num_layers=6, latent_size=3):
        super().__init__()
        
        self.embedding_size = embedding_size
        self.nhead = nhead
        self.to_split = to_split
        
        for i, dim in enumerate(to_split):
            e = nn.Sequential(
                nn.Linear(dim, embedding_size),
                nn.ReLU(),
                nn.Linear(embedding_size, embedding_size),
            )
            setattr(self, f'embedding_{i}', e)
        
        d_model = embedding_size
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dropout=dropout)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ave = nn.AvgPool1d(embedding_size)
        self.latent = nn.Linear(len(self.to_split), latent_size)

    def forward(self, x):
        s = torch.split(x, self.to_split, dim=1)
        z = []
        for i in range(len(self.to_split)):
            e = getattr(self, f'embedding_{i}')
            z.append(e(s[i]))
        z = torch.stack(z, dim=1)
        encoded = self.encoder(z)
        z = self.ave(encoded)
        z = torch.reshape(z, [-1, len(self.to_split)])
        z = self.latent(z)
        return z, encoded


encoder = Encoder(to_split).to(device)
z, encoded = encoder(example_x.to(device))
print(z.shape)
print(encoded.shape)

torch.Size([64, 3])
torch.Size([64, 19, 256])


In [9]:
class Decoder(nn.Module):
    def __init__(self, to_split, embedding_size=256, nhead=8, dropout=0.1, num_layers=6, latent_size=3):
        super().__init__()
        
        self.to_split = to_split
        self.embedding_size = embedding_size
        
        self.latent = nn.Linear(latent_size, len(self.to_split) * embedding_size)
        decoder_layer = nn.TransformerDecoderLayer(d_model=embedding_size, nhead=nhead, dropout=dropout)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        for i, dim in enumerate(to_split):
            o = nn.Sequential(
                nn.Linear(embedding_size, embedding_size),
                nn.ReLU(),
                nn.Linear(embedding_size, dim),
            )
            setattr(self, f'output_{i}', o)
        
    def forward(self, x, encoded):
        z = self.latent(x)
        z = torch.reshape(z, [-1, len(self.to_split), self.embedding_size])
        z = self.decoder(z, encoded)
        
        s = torch.split(z, 1, dim=1)
        output = []
        for i in range(len(self.to_split)):
            o = getattr(self, f'output_{i}')
            s2 = torch.reshape(s[i], [-1, self.embedding_size])
            output.append(o(s2))
        output = torch.concat(output, dim=1)
        return output


decoder = Decoder(to_split).to(device)
y = decoder(z, encoded)
print(y.shape)

torch.Size([64, 18585])


In [11]:
class AutoEncoder(nn.Module):
    def __init__(self, to_split, embedding_size=512, nhead=8, dropout=0.0, num_layers=8, latent_size=3):
        super().__init__()
        
        self.encoder = Encoder(to_split, embedding_size, nhead, dropout, num_layers, latent_size)
        self.decoder = Decoder(to_split, embedding_size, nhead, dropout, num_layers, latent_size)
        
    def forward(self, x):
        z, encoded = self.encoder(x)
        y = self.decoder(z, encoded)
        return y
        

model = AutoEncoder(to_split).to(device)
y = model(example_x.to(device))
y.shape

torch.Size([64, 18585])

In [12]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

In [13]:
def train(dataloader):
    size = len(dataloader.dataset)
    model.train()
    for batch, X in enumerate(dataloader):
        X = X.to(device)

        X_pred = model(X)
        loss = loss_fn(X_pred, X)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return loss


for epoch in range(1, 101):
    loss = train(dataloader)
    print(f'loss: {loss:>7f}')

loss: 1.063663
loss: 1.194027
loss: 1.116905
loss: 0.776303
loss: 0.811574
loss: 0.982197
loss: 1.116492
loss: 1.082717
loss: 0.884367
loss: 1.373050


KeyboardInterrupt: 