In [1]:
import numpy as np
import torch
import scanpy as sc
import torch.nn.functional as F
from configure import get_default_config
from model_brca import Model
from tools import normalize_type
import random
from sklearn.model_selection import KFold

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dataset = {
    0: "dro",
    1: "MERFISH",
    2: "STARmap",
    4: "BRCA",
    5: "ELSE",
}
data_name = dataset[4]
shuffle = False
random.seed(48)
ktimes = 10
cv = KFold(n_splits=ktimes, shuffle=False)

In [2]:
def filter_with_overlap_gene(adata, adata_sc):
    # remove all-zero-valued genes
    # sc.pp.filter_genes(adata, min_cells=1)
    # sc.pp.filter_genes(adata_sc, min_cells=1)
    if 'highly_variable' not in adata.var.keys():
        raise ValueError("'highly_variable' are not existed in adata!")
    else:
        adata = adata[:, adata.var['highly_variable']]

    if 'highly_variable' not in adata_sc.var.keys():
        raise ValueError("'highly_variable' are not existed in adata_sc!")
    else:
        adata_sc = adata_sc[:, adata_sc.var['highly_variable']]

        # Refine `marker_genes` so that they are shared by both adatas
    genes = list(set(adata.var.index) & set(adata_sc.var.index))
    genes.sort()
    print('Number of overlap genes:', len(genes))
    adata.uns["overlap_genes"] = genes
    adata_sc.uns["overlap_genes"] = genes

    adata = adata[:, genes]
    adata_sc = adata_sc[:, genes]

    return adata, adata_sc

In [3]:
rna_path = 'data/BRCA/scRNA.h5ad'
adata_sc = sc.read(rna_path)
adata_sc.var_names_make_unique()
sc.pp.highly_variable_genes(adata_sc, flavor="seurat_v3", n_top_genes=3000, inplace=True, subset=True)
sc.pp.normalize_total(adata_sc, target_sum=1e4)
sc.pp.log1p(adata_sc)
# load BRCA ST data
adata_st = sc.read_visium(path=r"./data/BRCA/",
                               count_file="filtered_feature_bc_matrix.h5", library_id="BRCA",
                               load_images=True, source_image_path="/spatial/")
adata_st.var_names_make_unique()
sc.pp.highly_variable_genes(adata_st, flavor="seurat_v3", n_top_genes=3000, inplace=True, subset=True)
sc.pp.normalize_total(adata_st, target_sum=1e4)
sc.pp.log1p(adata_st)
adata_st, adata_sc = filter_with_overlap_gene(adata_st, adata_sc)
atlas_genes = adata_st.uns["overlap_genes"]
x1_cell, x2_rna = adata_st.X, adata_sc.X
x1_cell = x1_cell.todense()

  utils.warn_names_duplicates("var")


Number of overlap genes: 921


  self.data[key] = value


In [4]:
def pearsonr(x, y):
    mean_x = torch.mean(x)
    mean_y = torch.mean(y)
    xm = x.sub(mean_x)
    ym = y.sub(mean_y)
    r_num = xm.dot(ym)
    r_den = torch.norm(xm, 2) * torch.norm(ym, 2)
    r_val = r_num / (r_den + 1e-8)
    r_val = torch.nan_to_num(r_val,nan=-1)
    return r_val

def correlationMetric(x, y):
    corr = 0
    for idx in range(x.size(1)):
        corr += pearsonr(x[:,idx], y[:,idx])
    corr /= (idx + 1)
    return (1 - corr).mean()

In [5]:
class MinMaxScaler_torch:
    def __init__(self, feature_range=(0, 1)):
        self.min_val = None
        self.max_val = None
        self.feature_range = feature_range

    def fit(self, data):
        self.min_val, _ = torch.min(data, dim=0)
        self.max_val, _ = torch.max(data, dim=0)

    def transform(self, data):
        data_minmax = (data - self.min_val) / (self.max_val - self.min_val)
        data_minmax = data_minmax * (self.feature_range[1] - self.feature_range[0]) + self.feature_range[0]
        return data_minmax

    def inverse_transform(self, data_minmax):
        data = (data_minmax - self.feature_range[0]) / (self.feature_range[1] - self.feature_range[0])
        data = data * (self.max_val - self.min_val) + self.min_val
        return data

