In [13]:
import numpy as np
import torch
import scipy.io as sio
import torch.nn.functional as F
from configure import get_default_config
from model_starmap import Model
from tools import normalize_type
import random
from sklearn.model_selection import KFold

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dataset = {
    0: "dro",
    1: "MERFISH",
    2: "STARmap",
    4: "BRCA",
    5: "ELSE",
}
data_name = dataset[2]
shuffle = False
random.seed(48)
ktimes = 10
cv = KFold(n_splits=ktimes, shuffle=False)
mat_name = 'data/' + data_name + '.mat'
data = sio.loadmat(mat_name)
genes = [i for item in data['genes'].flatten() for i in item]

In [14]:
def pearsonr(x, y):
    mean_x = torch.mean(x)
    mean_y = torch.mean(y)
    xm = x.sub(mean_x)
    ym = y.sub(mean_y)
    r_num = xm.dot(ym)
    r_den = torch.norm(xm, 2) * torch.norm(ym, 2)
    r_val = r_num / (r_den + 1e-8)
    r_val = torch.nan_to_num(r_val,nan=-1)
    return r_val

def correlationMetric(x, y):
    corr = 0
    for idx in range(x.size(1)):
        corr += pearsonr(x[:,idx], y[:,idx])
    corr /= (idx + 1)
    return (1 - corr).mean()

In [15]:
class MinMaxScaler_torch:
    def __init__(self, feature_range=(0, 1)):
        self.min_val = None
        self.max_val = None
        self.feature_range = feature_range

    def fit(self, data):
        self.min_val, _ = torch.min(data, dim=0)
        self.max_val, _ = torch.max(data, dim=0)

    def transform(self, data):
        data_minmax = (data - self.min_val) / (self.max_val - self.min_val)
        data_minmax = data_minmax * (self.feature_range[1] - self.feature_range[0]) + self.feature_range[0]
        return data_minmax

    def inverse_transform(self, data_minmax):
        data = (data_minmax - self.feature_range[0]) / (self.feature_range[1] - self.feature_range[0])
        data = data * (self.max_val - self.min_val) + self.min_val
        return data

In [16]:
def train(model, optimizer, x1, x2, config, model_path):
    print('\n===========> Training... <===========')
    model.train()
    if not isinstance(x1, torch.Tensor):
        x1 = torch.from_numpy(x1).to(device)
    if not isinstance(x2, torch.Tensor):
        x2 = torch.from_numpy(x2).to(device)
    for epoch in range(config['pretrain_epochs']):
        x1_hat = model(x2)
        mask_zero = x1 != 0
        x1_masked = torch.masked_select(x1, mask_zero)
        x1_masked_pred = torch.masked_select(x1_hat, mask_zero)
        loss_ae = F.mse_loss(x1, x1_hat, reduction='mean')
        loss_mask_ae = F.mse_loss(x1_masked, x1_masked_pred, reduction='mean')
        corrloss = correlationMetric(x1_hat, x1)
        loss = loss_ae + 0.001 * loss_mask_ae + 0.01 * corrloss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % 10 == 0:
            print ('train epoch %d: loss_ae: %.6f loss_mask_ae: %.6f loss_corr: %.6f' \
                   % (epoch, loss.item(), loss_mask_ae.item(), corrloss.item()))
    torch.save(model.state_dict(), model_path)


In [17]:
x1_loc = data['locations']
atlas_genes = data['genes']
atlas_genes = atlas_genes.reshape(-1)
atlas_genes = [i for item in atlas_genes for i in item]
x1_cell = data['ST']
x2_rna = data['rna']
cnts = 0
tmp_dims = 0
result = np.empty((x1_cell.shape[0], 0))


