In [1]:
import numpy as np
import seaborn as sns
import torch

import umap
import matplotlib.pyplot as plt
import pandas as pd
from community import community_louvain
from torch_geometric.utils import k_hop_subgraph,to_networkx,from_networkx
import matplotlib

import utils
import plots
from model_AE import reduction_AE
from model_GAT import Encoder,SenGAE,train_GAT
from model_Sencell import Sencell

import logging
import os
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser(description='Main program for sencells')

parser.add_argument('--output_dir', type=str, default='./outputs', help='')
parser.add_argument('--exp_name', type=str, default='', help='')

args = parser.parse_args(args=[])

args.exp_name='s5'

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

logging.basicConfig(format='%(asctime)s.%(msecs)03d [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
                    datefmt='# %Y-%m-%d %H:%M:%S')

logging.getLogger().setLevel(logging.DEBUG)
logger = logging.getLogger()

# Part 1: load and process data
# cell_cluster_arr在画umap的时候用
adata,cluster_cell_ls,cell_cluster_arr,celltype_names=utils.load_data()
# plots.umapPlot(adata.obsm['X_umap'],clusters=cell_cluster_arr,labels=celltype_names)

new_data,markers_index,\
sen_gene_ls,nonsen_gene_ls,gene_names=utils.process_data(adata,cluster_cell_ls,cell_cluster_arr)

print(f'cell num: {new_data.shape[0]}, gene num: {new_data.shape[1]}')

gene_cell=new_data.X.toarray().T
cell_gene=gene_cell.T
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device: ',device)

# 2022-12-02 18:03:12.871 [DEBUG] [attrs.py:77] Creating converter from 3 to 5


cluster 数量： 21
celltype names: ['Macrophages', 'T cell lineage', 'Unknown', 'B cell lineage', 'Innate lymphoid cell NK', 'AT2', 'Monocytes', 'Multiciliated lineage', 'Dendritic cells', 'EC capillary', 'Mast cells', 'Fibroblasts', 'Secretory', 'EC venous', 'Lymphatic EC mature', 'AT1', 'Basal', 'EC arterial', 'Myofibroblasts', 'None', 'Submucosal Secretory']
-----------------------  ----
Macrophages              6941
T cell lineage            749
Unknown                   618
B cell lineage            374
Innate lymphoid cell NK   327
AT2                       294
Monocytes                 228
Multiciliated lineage     194
Dendritic cells           177
EC capillary              138
Mast cells                100
Fibroblasts                93
Secretory                  86
EC venous                  74
Lymphatic EC mature        68
AT1                        27
Basal                      26
EC arterial                20
Myofibroblasts             17
None                        6
Submucosal

In [None]:
%%time
def get_simi(i,my_dict):
    for j in range(i+1,cell_gene.shape[0]):
        u1=cell_gene[i]
        u2=cell_gene[j]
        # u1,u2必须是numpy.array，not tensor
        nz_u1 = u1.nonzero()[0]
        nz_u2 = u2.nonzero()[0]
        nz_inter = set(nz_u1) & set(nz_u2)
        nz_union = set(nz_u1) | set(nz_u2)
        if len(nz_inter) == 0:
            simi_score = 1 / (len(nz_union) + len(u1))
        elif len(nz_inter) == len(nz_union):
            simi_score = (len(nz_union) + len(u1) - 1) / (len(nz_union) + len(u1))
        else:
            simi_score = len(nz_inter) / len(nz_union)
        my_dict[(i,j)]=simi_score
    print(len(my_dict))

def eucliDistance(v1,v2):
    # 计算欧氏距离
    return F.pairwise_distance(v1.view(1,-1),v2.view(1,-1),p=2)

def loss_exp(v1,v2):
    return torch.exp(-0.1*eucliDistance(v1,v2))

sim1_ls=[]
cell_gene=gene_cell.T    
results_matrix=np.zeros((cell_gene.shape[0],cell_gene.shape[0])) 


from multiprocessing import Pool,Manager
import os, time, random