In [6]:
def train(model, optimizer, x1, x2, config, model_path):
    print('\n===========> Training... <===========')
    model.train()
    if not isinstance(x1, torch.Tensor):
        x1 = torch.from_numpy(x1).to(device)
    if not isinstance(x2, torch.Tensor):
        x2 = torch.from_numpy(x2).to(device)
    for epoch in range(config['pretrain_epochs']):
        x1_hat = model(x2)
        mask_zero = x1 != 0
        x1_masked = torch.masked_select(x1, mask_zero)
        x1_masked_pred = torch.masked_select(x1_hat, mask_zero)
        loss_ae = F.mse_loss(x1, x1_hat, reduction='mean')
        loss_mask_ae = F.mse_loss(x1_masked, x1_masked_pred, reduction='mean')
        corrloss = correlationMetric(x1_hat, x1)
        loss = loss_ae + 0.001 * loss_mask_ae + 0.01 * corrloss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % 10 == 0:
            print ('train epoch %d: loss_ae: %.6f loss_mask_ae: %.6f loss_corr: %.6f' \
                   % (epoch, loss.item(), loss_mask_ae.item(), corrloss.item()))
    torch.save(model.state_dict(), model_path)


In [7]:
cnts = 0
for train_idx, test_idx in cv.split(atlas_genes):
    x1_train_cell = x1_cell[:, train_idx]
    x1_test_cell = x1_cell[:, test_idx]
    x2_train_cell = x2_rna[:, train_idx]
    x2_test_cell = x2_rna[:, test_idx]
    print(x1_cell.shape, x2_rna.shape)
    x1 = x1_train_cell.astype(np.float32)
    x2 = x2_train_cell.astype(np.float32)
    print (x1.max(), x1.min(), x2.max(), x2.min())
    count_zerox1 = np.sum(x1 == 0)
    count_zerox2 = np.sum(x2 == 0)
    print('Before Norm ST:', float(count_zerox1) / float(x1.shape[0] * x1.shape[1]), 'Before Norm RNA:',
          float(count_zerox2) / float(x2.shape[0] * x2.shape[1]))
    x1, x2 = normalize_type(x1, x2, type='all')
    print ('Xrna shape:', x2.shape,  'Xst shape:', x1.shape)
    count_zerox1 = np.sum(x1 == 0)
    count_zerox2 = np.sum(x2 == 0)
    print ('ST:', float(count_zerox1) / float(x1.shape[0]*x1.shape[1]), 'RNA:', float(count_zerox2) / float(x2.shape[0]*x2.shape[1]))
    print (x1.max(), x1.min(), x1.mean(), x2.max(), x2.min(), x2.mean())
    config = get_default_config(data_name)
    config['num_sample1'] = x1.shape[0]
    config['num_sample2'] = x2.shape[0]
    train_path = "train/%s_%d.pkl" % (data_name, cnts)
    print (train_path)
    tmp_dims = x1.shape[1]
    model = Model(config)
    model.to(device)
    optimizer_pre = torch.optim.Adam(model.parameters(), lr=config['pre_lr'])
    train(model, optimizer_pre, x1, x2, config, train_path)
    print("Finished " + str(cnts) + ' Times Training and Testing...')
    cnts += 1


(3798, 921) (45647, 921)
8.3757105 0.0 6.1131206 0.0
Before Norm ST: 0.6720480268028176 Before Norm RNA: 0.7785909122610616




Xrna shape: (45647, 828) Xst shape: (3798, 828)
ST: 0.6720483447937257 RNA: 0.7785909122610616
1.0 0.0 0.1413657 1.0 0.0 0.115089804
train/BRCA_0.pkl

train epoch 0: loss_ae: 0.081008 loss_mask_ae: 0.215522 loss_corr: 1.004733
train epoch 10: loss_ae: 0.063611 loss_mask_ae: 0.143707 loss_corr: 0.971637
train epoch 20: loss_ae: 0.053623 loss_mask_ae: 0.076622 loss_corr: 0.920570
train epoch 30: loss_ae: 0.049898 loss_mask_ae: 0.071980 loss_corr: 0.887965
train epoch 40: loss_ae: 0.045990 loss_mask_ae: 0.070037 loss_corr: 0.876791
train epoch 50: loss_ae: 0.046238 loss_mask_ae: 0.056039 loss_corr: 0.874176
train epoch 60: loss_ae: 0.040302 loss_mask_ae: 0.076589 loss_corr: 0.868123
train epoch 70: loss_ae: 0.037731 loss_mask_ae: 0.067067 loss_corr: 0.856886
train epoch 80: loss_ae: 0.036295 loss_mask_ae: 0.061961 loss_corr: 0.850242
train epoch 90: loss_ae: 0.035487 loss_mask_ae: 0.053463 loss_corr: 0.845099
Finished 0 Times Training and Testing...
(3798, 921) (45647, 921)
8.3757105 0.0 



