In [None]:
import model
import torch
import scanpy as sc
import anndata as ad
import dataloader as dl
from utils.get_edge_index import get_edge_index
from paste_ import get_pi
from tqdm import tqdm

In [None]:
# Specify the ID of the slice
data_path='/home/sxa/Datasets/Human_DLPFC/'
target_slice_id=151676
adjacent_slice_id=151675

In [None]:
# get data
target_slice = dl.load_data(data_path,target_slice_id) 
adjacent_slice = dl.load_data(data_path,adjacent_slice_id) 

In [None]:
# get the numbers of spots and genes
target_slice_spots_num,starget_slice_genes_num = target_slice.X.shape[0],target_slice.X.shape[1]
adjacent_slice_spots_num,adjacent_slice_genes_num = adjacent_slice.X.shape[0],adjacent_slice.X.shape[1]

In [None]:
# get bipartite graph
slice1_slice2_pi = get_pi(target_slice, adjacent_slice)
slice1_slice2_pi_edge_index = get_edge_index(slice1_slice2_pi)

In [None]:
# hyperparameter setting
hidden_dim = 1000
epochs = 1000
mse_loss = nn.MSELoss()
device = torch.device('cuda:0')
ae_optim = optim.Adam(model_.parameters(), lr=0.0001)

In [None]:
# definition model
model_ = model.model(target_slice_spots_num, adjacent_slice_spots_num, starget_slice_genes_num, adjacent_slice_genes_num, hidden_dim).to(device)

In [None]:
# train
target_slice_X = torch.tensor(target_slice.X).to(device)
adjacent_slice_X = torch.tensor(adjacent_slice.X).to(device)

for epoch in tqdm(range(epochs), desc="Training", unit="epoch"):
    # Forward pass
    _, recreated_slice1_X = model_(slice1_slice2_pi_edge_index.to(device), target_slice_X, adjacent_slice_X)
    
    # Compute the loss
    loss = mse_loss(recreated_slice1_X, target_slice_X)
    
    # Backward pass and optimization
    ae_optim.zero_grad()
    loss.backward()
    ae_optim.step()
    
    # Update the tqdm progress bar with the current loss value
    tqdm.write(f"Epoch {epoch}, Loss: {loss.item():.4f}")

In [None]:
# get enhanced data and write
model_.eval()
target_slice.obsm['enhanced_data']=model_.encodermodel_(slice1_slice2_pi_edge_index.to(device), target_slice_X, adjacent_slice_X).detach().cpu().numpy()
target_slice.write_h5ad('./tmp_data/Human_DLPFC_enhanced_data/151676.h5ad')