In [1]:

from __future__ import print_function, division
import argparse
import random
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
from sklearn.metrics import adjusted_rand_score as ari_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import Linear
from utils import load_data, load_graph
from GNN import GNNLayer
from evaluation import eva
from collections import Counter
from utils_func import adata_preprocess
from args_parser import set_parser
from utils import convert_str_to_int

In [2]:
class AE(nn.Module):

    def __init__(self, n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3,
                 n_input, n_z):
        super(AE, self).__init__()
        self.enc_1 = Linear(n_input, n_enc_1)        #2000->500
        self.enc_2 = Linear(n_enc_1, n_enc_2)        #500->500
        self.enc_3 = Linear(n_enc_2, n_enc_3)        #500->2000
        self.z_layer = Linear(n_enc_3, n_z)          #2000->10

        self.dec_1 = Linear(n_z, n_dec_1)            #10->2000
        self.dec_2 = Linear(n_dec_1, n_dec_2)        #2000->500
        self.dec_3 = Linear(n_dec_2, n_dec_3)        #500->500
        self.x_bar_layer = Linear(n_dec_3, n_input)  #500->2000

    def forward(self, x):
        enc_h1 = F.relu(self.enc_1(x))
        enc_h2 = F.relu(self.enc_2(enc_h1))
        enc_h3 = F.relu(self.enc_3(enc_h2))
        z = self.z_layer(enc_h3)

        dec_h1 = F.relu(self.dec_1(z))
        dec_h2 = F.relu(self.dec_2(dec_h1))
        dec_h3 = F.relu(self.dec_3(dec_h2))
        x_bar = self.x_bar_layer(dec_h3)

        return x_bar, enc_h1, enc_h2, enc_h3, z  #x_bar是重构的表达矩阵  z是隐藏空间纬度


class SDCN(nn.Module):

    def __init__(self, n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3, 
                n_input, n_z, n_clusters, v=1):  #n_input是表达矩阵，AE用来重构表达矩阵？
        super(SDCN, self).__init__()

        # autoencoder for intra information
        self.ae = AE(
            n_enc_1=n_enc_1,
            n_enc_2=n_enc_2,
            n_enc_3=n_enc_3,
            n_dec_1=n_dec_1,
            n_dec_2=n_dec_2,
            n_dec_3=n_dec_3,
            n_input=n_input,
            n_z=n_z)
        #self.ae.load_state_dict(torch.load(args.pretrain_path, map_location='cpu'))

        # GCN for inter information
        self.gnn_1 = GNNLayer(n_input, n_enc_1)   #2000->500
        self.gnn_2 = GNNLayer(n_enc_1, n_enc_2)   #500->500
        self.gnn_3 = GNNLayer(n_enc_2, n_enc_3)   #500->2000
        self.gnn_4 = GNNLayer(n_enc_3, n_z)       #2000->10
        self.gnn_5 = GNNLayer(n_z, n_clusters)    #10->8

        # cluster layer
        self.cluster_layer = Parameter(torch.Tensor(n_clusters, n_z))
        torch.nn.init.xavier_normal_(self.cluster_layer.data)

        # degree
        self.v = v

    def forward(self, x, adj):
        # DNN Module
        x_bar, tra1, tra2, tra3, z = self.ae(x)
        
        sigma = 0.5

        # GCN Module
        h = self.gnn_1(x, adj)
        h = self.gnn_2((1-sigma)*h + sigma*tra1, adj)
        h = self.gnn_3((1-sigma)*h + sigma*tra2, adj)
        h = self.gnn_4((1-sigma)*h + sigma*tra3, adj)
        h = self.gnn_5((1-sigma)*h + sigma*z, adj, active=False)
        predict = F.softmax(h, dim=1)

        # Dual Self-supervised Module
        q = 1.0 / (1.0 + torch.sum(torch.pow(z.unsqueeze(1) - self.cluster_layer, 2), 2) / self.v)
        q = q.pow((self.v + 1.0) / 2.0)
        q = (q.t() / torch.sum(q, 1)).t()

        return x_bar, q, predict, z


def target_distribution(q):
    weight = q**2 / q.sum(0)
    return (weight.t() / weight.sum(1)).t()