Xrna shape: (45647, 829) Xst shape: (3798, 829)
ST: 0.6695070289676937 RNA: 0.7782712530729932
1.0 0.0 0.14286344 1.0 0.0 0.114488915
train/BRCA_1.pkl

train epoch 0: loss_ae: 0.081783 loss_mask_ae: 0.216400 loss_corr: 0.998322
train epoch 10: loss_ae: 0.072879 loss_mask_ae: 0.177274 loss_corr: 0.988577
train epoch 20: loss_ae: 0.054041 loss_mask_ae: 0.114291 loss_corr: 0.929493
train epoch 30: loss_ae: 0.049813 loss_mask_ae: 0.071088 loss_corr: 0.889236
train epoch 40: loss_ae: 0.047202 loss_mask_ae: 0.067792 loss_corr: 0.876503
train epoch 50: loss_ae: 0.042510 loss_mask_ae: 0.066684 loss_corr: 0.874758
train epoch 60: loss_ae: 0.038970 loss_mask_ae: 0.068585 loss_corr: 0.870011
train epoch 70: loss_ae: 0.039218 loss_mask_ae: 0.076920 loss_corr: 0.861308
train epoch 80: loss_ae: 0.037091 loss_mask_ae: 0.050715 loss_corr: 0.855245
train epoch 90: loss_ae: 0.035787 loss_mask_ae: 0.051967 loss_corr: 0.849791
Finished 1 Times Training and Testing...
(3798, 921) (45647, 921)
8.3757105 0.0



ST: 0.6768374060120526 RNA: 0.7808339778881643
1.0 0.0 0.13857847 1.0 0.0 0.11407149
train/BRCA_2.pkl

train epoch 0: loss_ae: 0.079235 loss_mask_ae: 0.213454 loss_corr: 0.997719
train epoch 10: loss_ae: 0.064095 loss_mask_ae: 0.154997 loss_corr: 0.970128
train epoch 20: loss_ae: 0.049520 loss_mask_ae: 0.093493 loss_corr: 0.911815
train epoch 30: loss_ae: 0.047773 loss_mask_ae: 0.075076 loss_corr: 0.887190
train epoch 40: loss_ae: 0.044638 loss_mask_ae: 0.065813 loss_corr: 0.881880
train epoch 50: loss_ae: 0.040367 loss_mask_ae: 0.062594 loss_corr: 0.886272
train epoch 60: loss_ae: 0.037598 loss_mask_ae: 0.061242 loss_corr: 0.873992
train epoch 70: loss_ae: 0.036289 loss_mask_ae: 0.060890 loss_corr: 0.861581
train epoch 80: loss_ae: 0.035657 loss_mask_ae: 0.059854 loss_corr: 0.852607
train epoch 90: loss_ae: 0.035314 loss_mask_ae: 0.060843 loss_corr: 0.846736
Finished 2 Times Training and Testing...
(3798, 921) (45647, 921)
8.3757105 0.0 6.1131206 0.0
Before Norm ST: 0.6651507269078831



ST: 0.6651510445152073 RNA: 0.7760074604078082
1.0 0.0 0.14487197 1.0 0.0 0.11577417
train/BRCA_3.pkl

train epoch 0: loss_ae: 0.082867 loss_mask_ae: 0.216753 loss_corr: 1.000662
train epoch 10: loss_ae: 0.064089 loss_mask_ae: 0.139510 loss_corr: 0.966398
train epoch 20: loss_ae: 0.052139 loss_mask_ae: 0.100648 loss_corr: 0.911101
train epoch 30: loss_ae: 0.048331 loss_mask_ae: 0.080435 loss_corr: 0.881214
train epoch 40: loss_ae: 0.054402 loss_mask_ae: 0.122816 loss_corr: 0.891889
train epoch 50: loss_ae: 0.044822 loss_mask_ae: 0.075106 loss_corr: 0.888613
train epoch 60: loss_ae: 0.041698 loss_mask_ae: 0.054213 loss_corr: 0.878094
train epoch 70: loss_ae: 0.038271 loss_mask_ae: 0.057755 loss_corr: 0.870410
train epoch 80: loss_ae: 0.036845 loss_mask_ae: 0.061908 loss_corr: 0.858468
train epoch 90: loss_ae: 0.035936 loss_mask_ae: 0.059166 loss_corr: 0.849747
Finished 3 Times Training and Testing...
(3798, 921) (45647, 921)
7.801908 0.0 5.955568 0.0
Before Norm ST: 0.6807878059114345 B



