In [1]:
import numpy as np
import pickle
import anndata as ad
from sklearn.model_selection import train_test_split
import warnings
import copy

from data.data_process import data_process
from model.deconv_model_with_stage_2 import MBdeconv
from model.utils import *
from model.stage2 import *

seed = 2021
torch.manual_seed(seed)
np.random.seed(seed)

# 在使用GPU时，还可以设置以下代码来确保结果的一致性
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
warnings.filterwarnings("ignore")

# data 

In [2]:
# Define the cell types of interest and read the corresponding single-cell matrix data.
type_list = ['Luminal_Macrophages', 'Type 2 alveolar', 'Fibroblasts', 'Dendritic cells']
noise = ['Neutrophils']
train_data_file = 'data/lung_rna/296C_train.h5ad'
test_data_file = 'data/lung_rna/302C_test.h5ad'
train_data = ad.read_h5ad(train_data_file)
test_data = ad.read_h5ad(test_data_file)

In [3]:
# Select the corresponding cells based on the cell types of interest.
if noise:
    data_h5ad_noise = test_data[test_data.obs['CellType'].isin(noise)]
    data_h5ad_noise.obs.reset_index(drop=True, inplace=True)
# extract selected cells 
train_data = train_data[train_data.obs['CellType'].isin(type_list)]
train_data.obs.reset_index(drop=True, inplace=True)
test_data = test_data[test_data.obs['CellType'].isin(type_list)]
test_data.obs.reset_index(drop=True, inplace=True)
print('selected cells:', train_data)
print('noise cells:', data_h5ad_noise)

selected cells: View of AnnData object with n_obs × n_vars = 3601 × 3346
    obs: 'Sample', 'Donor', 'Source', 'Location', 'CellType', 'BroadCellType'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'leiden', 'neighbors_hm', 'pca'
    obsm: 'X_umap_hm'
    varm: 'PCs'
noise cells: View of AnnData object with n_obs × n_vars = 293 × 3346
    obs: 'Sample', 'Donor', 'Source', 'Location', 'CellType', 'BroadCellType'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'leiden', 'neighbors_hm', 'pca'
    obsm: 'X_umap_hm'
    varm: 'PCs'


In [4]:
# Define the key parameters in the simulated experiment, 
# including the number of training and testing data entries and 
# the capacity of pseudo-organized cells. The number of artificial noise cells 
# used in stage three of the mixing phase is typically set to be the same as that of the pseudotissue cells.

dp = data_process(type_list, train_sample_num=6000, tissue_name='lung_rna', 
                  test_sample_num=1000, sample_size=30, num_artificial_cells=30)

In [5]:
# data_h5ad_noise is a dataset used to add unknown cell types to the test dataset
dp.fit(train_data, test_data, data_h5ad_noise)

Generating artificial cells...
Generating train pseudo_bulk samples...


train Samples: 100%|██████████| 6000/6000 [01:38<00:00, 61.18it/s]


Generating test pseudo_bulk samples...


test Samples: 100%|██████████| 1000/1000 [00:13<00:00, 75.05it/s]


The data processing is complete


In [6]:
# Read the dataset, where train is used for training, test is a mixed test set from different donors,
# and test_with_noise contains unseen cells from train mixed in different proportions, 
# with the same labels as the test set

with open(f'data/lung_rna/lung_rna{len(type_list)}cell.pkl', 'rb') as f:
    train = pickle.load(f)
    test = pickle.load(f)
    test_with_noise = pickle.load(f)

In [7]:
train_x_sim, train_with_noise_1, train_with_noise_2, train_y = train
test_x_sim, test_y = test

# Partition a portion of the test dataset for evaluating performance to serve the early stopping mechanism.
valid_size = 1000  

# 切片操作  
valid_x_sim = train_x_sim[:valid_size]  
valid_with_noise_1 = train_with_noise_1[:valid_size]  
valid_with_noise_2 = train_with_noise_2[:valid_size]  
valid_y = train_y[:valid_size]  

train_x_sim = train_x_sim[valid_size:]  
train_with_noise_1 = train_with_noise_1[valid_size:]  
train_with_noise_2 = train_with_noise_2[valid_size:]  
train_y = train_y[valid_size:]  

test_dataset = TestCustomDataset(test_x_sim, test_y)
valid_dataset = TestCustomDataset(valid_x_sim, valid_y)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=False)

train_dataset = TrainCustomDataset(train_x_sim, train_with_noise_1, train_with_noise_2, train_y)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)


