In [1]:
import torch
import numpy as np
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.utils import to_categorical
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import umap.umap_ as umap
import random
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
path_to_h5ad = ""

In [None]:
adata = sc.read(path_to_h5ad)
adata = adata[adata.obs["stage"].isin(["E12","CS7"])]
#adata = adata[random.sample(adata.obs_names.to_list(), 6000)] # if exceeds 6,000 cells

sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=30)
sc.pp.normalize_per_cell(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=8000)
adata = adata[:, adata.var['highly_variable']]

data=pd.DataFrame(adata.X.todense())
data = data.to_numpy(dtype=np.float64)

In [None]:
train_data = data[:1100]
test_data = data[1100:]

# Convert data to PyTorch tensors and create data loaders
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_loader = DataLoader(TensorDataset(train_data), batch_size=200, shuffle=True)
test_loader = DataLoader(TensorDataset(test_data), batch_size=200, shuffle=True)




In [None]:
laten_size=30   #Size of the latent layer
layer1=100   #Size of the first layer of enocder and decoder
input_size=8000

# Define the variational autoencoder model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, layer1),
            nn.ReLU(),
            nn.Linear(layer1, laten_size),
        )
        self.decoder = nn.Sequential(
            nn.Linear(laten_size, layer1),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(layer1, input_size),
        )
        self.mu = nn.Linear(laten_size, laten_size)
        self.log_var = nn.Linear(laten_size, laten_size)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        log_var = self.log_var(x)
        z = self.reparameterize(mu, log_var)
        x = self.decoder(z)
        return x, mu, log_var

    def sample(self, num_samples):
        z = torch.randn(num_samples, laten_size)
        return self.decoder(z)

# Initialize the model and optimizer
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define the loss function: We also tried different loss functions Binary Cross-Entropy (BCE) Loss for 
# the reconstuciton loss of VAE but MSE worked better (lower error):
def vae_loss(x, x_recon, mu, log_var):
    recon_loss = nn.functional.mse_loss(x_recon, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_div

# Train the model
train_losses = []
test_losses = []
num_epochs = 150
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        x = batch[0]
        x_recon, mu, log_var = model(x)
        loss = vae_loss(x, x_recon, mu, log_var)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_losses.append(train_loss / len(train_data))

    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            x = batch[0]
            x_recon, mu, log_var = model(x)
            loss = vae_loss(x, x_recon, mu, log_var)
            test_loss += loss.item()
        test_losses.append(test_loss / len(test_data))

    print(f"Epoch {epoch+1}, Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}")

# Plot the training and validation loss
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.legend()
plt.show()

# Convert the data to a PyTorch tensor
data_tensor = torch.tensor(data).float()

# Compute the latent layer for the data
with torch.no_grad():
    model.eval()
    latent_layer = model.encoder(data_tensor)

In [8]:
latent_layer = (latent_layer - latent_layer.mean())/(latent_layer.std())

In [None]:
# Visualize the latent layer using UMAP
umap_embedding = umap.UMAP(n_neighbors=100, min_dist=0.3, random_state=42).fit_transform(latent_layer)
plt.scatter(umap_embedding[:, 0], umap_embedding[:, 1], c=numeric_labels, s=5, cmap='viridis')
plt.show()
dataset = torch.Tensor(latent_layer).float()

In [None]:
def make_beta_schedule(schedule='linear', num_steps=1000, start=1e-5, end=1e-2):
    if schedule == 'linear':
        betas = torch.linspace(start, end, num_steps)
    elif schedule == "quad":
        betas = torch.linspace(start ** 0.5, end ** 0.5, num_steps) ** 2
    elif schedule == "sigmoid":
        betas = torch.linspace(-6, 6, num_steps)
        betas = torch.sigmoid(betas) * (end - start) + start
    return betas

def extract(input, t, x):
    shape = x.shape
    out = torch.gather(input, 0, t.to(input.device))
    reshape = [t.shape[0]] + [1] * (len(shape) - 1)
    return out.reshape(*reshape)

def plot_schedule(num_steps,schedule):
    plt.plot(list(range(num_steps)),betas.numpy(),label='betas')
    plt.plot(list(range(num_steps)),torch.sqrt(alphas_prod).numpy(),label='sqrt_alphas_prod')
    plt.plot(list(range(num_steps)),torch.sqrt(1-alphas_prod).numpy(),label='sqrt_one_minus_alphas_prod')
    plt.legend(['betas','sqrt_alphas_prod','sqrt_one_minus_alphas_prod'],loc = 'upper left')
    plt.xlabel('steps')
    plt.ylabel('value')
    plt.title('{} schedule'.format(schedule))
    plt.show()

num_steps=1000

schedule='sigmoid'
betas = make_beta_schedule(schedule=schedule, num_steps=num_steps, start=1e-5, end=1e-2)
alphas = 1-betas
alphas_prod = torch.cumprod(alphas,0)
alphas_prod_p = torch.cat([torch.tensor([1]).float(),alphas_prod[:-1]],0)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
plot_schedule(num_steps,schedule)



In [11]:
def q_x(x_0,t):
    noise = torch.randn_like(x_0).to(device)
    alphas_t = alphas_bar_sqrt[t].to(device)
    alphas_1_m_t = one_minus_alphas_bar_sqrt[t].to(device)
    return (alphas_t * x_0 + alphas_1_m_t * noise)

In [13]:
class MLPDiffusion(nn.Module):
    def __init__(self,n_steps, num_units=512):
        super(MLPDiffusion,self).__init__()
        
        self.linears = nn.ModuleList(
            [
                nn.Linear(laten_size,num_units),
                nn.ReLU(),
                nn.Linear(num_units,num_units),
                nn.ReLU(),
                nn.Linear(num_units,num_units),
                nn.ReLU(),
                nn.Linear(num_units,laten_size),
            ]
        )
        self.step_embeddings = nn.ModuleList(
            [
                nn.Embedding(n_steps,num_units),
                nn.Embedding(n_steps,num_units),
                nn.Embedding(n_steps,num_units),
            ]
        )
    def forward(self,x,t):
#         x = x_0
        for idx,embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2*idx](x)
            x += t_embedding
            x = self.linears[2*idx+1](x)
            
        x = self.linears[-1](x)
        
        return x