ST: 0.6807878059114345 RNA: 0.7858623644185332
1.0 0.0 0.13504948 1.0 0.0 0.10999393
train/BRCA_4.pkl

train epoch 0: loss_ae: 0.076294 loss_mask_ae: 0.206862 loss_corr: 0.999013
train epoch 10: loss_ae: 0.060205 loss_mask_ae: 0.138070 loss_corr: 0.959072
train epoch 20: loss_ae: 0.049623 loss_mask_ae: 0.088356 loss_corr: 0.912338
train epoch 30: loss_ae: 0.047491 loss_mask_ae: 0.097286 loss_corr: 0.879013
train epoch 40: loss_ae: 0.056704 loss_mask_ae: 0.139266 loss_corr: 0.896642
train epoch 50: loss_ae: 0.043935 loss_mask_ae: 0.058736 loss_corr: 0.873719
train epoch 60: loss_ae: 0.039509 loss_mask_ae: 0.066373 loss_corr: 0.871702
train epoch 70: loss_ae: 0.037943 loss_mask_ae: 0.057959 loss_corr: 0.860641
train epoch 80: loss_ae: 0.036584 loss_mask_ae: 0.058688 loss_corr: 0.850130
train epoch 90: loss_ae: 0.035917 loss_mask_ae: 0.063158 loss_corr: 0.841852
Finished 4 Times Training and Testing...
(3798, 921) (45647, 921)
8.3757105 0.0 6.1131206 0.0
Before Norm ST: 0.6681251194997557



ST: 0.6681254371070801 RNA: 0.7795189354040973
1.0 0.0 0.14310195 1.0 0.0 0.11382983
train/BRCA_5.pkl

train epoch 0: loss_ae: 0.081712 loss_mask_ae: 0.215325 loss_corr: 0.997204
train epoch 10: loss_ae: 0.061703 loss_mask_ae: 0.121295 loss_corr: 0.961805
train epoch 20: loss_ae: 0.050967 loss_mask_ae: 0.075758 loss_corr: 0.903685
train epoch 30: loss_ae: 0.058576 loss_mask_ae: 0.059730 loss_corr: 0.882914
train epoch 40: loss_ae: 0.050076 loss_mask_ae: 0.062279 loss_corr: 0.885362
train epoch 50: loss_ae: 0.044930 loss_mask_ae: 0.062752 loss_corr: 0.883589
train epoch 60: loss_ae: 0.040092 loss_mask_ae: 0.063784 loss_corr: 0.873173
train epoch 70: loss_ae: 0.041798 loss_mask_ae: 0.048662 loss_corr: 0.873078
train epoch 80: loss_ae: 0.037019 loss_mask_ae: 0.063391 loss_corr: 0.863394
train epoch 90: loss_ae: 0.035915 loss_mask_ae: 0.055762 loss_corr: 0.855291
Finished 5 Times Training and Testing...
(3798, 921) (45647, 921)
8.3757105 0.0 6.1131206 0.0
Before Norm ST: 0.6753751418910722



ST: 0.6753754594983964 RNA: 0.7792705828275794
1.0 0.0 0.13982828 1.0 0.0 0.11420379
train/BRCA_6.pkl

train epoch 0: loss_ae: 0.080174 loss_mask_ae: 0.215242 loss_corr: 1.002069
train epoch 10: loss_ae: 0.060992 loss_mask_ae: 0.113495 loss_corr: 0.955892
train epoch 20: loss_ae: 0.058187 loss_mask_ae: 0.071740 loss_corr: 0.925059
train epoch 30: loss_ae: 0.048689 loss_mask_ae: 0.088160 loss_corr: 0.885663
train epoch 40: loss_ae: 0.045981 loss_mask_ae: 0.067494 loss_corr: 0.875839
train epoch 50: loss_ae: 0.040723 loss_mask_ae: 0.069050 loss_corr: 0.874012
train epoch 60: loss_ae: 0.045294 loss_mask_ae: 0.097066 loss_corr: 0.900008
train epoch 70: loss_ae: 0.040028 loss_mask_ae: 0.065355 loss_corr: 0.881291
train epoch 80: loss_ae: 0.038392 loss_mask_ae: 0.072191 loss_corr: 0.870481
train epoch 90: loss_ae: 0.036715 loss_mask_ae: 0.054952 loss_corr: 0.857840
Finished 6 Times Training and Testing...
(3798, 921) (45647, 921)
8.3757105 0.0 6.1131206 0.0
Before Norm ST: 0.6715930738735579



