In [1]:
import numpy as np
import pickle
import anndata as ad
from sklearn.model_selection import train_test_split
import warnings

from data.data_process import data_process
from model.deconv_model import MBdeconv
from model.utils import *

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# 在使用GPU时，还可以设置以下代码来确保结果的一致性
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
warnings.filterwarnings("ignore")

In [2]:
data_file = 'data/bone_marrow_mb/blood.h5ad'
data_h5ad = ad.read_h5ad(data_file)
data_h5ad.obs['CellType'].value_counts()

CellType
B                 143
GMP               140
CMP               134
Myeloid           130
HSC (catulin+)    127
Erythroid         124
CLP               117
MEP               116
T                 116
HSC (catulin-)    114
MPP                92
HPC                75
Name: count, dtype: int64

In [3]:
data_h5ad.obs['CellType'] = data_h5ad.obs['CellType'].replace({'HSC (catulin+)': 'HSC', 'HSC (catulin-)': 'HSC'})

In [4]:
type_list = ['Erythroid', 'T', 'B', 'GMP', 'Myeloid']
noise = ['HSC']

In [5]:
print("Loading data...")
# extract noise
if noise:
    data_h5ad_noise = data_h5ad[data_h5ad.obs['CellType'].isin(noise)]
    data_h5ad_noise.obs.reset_index(drop=True, inplace=True)
# extract selected cells 
data_h5ad = data_h5ad[data_h5ad.obs['CellType'].isin(type_list)]
data_h5ad.obs.reset_index(drop=True, inplace=True)
print('selected cells:', data_h5ad)

Loading data...
selected cells: View of AnnData object with n_obs × n_vars = 653 × 107
    obs: 'CellType'


In [6]:
train_idx = []
test_idx = []

# Selected cells split into train and test datasets
for cell_type in data_h5ad.obs['CellType'].unique():
    current_idx = data_h5ad.obs[data_h5ad.obs['CellType'] == cell_type].index.tolist()
    train_i, test_i = train_test_split(current_idx, test_size=0.5, random_state=42)
    train_idx.extend(train_i)
    test_idx.extend(test_i)

print("Selected cells split into train and test datasets.")
train_data = data_h5ad[train_idx]
test_data = data_h5ad[test_idx]

Selected cells split into train and test datasets.


In [7]:
dp = data_process(type_list, tissue_name='bone_marrow_mb', 
                  test_sample_num=1000, sample_size=30, num_artificial_cells=70)

In [8]:
dp.fit(train_data, test_data, data_h5ad_noise)

The data processing is complete


In [9]:
with open(f'data/bone_marrow_mb/bone_marrow_mb{len(type_list)}cell.pkl', 'rb') as f:
    train = pickle.load(f)
    test = pickle.load(f)
    test_with_noise = pickle.load(f)

In [10]:
train_x_sim, train_with_noise_1, train_with_noise_2, train_y = train
test_x_sim, test_y = test
train_dataset = TrainCustomDataset(train_x_sim, train_with_noise_1, train_with_noise_2, train_y)
test_dataset = TestCustomDataset(test_x_sim, test_y)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [11]:
num_MB = 107
feat_map_w = 256
feat_map_h = 10
num_cell_type = len(type_list)
epoches = 171
Alpha = 1
Beta = 1

In [12]:
model = MBdeconv(num_MB, feat_map_w, feat_map_h, num_cell_type, epoches, Alpha, Beta, train_dataloader, test_dataloader)

In [13]:
device = torch.device('cuda')
if model.gpu_available:
    model = model.to(model.gpu)

In [14]:
loss1_list, loss2_list, nce_loss_list = model.train_model()

[2.42s] ep 0, loss 4.1324
[47.07s] ep 20, loss 3.4455
[94.34s] ep 40, loss 2.8763
[135.50s] ep 60, loss 2.8077
[177.47s] ep 80, loss 2.7241
[225.82s] ep 100, loss 2.6508
[268.03s] ep 120, loss 2.6315
[315.42s] ep 140, loss 2.6124
[360.71s] ep 160, loss 2.6128


In [15]:
model_test = MBdeconv(num_MB, feat_map_w, feat_map_h, num_cell_type, epoches, Alpha, Beta, train_dataloader, test_dataloader)
model_test.load_state_dict(torch.load('save_models/107/last.pt'))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_test.to(device)
model_test.eval()
CCC, RMSE, Corr = predict(test_dataloader, type_list, model_test, 'bone_marrow_mb', True)
CCC, RMSE, Corr

(0.7818052133984204, 0.06848808161221741, 0.7977945635618106)