print('Parent process %s.' % os.getpid())
p = Pool()
manager = Manager()
my_dict = manager.dict()
for i in tqdm(range(cell_gene.shape[0])):
    p.apply_async(get_simi, args=(i,my_dict,))

print('Waiting for all subprocesses done...')
p.close()
p.join()
print('All subprocesses done.')

Parent process 55869.


100%|██████████| 10558/10558 [00:00<00:00, 41460.50it/s]


Waiting for all subprocesses done...
401573
403936
404288
404810
405413
406726
408220
416491
418637
420335
421053
421789
421931
422176
423045
423553
423626
423692
423763
423972
423673
424035
424363
424432
424534
424815
425189
425371
425532
425341
425731
426391
426816
427000
427067
427471
428642
428980
430127
434473
818048
821849
823665
824163
826087
831373
833510
834664
836438
836813
837457
837691
838155
838169
839141
839197
839208
840763
840956
840968
841000
841311
841568
842480
842503
842747
842959
844104
844254
845227
845728
847308
849024
850475
850786
851426
857699
861909
861998
864344
1239788
1240226
1245725
1246647
1251272
1251522
1251710
1253294
1253496
1253709
1254835
1254892
1255050
1255427
1255446
1255565
1255591
1256405
1256480
1257366
1257955
1257990
1259741
1259815
1260968
1262256
1262466
1262522
1262628
1262939
1263440
1264419
1264575
1267075
1267265
1271943
1278068
1280080
1281759
1284189
1657047
1659791
1661181
1664500
1664737
1666626
1667416
1667504
1667607
1668890
166

10438790
10439784
10440085
10441797
10444310
10445668
10446747
10450540
10450685
10453374
10454065
10454769
10459990
10463727
10464086
10465411
10468531
10474794
10477701
10781069
10782519
10786996
10792786
10793957
10799100
10800062
10802644
10805583
10806356
10807380
10809196
10810113
10811071
10811410
10813240
10813474
10814379
10814953
10816039
10816327
10816568
10816940
10817410
10818001
10823561
10823922
10824404
10829035
10829068
10836048
10836893
10838841
10840982
10842480
10842637
10847553
10847782
10855664
10870109
11157414
11160146
11162157
11170533
11173702
11177642
11179254
11181784
11182391
11184228
11186966
11187576
11188472
11189815
11190098
11190425
11190464
11191670
11193334
11193791
11194773
11195775
11196668
11198622
11200263
11200372
11203143
11203363
11204958
11205779
11215347
11216188
11216291
11216455
11220107
11222403
11223679
11224341
11233520
11250936
11533570
11537743
11538473
11550269
11550659
11556033
11557978
11560593
11562863
11563160
11564232
11564264
1

18755573
18757547
18758273
18759262
18759836
18760997
18761719
18761934
18765346
18768785
18769827
18770585
18772638
18772817
18773742
18774547
18776116
18777205
18781374
18783496
18783888
18788787
18789305
18791907
18793636
18819201
18827300
18827640
18873507
19074346
19083855
19084789
19088289
19089928
19090336
19092405
19093883
19095609
19097623
19099246
19099909
19100995
19101144
19101541
19102698
19103260
19104806
19105808
19107371
19111270
19112575
19113055
19114755
19116214
19116385
19118390
19119890
19120036
19120186
19124117
19128338
19128377
19130782
19135910
19137938
19162010
19171192
19173812
19215934
19414460
19424409
19427950
19429694
19430172
19433339
19434994
19436521
19438445
19439415
19439595
19439989
19440131
19442623
19444802
19446535
19446567
19448765
19452008
19452571
19453946
19454190
19454368
19456514
19457179
19458211
19458542
19458571
19460728
19460908
19463958
19469381
19469780
19471217
19473096
19481721
19501534
19513719
19514158
19555860
19753507
19761384
1