In [3]:
def train_sdcn(dataset):
    model = SDCN(300, 300, 500, 500, 300, 300,
                n_input=args.n_input,               #default==500   
                n_z=args.n_z,                       #default==30
                n_clusters=args.n_clusters,         #default==8
                v=1.0).to(device)
    print(model)

    optimizer = Adam(model.parameters(), lr=args.lr)

    params = set_parser()

    # KNN Graph
    adata = load_data(151673)
    adj_dict = load_graph(adata)   #adj是一个稀疏格式的矩阵，to_dense()查看原样
    adj = adj_dict["adj_norm"]
    adj = adj.cuda()

    # cluster parameter initiate
    adata_X = adata_preprocess(adata, min_cells=5, pca_n_comps=params.cell_feat_dim)
    adata_X = adata_X.copy()
    adata_X = torch.Tensor(adata_X).to(device)
    y,_ = convert_str_to_int(adata)                               ###把字符串类变成数子类，20211214
    y = np.array(y)
    with torch.no_grad():
        _, _, _, _, z = model.ae(adata_X)   #z是十维

    kmeans = KMeans(n_clusters=args.n_clusters, n_init=20)
    y_pred = kmeans.fit_predict(z.data.cpu().numpy())
    y_pred_last = y_pred
    model.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(device)
    eva(y, y_pred, 'pae')


    for epoch in range(200):
        if epoch % 1 == 0:
        # update_interval
            _, tmp_q, pred, _ = model(adata_X, adj)
            tmp_q = tmp_q.data
            p = target_distribution(tmp_q)
        
            res1 = tmp_q.cpu().numpy().argmax(1)       #Q
            res2 = pred.data.cpu().numpy().argmax(1)   #Z
            res3 = p.data.cpu().numpy().argmax(1)      #P
            eva(y, res1, str(epoch) + 'Q')
            eva(y, res2, str(epoch) + 'Z')
            eva(y, res3, str(epoch) + 'P')

        x_bar, q, pred, _ = model(adata_X, adj)

        kl_loss = F.kl_div(q.log(), p, reduction='batchmean')
        ce_loss = F.kl_div(pred.log(), p, reduction='batchmean')
        re_loss = F.mse_loss(x_bar, adata_X)

        loss = 0.1 * kl_loss + 0.01 * ce_loss + re_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [4]:
if __name__ == "__main__":
    
    args = set_parser()
    args.cuda = torch.cuda.is_available()
    print("use cuda: {}".format(args.cuda))
    device = torch.device("cuda" if args.cuda else "cpu")

    
    dataset = load_data(151673)


    print(args)
    train_sdcn(dataset)



use cuda: True


Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.


adata: (3639, 33538)
Namespace(cell_feat_dim=500, cuda=True, dec_interval=20, dec_kl_w=100, dec_tol=0.0, epochs=300, eval_graph_n=20, eval_resolution=1, feat_hidden1=100, feat_hidden2=20, feat_w=10, gcn_decay=0.01, gcn_hidden1=32, gcn_hidden2=8, gcn_lr=0.01, gcn_w=0.1, k=10, knn_distanceType='euclidean', lr=0.001, n_clusters=8, n_input=500, n_z=10, p_drop=0.2, pretrain_path='pkl')
SDCN(
  (ae): AE(
    (enc_1): Linear(in_features=500, out_features=300, bias=True)
    (enc_2): Linear(in_features=300, out_features=300, bias=True)
    (enc_3): Linear(in_features=300, out_features=500, bias=True)
    (z_layer): Linear(in_features=500, out_features=10, bias=True)
    (dec_1): Linear(in_features=10, out_features=500, bias=True)
    (dec_2): Linear(in_features=500, out_features=300, bias=True)
    (dec_3): Linear(in_features=300, out_features=300, bias=True)
    (x_bar_layer): Linear(in_features=300, out_features=500, bias=True)
  )
  (gnn_1): GNNLayer()
  (gnn_2): GNNLayer()
  (gnn_3): GNNLa

Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.


adata: (3639, 33538)
===== Preprocessing Data 
pae :acc 0.3537 , nmi 0.1580 , ari 0.1317 , f1 0.1949
0Q :acc 0.3537 , nmi 0.1580 , ari 0.1317 , f1 0.1949
0Z :acc 0.3594 , nmi 0.1690 , ari 0.0901 , f1 0.1436
0P :acc 0.3441 , nmi 0.1540 , ari 0.1315 , f1 0.2001
1Q :acc 0.3100 , nmi 0.1349 , ari 0.0924 , f1 0.1522
1Z :acc 0.4026 , nmi 0.2382 , ari 0.1602 , f1 0.1488
1P :acc 0.2660 , nmi 0.1279 , ari 0.0635 , f1 0.1915
2Q :acc 0.3168 , nmi 0.1391 , ari 0.0688 , f1 0.1537
2Z :acc 0.2817 , nmi 0.0242 , ari 0.0020 , f1 0.0708
2P :acc 0.2899 , nmi 0.1331 , ari 0.0573 , f1 0.1713
3Q :acc 0.2998 , nmi 0.1411 , ari 0.0584 , f1 0.1405
3Z :acc 0.2808 , nmi 0.0412 , ari 0.0024 , f1 0.0825
3P :acc 0.2957 , nmi 0.1433 , ari 0.0529 , f1 0.1879
4Q :acc 0.2960 , nmi 0.1428 , ari 0.0526 , f1 0.1405
4Z :acc 0.3215 , nmi 0.1592 , ari 0.0903 , f1 0.1221
4P :acc 0.2877 , nmi 0.1399 , ari 0.0481 , f1 0.1854
5Q :acc 0.2863 , nmi 0.1420 , ari 0.0462 , f1 0.1414
5Z :acc 0.3721 , nmi 0.2053 , ari 0.1421 , f1 0.175