source_data = data2h5ad(train_x_sim, train_y, type_list)
target_data = data2h5ad(test_x_sim, test_y, type_list)
valid_data = data2h5ad(valid_x_sim, valid_y, type_list)

AnnData object with n_obs × n_vars = 5000 × 3346
    obs: 'Luminal_Macrophages', 'Type 2 alveolar', 'Fibroblasts', 'Dendritic cells'
    uns: 'cell_types'
AnnData object with n_obs × n_vars = 1000 × 3346
    obs: 'Luminal_Macrophages', 'Type 2 alveolar', 'Fibroblasts', 'Dendritic cells'
    uns: 'cell_types'
AnnData object with n_obs × n_vars = 1000 × 3346
    obs: 'Luminal_Macrophages', 'Type 2 alveolar', 'Fibroblasts', 'Dendritic cells'
    uns: 'cell_types'


# model

In [8]:
num_feat = 3346
feat_map_w = 256
feat_map_h = 10
num_cell_type = len(type_list)
patience = 10
epoches = 200
Alpha = 1
Beta = 1
model_save_name = 'lung_rna'

In [9]:
# Train stage 2, returning the training loss and the best encoder parameters.
model_da = DANN(epoches, 50, 0.0001)
pred_loss, disc_loss, disc_loss_DA, best_model_weights = model_da.train(source_data, target_data, valid_data, patience = 3) 


