In [1]:
import sys,os
current_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))


import torch
import torch.optim as optim
import numpy as np
import scanpy as sc
import scvelo as scv
from model import NETWORK  # Ensure that model.py is saved in the same directory
from dataloaders import * # Ensure that dataloaders.py is saved in the same directory
from utils import *
from sklearn.manifold import Isomap


# Setup configuration
latent_dim = 64  # Latent dimension size, can be adjusted
hidden_dim = 512  # Hidden dimension size for the encoder and decoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

n_components = 100
n_knn_search = 10
dataset_name = "gastrulation"
cell_type_key = "celltype"
model_name = "VeloFormer"

num_genes = 2000
nhead = 1 #original: 1
embedding_dim = 128*nhead# original: 128
num_encoder_layers = 1 #original: 1
num_bins = 50
batch_size = 24  # Batch size for training
epochs = 1 # Number of epochs for training
learning_rate = 1e-4  # Learning rate for the optimizer
lambda1 = 1e-1  # Weight for heuristic loss
lambda2 = 1 # Weight for discrepancy loss
K = 11  # Number of neighbors for heuristic loss

# Load data
adata = sc.read_h5ad("gastrulation_processed.h5ad")
# Initialize model, optimizer, and loss function
model = NETWORK(input_dim=adata.shape[1]*2, latent_dim=latent_dim, 
                hidden_dim=hidden_dim, emb_dim = embedding_dim,
                nhead=nhead, num_encoder_layers=num_encoder_layers,
                num_genes=adata.shape[1], num_bins=num_bins).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Set up data loaders
train_loader, test_loader, full_data_loader = setup_dataloaders_binning(adata, 
                                                                       batch_size=batch_size, 
                                                                       num_genes=num_genes,
                                                                       num_bins=num_bins)

# Training loop
for epoch in range(epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    for batch_idx, (tokens, data, batch_indices) in enumerate(full_data_loader):
        tokens = tokens.to(device)
        data = data.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        """print(tokens.shape)
        print(data.shape)
        print(batch_indices.shape)"""
        
        # Forward pass
        out_dic = model(tokens, data)
        
        # Compute loss
        losses_dic = model.heuristic_loss(
            adata=adata, 
            x=data, 
            batch_indices=batch_indices,
            lambda1=lambda1, 
            lambda2=lambda2, 
            out_dic=out_dic, 
            device=device,
            K=K
        )
        
        # Backward pass and optimization
        loss = losses_dic["total_loss"]
        loss.backward()
        optimizer.step()
        
        # Accumulate loss for monitoring
        running_loss += loss.item()
        
        if batch_idx % 10 == 0:  # Print every 10 batches
            print(f'Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}], Loss: {loss.item()}')
    
    print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {running_loss / len(train_loader)}')

    # Save the model periodically
    """if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')"""

# After training, save final model
torch.save(model.state_dict(), 'model.pth')

Epoch [1/1], Batch [0], Loss: 0.11655070632696152
Epoch [1/1], Batch [10], Loss: 0.08587443828582764
Epoch [1/1], Batch [20], Loss: 0.06725024431943893
Epoch [1/1], Batch [30], Loss: 0.04546913504600525
Epoch [1/1], Batch [40], Loss: 0.07347115129232407
Epoch [1/1], Batch [50], Loss: 0.04695751890540123
Epoch [1/1], Batch [60], Loss: 0.044837117195129395
Epoch [1/1], Batch [70], Loss: 0.06530892848968506
Epoch [1/1], Batch [80], Loss: 0.05635891854763031
Epoch [1/1], Batch [90], Loss: 0.05474396422505379
Epoch [1/1], Batch [100], Loss: 0.06033838540315628
Epoch [1/1], Batch [110], Loss: 0.043794821947813034
Epoch [1/1], Batch [120], Loss: 0.05035460740327835
Epoch [1/1], Batch [130], Loss: 0.045936789363622665
Epoch [1/1], Batch [140], Loss: 0.047553952783346176
Epoch [1/1], Batch [150], Loss: 0.04587051272392273
Epoch [1/1], Batch [160], Loss: 0.03747443109750748
Epoch [1/1], Batch [170], Loss: 0.032691504806280136
Epoch [1/1], Batch [180], Loss: 0.022634292021393776
Epoch [1/1], Batc

In [7]:
list(adata.obs["celltype"].unique())

['Epiblast',
 'Primitive Streak',
 'Visceral endoderm',
 'Nascent mesoderm',
 'Rostral neurectoderm',
 'Blood progenitors 2',
 'Mixed mesoderm',
 'ExE mesoderm',
 'Intermediate mesoderm',
 'Pharyngeal mesoderm',
 'Caudal epiblast',
 'PGC',
 'Mesenchyme',
 'Haematoendothelial progenitors',
 'Blood progenitors 1',
 'Surface ectoderm',
 'Gut',
 'Paraxial mesoderm',
 'Caudal neurectoderm',
 'Notochord',
 'Somitic mesoderm',
 'Caudal Mesoderm',
 'Erythroid1',
 'Def. endoderm',
 'Allantois',
 'Anterior Primitive Streak',
 'Endothelium',
 'Forebrain/Midbrain/Hindbrain',
 'Spinal cord',
 'Cardiomyocytes',
 'Erythroid2',
 'NMP',
 'Erythroid3',
 'Neural crest']

In [3]:
adata.obs["stage"]

index
cell_1         E6.5
cell_2         E6.5
cell_6         E6.5
cell_8         E6.5
cell_9         E6.5
               ... 
cell_139326    E8.5
cell_139327    E8.5
cell_139329    E8.5
cell_139330    E8.5
cell_139331    E8.5
Name: stage, Length: 89267, dtype: category
Categories (9, object): ['E6.5', 'E6.75', 'E7.0', 'E7.25', ..., 'E7.75', 'E8.0', 'E8.25', 'E8.5']