In [1]:
import sys,os
current_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))

import torch
import torch.optim as optim
import numpy as np
import scanpy as sc
import scvelo as scv
from model import NETWORK  # Ensure that model.py is saved in the same directory
from dataloaders import * # Ensure that dataloaders.py is saved in the same directory
from utils import *
from sklearn.manifold import Isomap


# Setup configuration
latent_dim = 64  # Latent dimension size, can be adjusted
hidden_dim = 512  # Hidden dimension size for the encoder and decoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

n_components = 100
n_knn_search = 10
dataset_name = "gastrulation_erythroid"
cell_type_key = "clusters"
model_name = "imVelo"

num_genes = 2000
nhead = 1 #original: 1
embedding_dim = 128*nhead# original: 128
num_encoder_layers = 1 #original: 1
num_bins = 50
batch_size = 64  # Batch size for training
epochs = 20  # Number of epochs for training
learning_rate = 1e-4  # Learning rate for the optimizer
lambda1 = 1e-1  # Weight for heuristic loss
lambda2 = 1 # Weight for discrepancy loss
K = 11  # Number of neighbors for heuristic loss

# Load data
adata = sc.read_h5ad("pancreas_minmax.h5ad")
# Initialize model, optimizer, and loss function
model = NETWORK(input_dim=adata.shape[1]*2, latent_dim=latent_dim, 
                hidden_dim=hidden_dim, emb_dim = embedding_dim,
                nhead=nhead, num_encoder_layers=num_encoder_layers,
                num_genes=adata.shape[1], num_bins=num_bins).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Set up data loaders
train_loader, test_loader, full_data_loader = setup_dataloaders_binning(adata, 
                                                                       batch_size=batch_size, 
                                                                       num_genes=num_genes,
                                                                       num_bins=num_bins)

# Training loop
for epoch in range(epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    for batch_idx, (tokens, data, batch_indices) in enumerate(full_data_loader):
        tokens = tokens.to(device)
        data = data.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        """print(tokens.shape)
        print(data.shape)
        print(batch_indices.shape)"""
        
        # Forward pass
        out_dic = model(tokens, data)
        
        # Compute loss
        losses_dic = model.heuristic_loss(
            adata=adata, 
            x=data, 
            batch_indices=batch_indices,
            lambda1=lambda1, 
            lambda2=lambda2, 
            out_dic=out_dic, 
            device=device,
            K=K
        )
        
        # Backward pass and optimization
        loss = losses_dic["total_loss"]
        loss.backward()
        optimizer.step()
        
        # Accumulate loss for monitoring
        running_loss += loss.item()
        
        if batch_idx % 10 == 0:  # Print every 10 batches
            print(f'Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}], Loss: {loss.item()}')
    
    print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {running_loss / len(train_loader)}')

    # Save the model periodically
    """if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')"""

# After training, save final model
torch.save(model.state_dict(), 'final_model.pth')

Epoch [1/100], Batch [0], Loss: 0.10275401920080185
Epoch [1/100], Batch [10], Loss: 0.07818262279033661
Epoch [1/100], Batch [20], Loss: 0.07436554878950119
Epoch [1/100], Batch [30], Loss: 0.06524176150560379
Epoch [1/100], Batch [40], Loss: 0.06573805958032608
Epoch [1/100], Batch [50], Loss: 0.05343299359083176
Epoch [1/100], Average Loss: 0.08445390060226968
Epoch [2/100], Batch [0], Loss: 0.05369982868432999
Epoch [2/100], Batch [10], Loss: 0.04870948567986488
Epoch [2/100], Batch [20], Loss: 0.04629693552851677
Epoch [2/100], Batch [30], Loss: 0.03722342476248741
Epoch [2/100], Batch [40], Loss: 0.04258904978632927
Epoch [2/100], Batch [50], Loss: 0.03694283589720726
Epoch [2/100], Average Loss: 0.0508017436145468
Epoch [3/100], Batch [0], Loss: 0.03649960085749626
Epoch [3/100], Batch [10], Loss: 0.035724736750125885
Epoch [3/100], Batch [20], Loss: 0.035620056092739105
Epoch [3/100], Batch [30], Loss: 0.030085865408182144
Epoch [3/100], Batch [40], Loss: 0.036147989332675934
E

KeyboardInterrupt: 