[36m===== Starting Training (Total Epochs: 200) =====
Patience for early stopping: 3 epochs
Batch size: 50, Learning rate: 0.0001[0m



Epoch 1/200: 100%|██████████| 100/100 batches


[36m[Ep 1] | Pred: [32m0.0200[0m | Disc: 1.3869 | Disc_DA: 1.3870 | Valid RMSE: [32m0.1388[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 2/200: 100%|██████████| 100/100 batches


[36m[Ep 2] | Pred: [32m0.0168[0m | Disc: 1.3874 | Disc_DA: 1.3863 | Valid RMSE: [32m0.1049[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 3/200: 100%|██████████| 100/100 batches


[36m[Ep 3] | Pred: [32m0.0095[0m | Disc: 1.3876 | Disc_DA: 1.3862 | Valid RMSE: [32m0.0686[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 4/200: 100%|██████████| 100/100 batches


[36m[Ep 4] | Pred: [32m0.0058[0m | Disc: 1.3889 | Disc_DA: 1.3853 | Valid RMSE: [32m0.0518[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 5/200: 100%|██████████| 100/100 batches


[36m[Ep 5] | Pred: [32m0.0028[0m | Disc: 1.3878 | Disc_DA: 1.3860 | Valid RMSE: [32m0.0381[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 6/200: 100%|██████████| 100/100 batches


[36m[Ep 6] | Pred: [32m0.0023[0m | Disc: 1.3885 | Disc_DA: 1.3858 | Valid RMSE: [32m0.0351[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 7/200: 100%|██████████| 100/100 batches


[36m[Ep 7] | Pred: [32m0.0020[0m | Disc: 1.3878 | Disc_DA: 1.3856 | Valid RMSE: [32m0.0327[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 8/200: 100%|██████████| 100/100 batches


[36m[Ep 8] | Pred: [32m0.0017[0m | Disc: 1.3884 | Disc_DA: 1.3851 | Valid RMSE: [32m0.0287[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 9/200: 100%|██████████| 100/100 batches


[36m[Ep 9] | Pred: [32m0.0016[0m | Disc: 1.3878 | Disc_DA: 1.3853 | Valid RMSE: [32m0.0280[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 10/200: 100%|██████████| 100/100 batches


[36m[Ep 10] | Pred: [32m0.0015[0m | Disc: 1.3878 | Disc_DA: 1.3857 | Valid RMSE: [32m0.0268[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 11/200: 100%|██████████| 100/100 batches


[36m[Ep 11] | Pred: [32m0.0014[0m | Disc: 1.3872 | Disc_DA: 1.3852 | Valid RMSE: [32m0.0258[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 12/200: 100%|██████████| 100/100 batches


[36m[Ep 12] | Pred: [32m0.0012[0m | Disc: 1.3878 | Disc_DA: 1.3858 | Valid RMSE: [32m0.0269[0m
  [33m↯ No improvement (1/3)[0m


Epoch 13/200: 100%|██████████| 100/100 batches


[36m[Ep 13] | Pred: [32m0.0012[0m | Disc: 1.3876 | Disc_DA: 1.3857 | Valid RMSE: [32m0.0263[0m
  [33m↯ No improvement (2/3)[0m


Epoch 14/200: 100%|██████████| 100/100 batches


[36m[Ep 14] | Pred: [32m0.0010[0m | Disc: 1.3878 | Disc_DA: 1.3854 | Valid RMSE: [32m0.0223[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 15/200: 100%|██████████| 100/100 batches


[36m[Ep 15] | Pred: [32m0.0010[0m | Disc: 1.3876 | Disc_DA: 1.3853 | Valid RMSE: [32m0.0216[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 16/200: 100%|██████████| 100/100 batches


[36m[Ep 16] | Pred: [32m0.0010[0m | Disc: 1.3881 | Disc_DA: 1.3854 | Valid RMSE: [32m0.0232[0m
  [33m↯ No improvement (1/3)[0m


Epoch 17/200: 100%|██████████| 100/100 batches


[36m[Ep 17] | Pred: [32m0.0009[0m | Disc: 1.3874 | Disc_DA: 1.3852 | Valid RMSE: [32m0.0209[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 18/200: 100%|██████████| 100/100 batches


[36m[Ep 18] | Pred: [32m0.0010[0m | Disc: 1.3878 | Disc_DA: 1.3853 | Valid RMSE: [32m0.0221[0m
  [33m↯ No improvement (1/3)[0m


Epoch 19/200: 100%|██████████| 100/100 batches


[36m[Ep 19] | Pred: [32m0.0009[0m | Disc: 1.3881 | Disc_DA: 1.3854 | Valid RMSE: [32m0.0210[0m
  [33m↯ No improvement (2/3)[0m


Epoch 20/200: 100%|██████████| 100/100 batches


[36m[Ep 20] | Pred: [32m0.0009[0m | Disc: 1.3874 | Disc_DA: 1.3860 | Valid RMSE: [32m0.0231[0m
  [33m↯ No improvement (3/3)[0m
[36m
Early stopping triggered at epoch 20!
Best RMSE achieved: 0.0209[0m


[36m===== Training Complete! =====
Total epochs: 20/200
Best RMSE: [35m0.0209[0m
Final losses: Pred=0.0009, Disc=1.3874, Disc_DA=1.3860
[0m


In [10]:
model = MBdeconv(num_feat, feat_map_w, feat_map_h, num_cell_type, epoches, Alpha, Beta, train_dataloader, test_dataloader)

In [11]:
# Train stage 3, reading the parameters of stage 2 encoder before training.
device = torch.device('cuda')
if model.gpu_available:
    model = model.to(model.gpu)
model_da.encoder_da.load_state_dict(best_model_weights['encoder'])
encoder_params = copy.deepcopy(model_da.encoder_da.state_dict())
model.encoder.load_state_dict(encoder_params)
loss1_list, loss2_list, nce_loss_list = model.train_model(model_save_name, True, patience)


[36m===== Starting Training (Total Epochs: 200) =====
Patience for early stopping: 10 epochs[0m



Epoch 1/200: 100%|██████████| 79/79 batches


[36m[Ep 1] 1.8s | Loss: [32m4.1186[0m (L1: 0.0194, L2: 0.0194, NCE: 8.1495) | Test: RMSE=[32m0.0458[0m, MAE=[32m0.0368[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 2/200: 100%|██████████| 79/79 batches


[36m[Ep 2] 3.1s | Loss: [32m3.7292[0m (L1: 0.0063, L2: 0.0062, NCE: 7.4296) | Test: RMSE=[32m0.0371[0m, MAE=[32m0.0281[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 3/200: 100%|██████████| 79/79 batches


[36m[Ep 3] 4.5s | Loss: [32m3.4641[0m (L1: 0.0012, L2: 0.0012, NCE: 6.9214) | Test: RMSE=[32m0.0308[0m, MAE=[32m0.0225[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 4/200: 100%|██████████| 79/79 batches


[36m[Ep 4] 5.8s | Loss: [32m3.4263[0m (L1: 0.0009, L2: 0.0009, NCE: 6.8473) | Test: RMSE=[32m0.0297[0m, MAE=[32m0.0213[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 5/200: 100%|██████████| 79/79 batches


[36m[Ep 5] 7.2s | Loss: [32m3.4335[0m (L1: 0.0008, L2: 0.0009, NCE: 6.8620) | Test: RMSE=[32m0.0290[0m, MAE=[32m0.0210[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 6/200: 100%|██████████| 79/79 batches


[36m[Ep 6] 8.6s | Loss: [32m3.3147[0m (L1: 0.0008, L2: 0.0008, NCE: 6.6248) | Test: RMSE=[32m0.0291[0m, MAE=[32m0.0209[0m
  [33m↯ No improvement (1/10)[0m


Epoch 7/200: 100%|██████████| 79/79 batches


[36m[Ep 7] 10.1s | Loss: [32m3.3577[0m (L1: 0.0008, L2: 0.0008, NCE: 6.7108) | Test: RMSE=[32m0.0308[0m, MAE=[32m0.0222[0m
  [33m↯ No improvement (2/10)[0m


Epoch 8/200: 100%|██████████| 79/79 batches


[36m[Ep 8] 11.5s | Loss: [32m3.3573[0m (L1: 0.0008, L2: 0.0007, NCE: 6.7103) | Test: RMSE=[32m0.0295[0m, MAE=[32m0.0225[0m
  [33m↯ No improvement (3/10)[0m


Epoch 9/200: 100%|██████████| 79/79 batches


[36m[Ep 9] 12.9s | Loss: [32m3.2899[0m (L1: 0.0008, L2: 0.0008, NCE: 6.5753) | Test: RMSE=[32m0.0285[0m, MAE=[32m0.0211[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 10/200: 100%|██████████| 79/79 batches


[36m[Ep 10] 14.4s | Loss: [32m3.3256[0m (L1: 0.0007, L2: 0.0007, NCE: 6.6472) | Test: RMSE=[32m0.0294[0m, MAE=[32m0.0211[0m
  [33m↯ No improvement (1/10)[0m


Epoch 11/200: 100%|██████████| 79/79 batches


[36m[Ep 11] 15.9s | Loss: [32m3.3107[0m (L1: 0.0007, L2: 0.0007, NCE: 6.6175) | Test: RMSE=[32m0.0290[0m, MAE=[32m0.0209[0m
  [33m↯ No improvement (2/10)[0m


Epoch 12/200: 100%|██████████| 79/79 batches


[36m[Ep 12] 17.3s | Loss: [32m3.2924[0m (L1: 0.0007, L2: 0.0007, NCE: 6.5808) | Test: RMSE=[32m0.0305[0m, MAE=[32m0.0217[0m
  [33m↯ No improvement (3/10)[0m


Epoch 13/200: 100%|██████████| 79/79 batches


[36m[Ep 13] 18.8s | Loss: [32m3.2862[0m (L1: 0.0007, L2: 0.0007, NCE: 6.5684) | Test: RMSE=[32m0.0336[0m, MAE=[32m0.0233[0m
  [33m↯ No improvement (4/10)[0m


Epoch 14/200: 100%|██████████| 79/79 batches


[36m[Ep 14] 20.3s | Loss: [32m3.2864[0m (L1: 0.0007, L2: 0.0007, NCE: 6.5689) | Test: RMSE=[32m0.0291[0m, MAE=[32m0.0206[0m
  [33m↯ No improvement (5/10)[0m


Epoch 15/200: 100%|██████████| 79/79 batches


[36m[Ep 15] 21.7s | Loss: [32m3.2532[0m (L1: 0.0006, L2: 0.0006, NCE: 6.5026) | Test: RMSE=[32m0.0287[0m, MAE=[32m0.0205[0m
  [33m↯ No improvement (6/10)[0m


Epoch 16/200: 100%|██████████| 79/79 batches


[36m[Ep 16] 23.1s | Loss: [32m3.2771[0m (L1: 0.0007, L2: 0.0006, NCE: 6.5504) | Test: RMSE=[32m0.0303[0m, MAE=[32m0.0217[0m
  [33m↯ No improvement (7/10)[0m


Epoch 17/200: 100%|██████████| 79/79 batches


[36m[Ep 17] 24.5s | Loss: [32m3.2360[0m (L1: 0.0006, L2: 0.0006, NCE: 6.4683) | Test: RMSE=[32m0.0278[0m, MAE=[32m0.0197[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 18/200: 100%|██████████| 79/79 batches


[36m[Ep 18] 26.0s | Loss: [32m3.2438[0m (L1: 0.0006, L2: 0.0006, NCE: 6.4838) | Test: RMSE=[32m0.0316[0m, MAE=[32m0.0220[0m
  [33m↯ No improvement (1/10)[0m


Epoch 19/200: 100%|██████████| 79/79 batches


[36m[Ep 19] 27.4s | Loss: [32m3.2001[0m (L1: 0.0006, L2: 0.0006, NCE: 6.3965) | Test: RMSE=[32m0.0277[0m, MAE=[32m0.0198[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 20/200: 100%|██████████| 79/79 batches


[36m[Ep 20] 28.8s | Loss: [32m3.2008[0m (L1: 0.0006, L2: 0.0006, NCE: 6.3980) | Test: RMSE=[32m0.0276[0m, MAE=[32m0.0196[0m
  [35m★ New best RMSE! Model saved.[0m


Epoch 21/200: 100%|██████████| 79/79 batches


[36m[Ep 21] 30.2s | Loss: [32m3.1928[0m (L1: 0.0006, L2: 0.0006, NCE: 6.3820) | Test: RMSE=[32m0.0281[0m, MAE=[32m0.0202[0m
  [33m↯ No improvement (1/10)[0m


Epoch 22/200: 100%|██████████| 79/79 batches


[36m[Ep 22] 31.6s | Loss: [32m3.0786[0m (L1: 0.0006, L2: 0.0006, NCE: 6.1539) | Test: RMSE=[32m0.0289[0m, MAE=[32m0.0211[0m
  [33m↯ No improvement (2/10)[0m


Epoch 23/200: 100%|██████████| 79/79 batches


[36m[Ep 23] 33.0s | Loss: [32m3.0355[0m (L1: 0.0006, L2: 0.0006, NCE: 6.0675) | Test: RMSE=[32m0.0280[0m, MAE=[32m0.0204[0m
  [33m↯ No improvement (3/10)[0m


Epoch 24/200: 100%|██████████| 79/79 batches


[36m[Ep 24] 34.4s | Loss: [32m3.0097[0m (L1: 0.0006, L2: 0.0006, NCE: 6.0159) | Test: RMSE=[32m0.0292[0m, MAE=[32m0.0218[0m
  [33m↯ No improvement (4/10)[0m


Epoch 25/200: 100%|██████████| 79/79 batches


[36m[Ep 25] 35.9s | Loss: [32m2.8963[0m (L1: 0.0006, L2: 0.0006, NCE: 5.7890) | Test: RMSE=[32m0.0298[0m, MAE=[32m0.0221[0m
  [33m↯ No improvement (5/10)[0m


Epoch 26/200: 100%|██████████| 79/79 batches


[36m[Ep 26] 37.3s | Loss: [32m2.8430[0m (L1: 0.0006, L2: 0.0006, NCE: 5.6825) | Test: RMSE=[32m0.0286[0m, MAE=[32m0.0201[0m
  [33m↯ No improvement (6/10)[0m


Epoch 27/200: 100%|██████████| 79/79 batches


[36m[Ep 27] 38.7s | Loss: [32m2.6727[0m (L1: 0.0006, L2: 0.0006, NCE: 5.3420) | Test: RMSE=[32m0.0277[0m, MAE=[32m0.0199[0m
  [33m↯ No improvement (7/10)[0m


Epoch 28/200: 100%|██████████| 79/79 batches


[36m[Ep 28] 40.2s | Loss: [32m2.8701[0m (L1: 0.0006, L2: 0.0006, NCE: 5.7368) | Test: RMSE=[32m0.0279[0m, MAE=[32m0.0203[0m
  [33m↯ No improvement (8/10)[0m


Epoch 29/200: 100%|██████████| 79/79 batches


[36m[Ep 29] 41.5s | Loss: [32m3.1984[0m (L1: 0.0006, L2: 0.0006, NCE: 6.3936) | Test: RMSE=[32m0.0291[0m, MAE=[32m0.0204[0m
  [33m↯ No improvement (9/10)[0m


Epoch 30/200: 100%|██████████| 79/79 batches


[36m[Ep 30] 42.9s | Loss: [32m3.1860[0m (L1: 0.0006, L2: 0.0006, NCE: 6.3685) | Test: RMSE=[32m0.0289[0m, MAE=[32m0.0216[0m
  [33m↯ No improvement (10/10)[0m
[36m
Early stopping triggered at epoch 30!
Best RMSE achieved: 0.0276[0m


[36m===== Training Complete! =====
Total training time: 42.9 seconds
Final losses: Total=3.1860, L1=0.0006, L2=0.0006, NCE=6.3685
[0m


In [15]:
model_test = MBdeconv(num_feat, feat_map_w, feat_map_h, num_cell_type, epoches, Alpha, Beta, train_dataloader, test_dataloader)

In [16]:
# Perform inference on the test dataset in Stage 4 and obtain the overall CCC, RMSE, and Correlation values.
model_test.load_state_dict(torch.load('save_models/3346/lung_rna.pt'))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_test.to(device)
model_test.eval()
CCC, RMSE, Corr, pred, gt = predict(test_dataloader, type_list, model_test, True)

In [17]:
CCC, RMSE, Corr 

(0.9784067211573461, 0.027694766252091465, 0.9813103117779443)