In [18]:
for train_idx, test_idx in cv.split(atlas_genes):
    x1_train_cell = x1_cell[:, train_idx]
    x1_test_cell = x1_cell[:, test_idx]
    x2_train_cell = x2_rna[:, train_idx]
    x2_test_cell = x2_rna[:, test_idx]
    print(x1_loc.shape, x1_cell.shape, x2_rna.shape)
    x1 = x1_train_cell.astype(np.float32)
    x2 = x2_train_cell.astype(np.float32)
    print (x1.max(), x1.min(), x2.max(), x2.min())
    count_zerox1 = np.sum(x1 == 0)
    count_zerox2 = np.sum(x2 == 0)
    print('Before Norm ST:', float(count_zerox1) / float(x1.shape[0] * x1.shape[1]), 'Before Norm RNA:',
          float(count_zerox2) / float(x2.shape[0] * x2.shape[1]))
    x1, x2 = normalize_type(x1, x2, type='all')
    print ('Xrna shape:', x2.shape,  'Xst shape:', x1.shape)
    count_zerox1 = np.sum(x1 == 0)
    count_zerox2 = np.sum(x2 == 0)
    print ('ST:', float(count_zerox1) / float(x1.shape[0]*x1.shape[1]), 'RNA:', float(count_zerox2) / float(x2.shape[0]*x2.shape[1]))
    print (x1.max(), x1.min(), x1.mean(), x2.max(), x2.min(), x2.mean())
    config = get_default_config(data_name)
    config['num_sample1'] = x1.shape[0]
    config['num_sample2'] = x2.shape[0]
    train_path = "train/%s_%d.pkl" % (data_name, cnts)
    print (train_path)
    tmp_dims = x1.shape[1]
    model = Model(config)
    model.to(device)
    optimizer_pre = torch.optim.Adam(model.parameters(), lr=config['pre_lr'])
    train(model, optimizer_pre, x1, x2, config, train_path)
    print("Finished " + str(cnts) + ' Times Training and Testing...')
    cnts += 1


(1549, 2) (1549, 996) (15413, 996)
333.0 0.0 411165.0 0.0
Before Norm ST: 0.7901958636908605 Before Norm RNA: 0.5123166841997943
Xrna shape: (15413, 896) Xst shape: (1549, 896)
ST: 0.7901965842017892 RNA: 0.5123166841997943
1.0 0.0 0.05263017 1.0 0.0 0.036019653
train/STARmap_0.pkl
{'dims': [2222, 512, 256], 'pretrain_epochs': 400, 'epochs': 15, 'pre_lr': 0.0001, 'lr': 0.0001, 'batch_size': 256, 'weight_map': 10, 'weight_coef': 0, 'weight_mmd': 0, 'weight_ent': 0, 'alpha': 1, 'ot': {'epochs': 300, 'lr': 0.05, 'step_size': 300, 'tau': 1, 'it': 3, 'epsilon': 1, 'num_iter': 5}, 'num_sample1': 1549, 'num_sample2': 15413}

train epoch 0: loss_ae: 0.028833 loss_mask_ae: 0.087751 loss_corr: 1.005404
train epoch 10: loss_ae: 0.023509 loss_mask_ae: 0.064130 loss_corr: 0.804869
train epoch 20: loss_ae: 0.022450 loss_mask_ae: 0.056524 loss_corr: 0.772336
train epoch 30: loss_ae: 0.022311 loss_mask_ae: 0.057325 loss_corr: 0.766850
train epoch 40: loss_ae: 0.022258 loss_mask_ae: 0.057390 loss_corr:


train epoch 0: loss_ae: 0.028779 loss_mask_ae: 0.087262 loss_corr: 1.006526
train epoch 10: loss_ae: 0.023559 loss_mask_ae: 0.064189 loss_corr: 0.806980
train epoch 20: loss_ae: 0.022475 loss_mask_ae: 0.056684 loss_corr: 0.771737
train epoch 30: loss_ae: 0.022351 loss_mask_ae: 0.057663 loss_corr: 0.767149
train epoch 40: loss_ae: 0.022298 loss_mask_ae: 0.057733 loss_corr: 0.765096
train epoch 50: loss_ae: 0.022266 loss_mask_ae: 0.057373 loss_corr: 0.764223
train epoch 60: loss_ae: 0.022236 loss_mask_ae: 0.056924 loss_corr: 0.763687
train epoch 70: loss_ae: 0.022190 loss_mask_ae: 0.056734 loss_corr: 0.762828
train epoch 80: loss_ae: 0.022199 loss_mask_ae: 0.059591 loss_corr: 0.762820
train epoch 90: loss_ae: 0.022069 loss_mask_ae: 0.054668 loss_corr: 0.760804
train epoch 100: loss_ae: 0.021974 loss_mask_ae: 0.056222 loss_corr: 0.758667
train epoch 110: loss_ae: 0.021870 loss_mask_ae: 0.055415 loss_corr: 0.757757
train epoch 120: loss_ae: 0.021711 loss_mask_ae: 0.055165 loss_corr: 0.757