26216920
26222506
26226747
26226758
26227838
26232800
26235114
26235778
26236907
26237227
26237914
26238543
26241193
26242466
26243222
26247869
26248924
26249049
26252567
26252698
26254120
26254883
26255229
26262346
26264915
26265828
26266252
26266952
26269960
26270578
26271911
26278043
26289622
26290690
26291021
26294669
26309112
26340695
26360618
26519620
26528674
26530241
26533023
26534380
26534450
26539298
26539426
26542433
26542837
26542915
26543020
26543892
26547647
26548898
26550138
26552613
26554893
26555481
26555522
26557458
26557918
26559046
26561243
26568854
26568869
26569577
26570487
26573003
26577642
26578151
26578247
26590382
26595872
26599497
26599672
26604431
26614558
26649547
26664565
26830823
26834933
26835462
26836917
26839884
26840273
26842521
26843198
26844682
26845242
26845710
26847303
26850226
26852894
26853592
26858126
26858567
26858580
26858736
26859599
26864197
26864420
26864842
26870520
26870741
26873356
26875439
26875854
26876442
26878174
26879226
26879536
2

32648915
32650506
32671377
32672450
32672782
32673216
32680416
32730518
32735575
32849736
32851697
32859341
32861778
32865509
32870599
32870882
32873556
32873871
32876088
32876608
32879263
32880189
32881291
32882227
32883586
32883596
32885095
32885960
32886604
32887512
32888356
32888499
32888634
32892186
32899630
32900304
32901079
32903056
32905246
32912618
32921819
32924437
32940124
32942578
32942687
32948686
32955716
33006015
33010119
33119175
33119970
33128135
33129803
33135334
33136336
33140772
33142661
33143350
33144835
33145803
33146118
33146258
33150029
33150158
33151860
33151912
33154774
33156090
33156273
33156856
33156873
33157353
33157458
33160750
33170351
33170715
33171321
33173473
33175253
33188895
33189211
33193085
33207730
33210878
33218759
33219133
33228446
33274993
33280494
33386662
33391730
33394544
33400653
33401222
33403321
33410194
33410515
33411242
33412970
33414071
33414418
33414864
33415555
33416473
33416803
33417517
33421517
33422334
33422397
33422879
33424158
3

38445137
38445570
38446251
38450287
38452896
38457515
38460270
38464609
38470043
38473509
38474137
38476694
38490590
38504685
38507371
38519846
38530768
38586033
38591436
38648069
38652537
38654738
38655999
38656501
38657605
38657872
38659544
38665432
38667446
38667861
38668490
38669988
3867082038670824

38673728
38675615
38675659
38675874
38677754
38678204
38679101
38679173
38684056
38684355
38686224
38692461
38694443
38697203
38703307
38706875
38709798
38712256
38724756
38738871
38740945
38752308
38765280
38818503
38828963
38884682
38885422
38885983
38886701
38888921
38890205
38891698
38894176
38895657
38898910
38899224
38900358
38902032
38903501
38903721
38905125
38907868
38908617
38909335
38909487
38909582
38911950
38914310
38915013
38916766
38916970
38926285
38928559
38930886
38937719
38939133
38941219
38942717
38956956
38971813
38978907
38985706
38996158
39051677
39062507
39113709
39115528
39115638
39115956
39120053
39120561
39122962
39124609
39127039
39127880
39130861
39131471
3

43418422
43418766
43418908
43419314
43419500
43420430
43420902
43423309
43424083
43424178
43426549
43428196
43430560
43434042
43434728
43441097
43447493
43456153
43457653
43459672
43461687
43463805
43466938
43497939
43503247
43512130
43513985
43571910
43573341
43593741
43596694
43596932
43605940
43607833
43608030
43608229
43609811
43611393
43613264
43613689
43615181
43615384
43616639
43617749
43618624
43618886
43619575
43620236
43621282
43621588
43621712
43623656
43628133
43629238
43629558
43639114
43643990
43652033
43652772
43655577
43656502
43657667
43666642
43697992
43700564
43712052
43712271
43770898
43778112
43789736
43791045
43792172
43802306
43804163
43804492
43804924
43806013
43807371
43810736
43810762
43810987
43811206
43812069
43812283
43812414
43814161
43814362
43815227
43815647
43818597
43818848
43819735
43824110
43824980
43825063
43836318
43837081
43845287
43849145
43852877
43853210
43854102
43861418
43894432
43895483
43908311
43909826
43961431
43978026
43984884
43985221
4