ST: 0.6715933914808823 RNA: 0.7811551872484086
1.0 0.0 0.14178143 1.0 0.0 0.11343517
train/BRCA_7.pkl

train epoch 0: loss_ae: 0.081134 loss_mask_ae: 0.215777 loss_corr: 0.999044
train epoch 10: loss_ae: 0.064084 loss_mask_ae: 0.147117 loss_corr: 0.968628
train epoch 20: loss_ae: 0.050987 loss_mask_ae: 0.101794 loss_corr: 0.909930
train epoch 30: loss_ae: 0.053585 loss_mask_ae: 0.121723 loss_corr: 0.887206
train epoch 40: loss_ae: 0.048104 loss_mask_ae: 0.097077 loss_corr: 0.877878
train epoch 50: loss_ae: 0.043268 loss_mask_ae: 0.068906 loss_corr: 0.874849
train epoch 60: loss_ae: 0.040315 loss_mask_ae: 0.062267 loss_corr: 0.875507
train epoch 70: loss_ae: 0.037721 loss_mask_ae: 0.065606 loss_corr: 0.868804
train epoch 80: loss_ae: 0.036669 loss_mask_ae: 0.064098 loss_corr: 0.860343
train epoch 90: loss_ae: 0.035556 loss_mask_ae: 0.054339 loss_corr: 0.851740
Finished 7 Times Training and Testing...
(3798, 921) (45647, 921)
8.3757105 0.0 6.1131206 0.0
Before Norm ST: 0.6735670033939518



ST: 0.6735673210012761 RNA: 0.7800958702253933
1.0 0.0 0.13999219 1.0 0.0 0.11338467
train/BRCA_8.pkl

train epoch 0: loss_ae: 0.079913 loss_mask_ae: 0.213403 loss_corr: 0.997659
train epoch 10: loss_ae: 0.060584 loss_mask_ae: 0.120158 loss_corr: 0.961184
train epoch 20: loss_ae: 0.050569 loss_mask_ae: 0.099348 loss_corr: 0.902109
train epoch 30: loss_ae: 0.048586 loss_mask_ae: 0.095377 loss_corr: 0.876824
train epoch 40: loss_ae: 0.046253 loss_mask_ae: 0.091611 loss_corr: 0.877667
train epoch 50: loss_ae: 0.041326 loss_mask_ae: 0.056664 loss_corr: 0.871530
train epoch 60: loss_ae: 0.038350 loss_mask_ae: 0.058586 loss_corr: 0.864259
train epoch 70: loss_ae: 0.038736 loss_mask_ae: 0.076597 loss_corr: 0.859641
train epoch 80: loss_ae: 0.038240 loss_mask_ae: 0.075022 loss_corr: 0.850833
train epoch 90: loss_ae: 0.035732 loss_mask_ae: 0.054896 loss_corr: 0.849034
Finished 8 Times Training and Testing...
(3798, 921) (45647, 921)
8.3757105 0.0 6.1131206 0.0
Before Norm ST: 0.6785181839721369



ST: 0.6785185015794613 RNA: 0.7835552329338665
1.0 0.0 0.13808104 1.0 0.0 0.1121006
train/BRCA_9.pkl

train epoch 0: loss_ae: 0.079105 loss_mask_ae: 0.214153 loss_corr: 0.997699
train epoch 10: loss_ae: 0.062840 loss_mask_ae: 0.155482 loss_corr: 0.931539
train epoch 20: loss_ae: 0.054973 loss_mask_ae: 0.126566 loss_corr: 0.923705
train epoch 30: loss_ae: 0.049046 loss_mask_ae: 0.088776 loss_corr: 0.892152
train epoch 40: loss_ae: 0.047220 loss_mask_ae: 0.078958 loss_corr: 0.876901
train epoch 50: loss_ae: 0.043446 loss_mask_ae: 0.075487 loss_corr: 0.872055
train epoch 60: loss_ae: 0.040113 loss_mask_ae: 0.069214 loss_corr: 0.870668
train epoch 70: loss_ae: 0.042475 loss_mask_ae: 0.047241 loss_corr: 0.875449
train epoch 80: loss_ae: 0.038313 loss_mask_ae: 0.066235 loss_corr: 0.877425
train epoch 90: loss_ae: 0.036313 loss_mask_ae: 0.064111 loss_corr: 0.858517
Finished 9 Times Training and Testing...