train epoch 70: loss_ae: 0.022080 loss_mask_ae: 0.056394 loss_corr: 0.760123
train epoch 80: loss_ae: 0.022097 loss_mask_ae: 0.052654 loss_corr: 0.759855
train epoch 90: loss_ae: 0.021944 loss_mask_ae: 0.057702 loss_corr: 0.757686
train epoch 100: loss_ae: 0.021775 loss_mask_ae: 0.056147 loss_corr: 0.755802
train epoch 110: loss_ae: 0.021673 loss_mask_ae: 0.052661 loss_corr: 0.755059
train epoch 120: loss_ae: 0.021735 loss_mask_ae: 0.056508 loss_corr: 0.755290
train epoch 130: loss_ae: 0.021629 loss_mask_ae: 0.051372 loss_corr: 0.754022
train epoch 140: loss_ae: 0.021434 loss_mask_ae: 0.054499 loss_corr: 0.751108
train epoch 150: loss_ae: 0.021328 loss_mask_ae: 0.054307 loss_corr: 0.747942
train epoch 160: loss_ae: 0.021182 loss_mask_ae: 0.053997 loss_corr: 0.743024
train epoch 170: loss_ae: 0.021040 loss_mask_ae: 0.053448 loss_corr: 0.738374
train epoch 180: loss_ae: 0.021012 loss_mask_ae: 0.051124 loss_corr: 0.735925
train epoch 190: loss_ae: 0.020905 loss_mask_ae: 0.052284 loss_corr

train epoch 150: loss_ae: 0.021465 loss_mask_ae: 0.055888 loss_corr: 0.746043
train epoch 160: loss_ae: 0.021378 loss_mask_ae: 0.054059 loss_corr: 0.742635
train epoch 170: loss_ae: 0.021284 loss_mask_ae: 0.050722 loss_corr: 0.739193
train epoch 180: loss_ae: 0.021134 loss_mask_ae: 0.052846 loss_corr: 0.735703
train epoch 190: loss_ae: 0.021043 loss_mask_ae: 0.052563 loss_corr: 0.732632
train epoch 200: loss_ae: 0.020971 loss_mask_ae: 0.051355 loss_corr: 0.729913
train epoch 210: loss_ae: 0.020894 loss_mask_ae: 0.052506 loss_corr: 0.726907
train epoch 220: loss_ae: 0.020912 loss_mask_ae: 0.052345 loss_corr: 0.727751
train epoch 230: loss_ae: 0.020842 loss_mask_ae: 0.051215 loss_corr: 0.724574
train epoch 240: loss_ae: 0.020770 loss_mask_ae: 0.051959 loss_corr: 0.721265
train epoch 250: loss_ae: 0.020722 loss_mask_ae: 0.051682 loss_corr: 0.718879
train epoch 260: loss_ae: 0.020685 loss_mask_ae: 0.051166 loss_corr: 0.716794
train epoch 270: loss_ae: 0.020659 loss_mask_ae: 0.051632 loss_c

train epoch 230: loss_ae: 0.020781 loss_mask_ae: 0.052796 loss_corr: 0.723584
train epoch 240: loss_ae: 0.020740 loss_mask_ae: 0.051058 loss_corr: 0.721367
train epoch 250: loss_ae: 0.020695 loss_mask_ae: 0.052655 loss_corr: 0.719085
train epoch 260: loss_ae: 0.020663 loss_mask_ae: 0.052038 loss_corr: 0.717378
train epoch 270: loss_ae: 0.020629 loss_mask_ae: 0.051944 loss_corr: 0.715525
train epoch 280: loss_ae: 0.020680 loss_mask_ae: 0.050941 loss_corr: 0.716490
train epoch 290: loss_ae: 0.020577 loss_mask_ae: 0.051533 loss_corr: 0.712199
train epoch 300: loss_ae: 0.020522 loss_mask_ae: 0.050471 loss_corr: 0.708885
train epoch 310: loss_ae: 0.020482 loss_mask_ae: 0.051104 loss_corr: 0.706959
train epoch 320: loss_ae: 0.020493 loss_mask_ae: 0.052562 loss_corr: 0.706507
train epoch 330: loss_ae: 0.020418 loss_mask_ae: 0.050694 loss_corr: 0.702893
train epoch 340: loss_ae: 0.020362 loss_mask_ae: 0.051199 loss_corr: 0.700556
train epoch 350: loss_ae: 0.020344 loss_mask_ae: 0.050640 loss_c