47543878
47550965
47553463
47553781
47554563
47555077
47555749
47556034
47556056
47556475
47556566
47556931
47557183
47557205
47558092
47558383
47559143
47559545
47560590
47561636
47562988
47563161
47563726
47565608
47572209
47582473
47582593
47587766
47588598
47589720
47591442
47591558
47605770
47639633
47642088
47645930
47650736
47694688
47698293
47700941
47705137
47710441
47714158
47714389
47714804
47715335
47715362
47715752
47716060
47716539
47717430
47717592
47718155
47718250
47718295
47719467
47720315
47721047
47721928
47722187
47723826
47724147
47724506
47726009
47734128
47741078
47744822
47747819
47750109
47750255
47752398
47754833
47764278
47801266
47803471
47805301
47812790
47855831
47858639
47860998
47862979
47872036
47872587
47873451
47873636
47874105
47874265
47874417
47874948
47874966
47876108
47876456
47878832
47878902
47878953
47879309
47879602
47879807
47880508
47881300
47882450
47883298
47884734
47886059
47893360
47900876
47901904
47905840
47906700
47909815
47910175
4

50758540
50763476
50766391
50802004
50802144
50803683
50813466
50841164
50842139
50843156
50843301
50843692
50843825
50845937
50845979
50846902
50847157
50847219
50847332
50847840
50848152
50848576
50849315
50849493
50849540
50850237
50852196
50854029
50857694
50858032
50859681
50859722
50861844
50864985
50866199
50869724
50873506
50876050
50876530
50879853
50883832
50891186
50893775
50926961
50928880
50929010
50934882
50963931
50967212
50967738
50967998
50968814
50969154
50969590
50969875
50970156
50970498
50971029
50971483
50971529
50972066
50973329
50973605
50974245
50975173
50975481
50975644
50977431
50980752
50982535
50983651
50985674
50986004
50989602
50991395
50994128
50996016
50998844
50999872
51004443
51006622
51014642
51018447
51051387
51051948
51052402
51058189
51085947
51089263
51090085
51091649
51091956
51092072
51092473
51092519
51092643
51093075
51093426
51093800
51093888
51095063
51095405
51095693
51096501
51096556
51096633
51097930
51099409
51102806
51105140
51106201
5

53214221
53215553
53221496
53221784
53224268
53226108
53228233
53232428
53233895
53234066
53235661
53236583
53240524
53272611
53280184
53280697
53281525
53286431
53291752
53291941
53297007
53297326
53297983
53298189
53298483
53298654
53299161
53299550
53299755
53299892
53299971
5330000653300009

53300092
53300493
53300575
53300748
53300773
53300821
53302262
53302738
53302759
53310192
53310867
53311680
53315395
53316847
53320323
53323172
53323936
53325110
53326588
53329522
53362140
53366589
53368571
53368683
53373770
53377639
53379187
53382699
53384625
53384948
53385059
53385141
53386190
53386570
53386675
53387033
53387178
53387410
53387520
53387543
53387828
53388098
53388113
53388812
53388955
53388965
53389861
53390779
53392025
53397498
53398901
53399991
53402535
53402579
53408651
53408680
53411490
53412335
53413233
53417485
53448860
53453187
53454138
53454247
53458728
53464538
53467393
53467766
53468430
53469631
53470765
53471241
53471322
53471344
53471913
53472591
53472984
53473004
5

In [None]:
# 18:01