In [None]:
def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps):
    """Sampling and calculating loss at any time t"""
    batch_size = x_0.shape[0]
    
    half_size = batch_size // 2
    if batch_size % 2 == 1:

        half_size = (batch_size + 1) // 2
    
    t = torch.randint(0, n_steps, size=(half_size,), device=x_0.device)
    t = torch.cat([t, n_steps-1-t], dim=0)
    
    t = t[:batch_size].unsqueeze(-1)
    
    a = alphas_bar_sqrt[t]
    aml = one_minus_alphas_bar_sqrt[t]

    e = torch.randn_like(x_0, device=x_0.device)
    
    x = (x_0 * a + e * aml).to(device)
    
    output = model(x, t.squeeze(-1))
    
    return (e - output).square().mean()

In [15]:
def p_sample_loop(model, shape, n_steps, betas, one_minus_alphas_bar_sqrt):
    """Restore x[T-1], x[T-2]|...x[0] from x[T]"""
    cur_x = torch.randn(shape).to(device)
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt).to(device)
        x_seq.append(cur_x)
    return x_seq


def p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt):
    """Sampling the reconstructed value at time t from x[T]"""
    device="cpu"
    t = torch.tensor([t]).to(device)
    betas = betas.to(device)
    one_minus_alphas_bar_sqrt = one_minus_alphas_bar_sqrt.to(device)
    coeff = (betas[t] / one_minus_alphas_bar_sqrt[t]).to(device)

    eps_theta = model(x, t).to(device)

    mean = (1 / (1 - betas[t]).sqrt()) * (x - (coeff * eps_theta)).to(device)

    z = torch.randn_like(x).to(device)
    sigma_t = betas[t].sqrt().to(device)

    sample = mean + sigma_t * z

    return sample

In [None]:
class EMA():
    def __init__(self,mu=0.001):
        self.mu = mu
        self.shadow = {}
        
    def register(self,name,val):
        self.shadow[name] = val.clone()
        
    def __call__(self,name,x):
        assert name in self.shadow
        new_average = self.mu * x + (1.0-self.mu)*self.shadow[name]
        self.shadow[name] = new_average.clone()
        return new_average

device="cuda"
total_loss=[]
print('Training model...')
batch_size = 1200
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
num_epoch = 1500
plt.rc('text',color='blue')
model2 = MLPDiffusion(num_steps)
model2.to(device)
model2 = torch.nn.DataParallel(model2).to(device)
alphas_bar_sqrt = alphas_bar_sqrt.to(device)
one_minus_alphas_bar_sqrt = one_minus_alphas_bar_sqrt.to(device)
optimizer = torch.optim.Adam(model2.parameters(),lr=1e-3)

for t in range(num_epoch):
    print("The values are: {} and {}".format(loss, t))
    for idx,batch_x in enumerate(dataloader):
        batch_x = batch_x.to(device)
        loss = diffusion_loss_fn(model2, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model2.parameters(),1.)
        optimizer.step()      

In [None]:
from collections import OrderedDict

# Specify a path to save to
PATH = "model_interpolate.pth" # Choose whatever you like
# Save
torch.save(model2.state_dict(), PATH)
# Load
device = torch.device('cpu')
model4 = MLPDiffusion(num_steps)

# Original saved file with DataParallel
state_dict = torch.load("model_interpolate.pth")
# create new OrderedDict that does not contain `module.`
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model4.load_state_dict(new_state_dict)

In [None]:
def p_sample_loop_optimized(model, shape, n_steps, betas, one_minus_alphas_bar_sqrt):
    cur_x = torch.randn(shape)

    for i in reversed(range(n_steps)):
        cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt)
        print(i)
    return cur_x 

x_final = p_sample_loop_optimized(model4, torch.Size([1000, 30]), num_steps, betas, one_minus_alphas_bar_sqrt)

In [None]:
cur_x = x_final.detach().cpu().numpy()
umap_emb = umap.UMAP(n_neighbors=100, min_dist=0.2).fit_transform(cur_x)
plt.scatter(umap_emb[:, 0], umap_emb[:, 1], s=10);  

In [None]:
generated_expr = model.decoder(cur_x)