In [3]:
import sys,os
current_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(current_dir, '..')))


import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import scanpy as sc
import scvelo as scv
from model import NETWORK  # Ensure that model.py is saved in the same directory
from dataloaders import * # Ensure that dataloaders.py is saved in the same directory
from utils import *
from sklearn.manifold import Isomap


# Setup configuration
latent_dim = 64  # Latent dimension size, can be adjusted
hidden_dim = 512  # Hidden dimension size for the encoder and decoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

n_components = 100
n_knn_search = 10
dataset_name = "pancreas"
cell_type_key = "clusters"
model_name = "VeloFormer"

num_genes = 2000
nhead = 1 #original: 1
embedding_dim = 128*nhead# original: 128
num_encoder_layers = 1 #original: 1
num_bins = 50
batch_size = 64  # Batch size for training
epochs = 10 # Number of epochs for training
learning_rate = 1e-4  # Learning rate for the optimizer
lambda1 = 1 # Weight for heuristic loss
lambda2 = 0 # Weight for discrepancy loss
K = 11  # Number of neighbors for heuristic loss


knn_rep = "ve"
best_key = None
ve_layer = "None"

# Load data
adata = sc.read_h5ad("pancreas-gastr->pancr_transfer_isomap_latest.h5ad")
#adata.obsm["MuMs"] = np.concatenate([adata.layers["Ms"]], axis=1)
#manifold_and_neighbors(adata, n_components, n_knn_search, dataset_name, K, knn_rep, best_key, ve_layer)
adata = color_keys(adata, cell_type_key)

# Initialize model, optimizer, and loss function
model = NETWORK(input_dim=num_genes*2, latent_dim=latent_dim, 
                hidden_dim=hidden_dim, emb_dim = embedding_dim,
                nhead=nhead, num_encoder_layers=num_encoder_layers,
                num_genes=num_genes, num_bins=num_bins).to(device)
                
model.load_state_dict(torch.load('model_10epochs.pth'))

# Reinitialize weights of the derivative_decoder and probabilities_decoder
def reinitialize_weights(layer):
    if isinstance(layer, nn.Linear):
        nn.init.xavier_uniform_(layer.weight)  # Xavier initialization for weights
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)         # Initialize biases to zero

# Apply reinitialization to specific decoders
#model.derivative_decoder.apply(reinitialize_weights)
#model.probabilities_decoder.apply(reinitialize_weights)


# Freeze all layers except for the derivative and probabilities decoders
for name, param in model.named_parameters():
    if "derivative_decoder" not in name and "probabilities_decoder" not in name:
        param.requires_grad = False  # Freeze the parameters
    else:
        param.requires_grad = True   # Keep the decoders' parameters trainable


# Ensure optimizer only updates trainable parameters
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
#optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# Set up data loaders
train_loader, test_loader, full_data_loader = setup_dataloaders_binning_simpler(adata, 
                                                                       batch_size=batch_size, 
                                                                       num_genes=num_genes,
                                                                       num_bins=num_bins)

# Training loop
for epoch in range(epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    for batch_idx, (tokens, data, batch_indices) in enumerate(full_data_loader):
        tokens = tokens.to(device)
        data = data.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        """print(tokens.shape)
        print(data.shape)
        print(batch_indices.shape)"""
        
        # Forward pass
        out_dic = model(tokens, data)
        
        # Compute loss
        losses_dic = model.heuristic_loss(
            adata=adata, 
            x=data, 
            batch_indices=batch_indices,
            lambda1=lambda1, 
            lambda2=lambda2, 
            out_dic=out_dic, 
            device=device,
            K=K
        )
        
        # Backward pass and optimization
        loss = losses_dic["total_loss"]
        loss.backward()
        optimizer.step()
        
        # Accumulate loss for monitoring
        running_loss += loss.item()
        
        if batch_idx % 10 == 0:  # Print every 10 batches
            print(f'Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}], Loss: {loss.item()}')
    
    print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {running_loss / len(train_loader)}')

    # Save the model periodically
    """if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')"""

# After training, save final model
torch.save(model.state_dict(), 'linear_probed_model.pth')

computing isomap 1...
computing isomap 2..
ve shape: (3696, 64)
adata shape: (3696, 2000)
n_components: 100
n_neighbors: 10
knn rep used: ve
ve key used
Epoch [1/10], Batch [0], Loss: 0.9556393944869533
Epoch [1/10], Batch [10], Loss: 0.8863926671696456
Epoch [1/10], Batch [20], Loss: 0.8182053883242436
Epoch [1/10], Batch [30], Loss: 0.7568108230413835
Epoch [1/10], Batch [40], Loss: 0.6724526597871585
Epoch [1/10], Batch [50], Loss: 0.6068205909505628
Epoch [1/10], Average Loss: 0.9322851309612108
Epoch [2/10], Batch [0], Loss: 0.48643011245546186
Epoch [2/10], Batch [10], Loss: 0.4159967876730087
Epoch [2/10], Batch [20], Loss: 0.37033921290372457
Epoch [2/10], Batch [30], Loss: 0.34180613047067143
Epoch [2/10], Batch [40], Loss: 0.3192533184159715
Epoch [2/10], Batch [50], Loss: 0.3031333333884139
Epoch [2/10], Average Loss: 0.4512835342184697
Epoch [3/10], Batch [0], Loss: 0.1447963584555705
Epoch [3/10], Batch [10], Loss: 0.16582530783458577
Epoch [3/10], Batch [20], Loss: 0.1909