array([[0.        , 0.07983193, 0.05084746, ..., 0.07734807, 0.08571429,
        0.06965174],
       [0.        , 0.        , 0.25698324, ..., 0.3220339 , 0.14102564,
        0.19248826],
       [0.        , 0.        , 0.        , ..., 0.24427481, 0.08888889,
        0.11585366],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [None]:
for i in tqdm(range(cell_gene.shape[0])):
    for j in range(i+1,cell_gene.shape[0]):
        results_matrix[i][j]=my_dict[(i,j)]

In [None]:
torch.save(sim1_ls,"./outputs/sim1_ls")

In [4]:
%%time
import torch
from torch import nn, optim
from torch.nn import functional as F
import torch.utils.data as Data
import numpy as np


class AE(nn.Module):
    def __init__(self, dim, emb_dim=128):
        super(AE, self).__init__()
        self.dim = dim
        self.fc1 = nn.Linear(dim, 512)
        self.fc2 = nn.Linear(512, emb_dim)
        self.fc3 = nn.Linear(emb_dim, 512)
        self.fc4 = nn.Linear(512, dim)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return F.relu(self.fc2(h1))

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.relu(self.fc4(h3))

    def forward(self, x):
        z = self.encode(x.view(-1, self.dim))
        return self.decode(z), z
    
    
feature=torch.tensor(gene_cell.T)
feature=feature.to(device)
model = AE(dim=feature.shape[1]).to(device)
ba=feature.shape[0]
loader = Data.DataLoader(feature, ba)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

loss_func = nn.MSELoss()
EPOCH_AE = 2000
for epoch in range(EPOCH_AE):
    embeddings = []
    # loss_ls=[]
    for _, batch_x in enumerate(loader)	:
        decoded, encoded = model(batch_x)
        break
        loss1 = loss_func(batch_x, decoded)
        loss2 = 0
        for i in range(batch_x.shape[0]):
            for j in range(i+1,batch_x.shape[0]):
                v1=batch_x[i]
                v2=batch_x[j]
                h1=encoded[i]
                h2=encoded[j]
                sim1=get_simi(np.array(v1.cpu()),np.array(v2.cpu()))
                sim2=loss_exp(h1,h2)
                loss2+=sim1*(sim2-sim1).abs()
        print(loss1,loss2)
        loss=loss1+loss2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        embeddings.append(encoded)
    #     loss_ls.append(loss.item())
    # scheduler.step(np.mean(loss_ls))
#     print('Epoch :', epoch, '|', 'train_loss:%.12f' % loss.data)
    break

CPU times: user 102 ms, sys: 29.1 ms, total: 131 ms
Wall time: 129 ms


In [6]:
encoded.shape

torch.Size([10558, 128])

In [3]:
%%time
import torch
from torch import nn, optim
from torch.nn import functional as F
import torch.utils.data as Data
import numpy as np


class AE(nn.Module):
    def __init__(self, dim, emb_dim=128):
        super(AE, self).__init__()
        self.dim = dim
        self.fc1 = nn.Linear(dim, 512)
        self.fc2 = nn.Linear(512, emb_dim)
        self.fc3 = nn.Linear(emb_dim, 512)
        self.fc4 = nn.Linear(512, dim)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return F.relu(self.fc2(h1))

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.relu(self.fc4(h3))

    def forward(self, x):
        z = self.encode(x.view(-1, self.dim))
        return self.decode(z), z
    
    
feature=torch.tensor(gene_cell.T)
feature=feature.to(device)
model = AE(dim=feature.shape[1]).to(device)
ba=5000
loader = Data.DataLoader(feature, ba)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

loss_func = nn.MSELoss()
EPOCH_AE = 2000
for epoch in range(EPOCH_AE):
    embeddings = []
    # loss_ls=[]
    for _, batch_x in enumerate(loader)	:
        decoded, encoded = model(batch_x)
        loss1 = loss_func(batch_x, decoded)
        loss2 = 0
        for i in range(batch_x.shape[0]):
            for j in range(i+1,batch_x.shape[0]):
                v1=batch_x[i]
                v2=batch_x[j]
                h1=encoded[i]
                h2=encoded[j]
                sim1=get_simi(np.array(v1.cpu()),np.array(v2.cpu()))
                sim2=loss_exp(h1,h2)
                loss2+=sim1*(sim2-sim1).abs()
        print(loss1,loss2)
        loss=loss1+loss2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        embeddings.append(encoded)
    #     loss_ls.append(loss.item())
    # scheduler.step(np.mean(loss_ls))
    print('Epoch :', epoch, '|', 'train_loss:%.12f' % loss.data)
    break

KeyboardInterrupt: 

25000000