In [7]:
import sys
# ^^^ pyforest auto-imports - don't write above this line
import numpy as np
import pandas as pd
import scanpy as sc

from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score, normalized_mutual_info_score
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.cluster import SpectralClustering

from sklearn.decomposition import PCA, SparsePCA, KernelPCA
from sklearn.manifold import TSNE

from rpy2.robjects import r, pandas2ri
from rpy2.robjects.vectors import StrVector

pandas2ri.activate()

import magic
import scprep

%matplotlib inline

# from sklearnex import patch_sklearn
# patch_sklearn()

import warnings

from sklearn.cluster import KMeans
from tqdm import tqdm

sys.path.insert(0, '../../imputation2/notebooks/repos/GNNImpute/')
from GNNImpute.api import GNNImpute

In [2]:
def get_cluster_metrics(pred, labels):
    ari_res = []
    ami_res = []
    nmi_res = []

#     try:
#         r.assign("data", pred.T)
#         seurat = r('''
#             countsData = data.frame(data)
#             pbmc <- CreateSeuratObject(counts = countsData, project = "thal_single_cell", min.cells = 1, min.features = 1)
#             pbmc <- FindVariableFeatures(pbmc, selection.method = "vst", verbose=FALSE)
#             all.genes <- rownames(pbmc)
#             pbmc <- ScaleData(pbmc, features = all.genes, verbose=FALSE)
#             pbmc <- RunPCA(pbmc, verbose=FALSE)
#             pbmc <- FindNeighbors(pbmc, verbose=FALSE)
#             pbmc <- FindClusters(pbmc, verbose=FALSE)
#             Idents(pbmc)
#         ''')
#         ari_res.append(adjusted_rand_score(labels, seurat))
#         ami_res.append(adjusted_mutual_info_score(labels, seurat))
#         nmi_res.append(normalized_mutual_info_score(labels, seurat))
#     except Exception as e:
#         pass

    pred_ = KMeans(n_clusters=len(np.unique(labels)), random_state=42).fit_predict(pred)

    ari_res.append(adjusted_rand_score(labels, pred_))
    ami_res.append(adjusted_mutual_info_score(labels, pred_))
    nmi_res.append(normalized_mutual_info_score(labels, pred_))

    warnings.filterwarnings("error")

    affinities = ['cosine', 'linear', 'poly']

    for i in affinities:
        try:
            pred_ = SpectralClustering(
                n_clusters=len(np.unique(labels)), 
                random_state=42, 
                affinity=i
            ).fit_predict(pred)
            ari_res.append(adjusted_rand_score(labels, pred_))
            ami_res.append(adjusted_mutual_info_score(labels, pred_))
            nmi_res.append(normalized_mutual_info_score(labels, pred_))
        except:
            ari_res.append(0)
            ami_res.append(0)
            nmi_res.append(0)

    warnings.resetwarnings()
    
    return max(ari_res), max(ami_res), max(nmi_res)

In [3]:
def get_data(i):
    df = pd.read_csv('../data/{}/data.csv.gz'.format(i), index_col=0)
    tmp = np.sign(df)
    cols = (np.sum(tmp) > int((df.shape[0])*0.05))
    rows = (np.sum(tmp, axis=1) > int((df.shape[1])*0.05))
    df = np.log(df.loc[rows, cols] + 1)
    df_norm = df.copy()
    df_norm = scprep.normalize.library_size_normalize(df_norm)    
    df_norm = scprep.transform.sqrt(df_norm)
    X_norm = pd.DataFrame(df_norm, columns=df.columns)
    labels = df.index
    return X_norm, labels

In [4]:
dir_list = !ls ../data/ | grep -v zip
dir_list

['baron',
 'bmcite',
 'brosens',
 'carey',
 'cbmc',
 'chang',
 'Fujii',
 'hcabm40k',
 'hrvatin',
 'jakel',
 'jiang',
 'loureiro',
 'manno',
 'mingyao',
 'pbmc3k',
 'Selewa',
 'Xu']

In [5]:
len(dir_list)

17

In [10]:
res = []
for i in tqdm(dir_list):
    X_norm, labels = get_data(i)

    adata = sc.AnnData(X_norm.values)
    adata = GNNImpute(
        adata=adata, layer='GATConv',
        no_cuda=False,
        d = '/export/scratch/inoue019/gnn/',
        epochs=3000, 
        lr=0.001, weight_decay=0.0005,
        hidden=50, patience=200,
        fastmode=True, heads=3,
        use_raw=False,
        verbose=True
    )
    
    pred = adata.X
    pred = pd.DataFrame(pred, columns=X_norm.columns, index=X_norm.index)

    res.append(get_cluster_metrics(pred, labels))

  0%|          | 0/17 [00:00<?, ?it/s]

Epoch: 0010 loss_train: 1.0006 loss_val: 0.9993
Epoch: 0020 loss_train: 0.9998 loss_val: 0.9985
Epoch: 0030 loss_train: 0.9981 loss_val: 0.9970
Epoch: 0040 loss_train: 0.9917 loss_val: 0.9909
Epoch: 0050 loss_train: 0.9834 loss_val: 0.9846
Epoch: 0060 loss_train: 0.9782 loss_val: 0.9811
Epoch: 0070 loss_train: 0.9730 loss_val: 0.9785
Epoch: 0080 loss_train: 0.9687 loss_val: 0.9730
Epoch: 0090 loss_train: 0.9667 loss_val: 0.9717
Epoch: 0100 loss_train: 0.9636 loss_val: 0.9708
Epoch: 0110 loss_train: 0.9637 loss_val: 0.9695
Epoch: 0120 loss_train: 0.9616 loss_val: 0.9689
Epoch: 0130 loss_train: 0.9609 loss_val: 0.9686
Epoch: 0140 loss_train: 0.9597 loss_val: 0.9676
Epoch: 0150 loss_train: 0.9594 loss_val: 0.9674
Epoch: 0160 loss_train: 0.9597 loss_val: 0.9678
Epoch: 0170 loss_train: 0.9598 loss_val: 0.9671
Epoch: 0180 loss_train: 0.9587 loss_val: 0.9680
Epoch: 0190 loss_train: 0.9588 loss_val: 0.9679
Epoch: 0200 loss_train: 0.9589 loss_val: 0.9681
Epoch: 0210 loss_train: 0.9588 loss_val:

  6%|▌         | 1/17 [00:32<08:38, 32.38s/it]

Epoch: 0010 loss_train: 0.9988 loss_val: 0.9972
Epoch: 0020 loss_train: 0.9937 loss_val: 0.9926
Epoch: 0030 loss_train: 0.9857 loss_val: 0.9849
Epoch: 0040 loss_train: 0.9754 loss_val: 0.9754
Epoch: 0050 loss_train: 0.9650 loss_val: 0.9652
Epoch: 0060 loss_train: 0.9554 loss_val: 0.9572
Epoch: 0070 loss_train: 0.9493 loss_val: 0.9509
Epoch: 0080 loss_train: 0.9451 loss_val: 0.9484
Epoch: 0090 loss_train: 0.9428 loss_val: 0.9460
Epoch: 0100 loss_train: 0.9418 loss_val: 0.9453
Epoch: 0110 loss_train: 0.9404 loss_val: 0.9439
Epoch: 0120 loss_train: 0.9399 loss_val: 0.9432
Epoch: 0130 loss_train: 0.9395 loss_val: 0.9435
Epoch: 0140 loss_train: 0.9391 loss_val: 0.9425
Epoch: 0150 loss_train: 0.9387 loss_val: 0.9423
Epoch: 0160 loss_train: 0.9383 loss_val: 0.9419
Epoch: 0170 loss_train: 0.9377 loss_val: 0.9420
Epoch: 0180 loss_train: 0.9378 loss_val: 0.9424
Epoch: 0190 loss_train: 0.9381 loss_val: 0.9425
Epoch: 0200 loss_train: 0.9376 loss_val: 0.9419
Epoch: 0210 loss_train: 0.9377 loss_val:

 12%|█▏        | 2/17 [02:02<16:31, 66.12s/it]

Epoch: 0010 loss_train: 0.9990 loss_val: 1.0027
Epoch: 0020 loss_train: 0.9981 loss_val: 1.0020
Epoch: 0030 loss_train: 0.9926 loss_val: 0.9970
Epoch: 0040 loss_train: 0.9808 loss_val: 0.9862
Epoch: 0050 loss_train: 0.9721 loss_val: 0.9784
Epoch: 0060 loss_train: 0.9659 loss_val: 0.9732
Epoch: 0070 loss_train: 0.9631 loss_val: 0.9711
Epoch: 0080 loss_train: 0.9614 loss_val: 0.9703
Epoch: 0090 loss_train: 0.9614 loss_val: 0.9704
Epoch: 0100 loss_train: 0.9605 loss_val: 0.9692
Epoch: 0110 loss_train: 0.9597 loss_val: 0.9692
Epoch: 0120 loss_train: 0.9597 loss_val: 0.9684
Epoch: 0130 loss_train: 0.9590 loss_val: 0.9687
Epoch: 0140 loss_train: 0.9584 loss_val: 0.9683
Epoch: 0150 loss_train: 0.9583 loss_val: 0.9679
Epoch: 0160 loss_train: 0.9582 loss_val: 0.9683
Epoch: 0170 loss_train: 0.9578 loss_val: 0.9678
Epoch: 0180 loss_train: 0.9576 loss_val: 0.9676
Epoch: 0190 loss_train: 0.9576 loss_val: 0.9672
Epoch: 0200 loss_train: 0.9573 loss_val: 0.9676
Epoch: 0210 loss_train: 0.9576 loss_val:

 18%|█▊        | 3/17 [03:03<14:58, 64.15s/it]

Epoch: 0010 loss_train: 1.0010 loss_val: 1.0035
Epoch: 0020 loss_train: 0.9996 loss_val: 1.0023
Epoch: 0030 loss_train: 0.9939 loss_val: 0.9972
Epoch: 0040 loss_train: 0.9804 loss_val: 0.9838
Epoch: 0050 loss_train: 0.9701 loss_val: 0.9741
Epoch: 0060 loss_train: 0.9629 loss_val: 0.9670
Epoch: 0070 loss_train: 0.9580 loss_val: 0.9624
Epoch: 0080 loss_train: 0.9545 loss_val: 0.9594
Epoch: 0090 loss_train: 0.9527 loss_val: 0.9585
Epoch: 0100 loss_train: 0.9516 loss_val: 0.9567
Epoch: 0110 loss_train: 0.9510 loss_val: 0.9566
Epoch: 0120 loss_train: 0.9504 loss_val: 0.9558
Epoch: 0130 loss_train: 0.9496 loss_val: 0.9555
Epoch: 0140 loss_train: 0.9488 loss_val: 0.9552
Epoch: 0150 loss_train: 0.9488 loss_val: 0.9546
Epoch: 0160 loss_train: 0.9488 loss_val: 0.9547
Epoch: 0170 loss_train: 0.9482 loss_val: 0.9549
Epoch: 0180 loss_train: 0.9483 loss_val: 0.9545
Epoch: 0190 loss_train: 0.9481 loss_val: 0.9541
Epoch: 0200 loss_train: 0.9482 loss_val: 0.9545
Epoch: 0210 loss_train: 0.9477 loss_val:

 24%|██▎       | 4/17 [05:07<18:57, 87.46s/it]

Epoch: 0010 loss_train: 0.9992 loss_val: 0.9997
Epoch: 0020 loss_train: 0.9988 loss_val: 0.9994
Epoch: 0030 loss_train: 0.9972 loss_val: 0.9977
Epoch: 0040 loss_train: 0.9895 loss_val: 0.9912
Epoch: 0050 loss_train: 0.9816 loss_val: 0.9836
Epoch: 0060 loss_train: 0.9736 loss_val: 0.9763
Epoch: 0070 loss_train: 0.9667 loss_val: 0.9705
Epoch: 0080 loss_train: 0.9626 loss_val: 0.9674
Epoch: 0090 loss_train: 0.9604 loss_val: 0.9652
Epoch: 0100 loss_train: 0.9575 loss_val: 0.9648
Epoch: 0110 loss_train: 0.9562 loss_val: 0.9626
Epoch: 0120 loss_train: 0.9551 loss_val: 0.9623
Epoch: 0130 loss_train: 0.9549 loss_val: 0.9623
Epoch: 0140 loss_train: 0.9541 loss_val: 0.9614
Epoch: 0150 loss_train: 0.9534 loss_val: 0.9614
Epoch: 0160 loss_train: 0.9535 loss_val: 0.9609
Epoch: 0170 loss_train: 0.9537 loss_val: 0.9610
Epoch: 0180 loss_train: 0.9519 loss_val: 0.9604
Epoch: 0190 loss_train: 0.9522 loss_val: 0.9601
Epoch: 0200 loss_train: 0.9519 loss_val: 0.9598
Epoch: 0210 loss_train: 0.9528 loss_val:

 29%|██▉       | 5/17 [05:58<14:53, 74.46s/it]

Epoch: 0010 loss_train: 0.9972 loss_val: 1.0078
Epoch: 0020 loss_train: 0.9947 loss_val: 1.0060
Epoch: 0030 loss_train: 0.9871 loss_val: 1.0005
Epoch: 0040 loss_train: 0.9777 loss_val: 0.9916
Epoch: 0050 loss_train: 0.9628 loss_val: 0.9773
Epoch: 0060 loss_train: 0.9494 loss_val: 0.9642
Epoch: 0070 loss_train: 0.9375 loss_val: 0.9530
Epoch: 0080 loss_train: 0.9264 loss_val: 0.9419
Epoch: 0090 loss_train: 0.9164 loss_val: 0.9350
Epoch: 0100 loss_train: 0.9087 loss_val: 0.9273
Epoch: 0110 loss_train: 0.8995 loss_val: 0.9217
Epoch: 0120 loss_train: 0.8931 loss_val: 0.9134
Epoch: 0130 loss_train: 0.8866 loss_val: 0.9067
Epoch: 0140 loss_train: 0.8822 loss_val: 0.9057
Epoch: 0150 loss_train: 0.8788 loss_val: 0.9017
Epoch: 0160 loss_train: 0.8754 loss_val: 0.8996
Epoch: 0170 loss_train: 0.8735 loss_val: 0.8971
Epoch: 0180 loss_train: 0.8700 loss_val: 0.8946
Epoch: 0190 loss_train: 0.8678 loss_val: 0.8948
Epoch: 0200 loss_train: 0.8663 loss_val: 0.8940
Epoch: 0210 loss_train: 0.8661 loss_val:

Total time elapsed: 7.5360s


 35%|███▌      | 6/17 [06:37<11:27, 62.47s/it]

Epoch: 0010 loss_train: 1.0021 loss_val: 1.0032
Epoch: 0020 loss_train: 1.0001 loss_val: 1.0015
Epoch: 0030 loss_train: 0.9946 loss_val: 0.9966
Epoch: 0040 loss_train: 0.9854 loss_val: 0.9883
Epoch: 0050 loss_train: 0.9783 loss_val: 0.9821
Epoch: 0060 loss_train: 0.9733 loss_val: 0.9779
Epoch: 0070 loss_train: 0.9694 loss_val: 0.9747
Epoch: 0080 loss_train: 0.9670 loss_val: 0.9727
Epoch: 0090 loss_train: 0.9660 loss_val: 0.9730
Epoch: 0100 loss_train: 0.9653 loss_val: 0.9724
Epoch: 0110 loss_train: 0.9645 loss_val: 0.9726
Epoch: 0120 loss_train: 0.9645 loss_val: 0.9713
Epoch: 0130 loss_train: 0.9634 loss_val: 0.9709
Epoch: 0140 loss_train: 0.9631 loss_val: 0.9710
Epoch: 0150 loss_train: 0.9629 loss_val: 0.9709
Epoch: 0160 loss_train: 0.9620 loss_val: 0.9704
Epoch: 0170 loss_train: 0.9616 loss_val: 0.9707
Epoch: 0180 loss_train: 0.9623 loss_val: 0.9705
Epoch: 0190 loss_train: 0.9613 loss_val: 0.9708
Epoch: 0200 loss_train: 0.9617 loss_val: 0.9702
Epoch: 0210 loss_train: 0.9617 loss_val:

 41%|████      | 7/17 [07:13<08:56, 53.66s/it]

Epoch: 0010 loss_train: 1.0011 loss_val: 0.9944
Epoch: 0020 loss_train: 0.9938 loss_val: 0.9880
Epoch: 0030 loss_train: 0.9831 loss_val: 0.9783
Epoch: 0040 loss_train: 0.9709 loss_val: 0.9672
Epoch: 0050 loss_train: 0.9596 loss_val: 0.9570
Epoch: 0060 loss_train: 0.9499 loss_val: 0.9484
Epoch: 0070 loss_train: 0.9421 loss_val: 0.9412
Epoch: 0080 loss_train: 0.9371 loss_val: 0.9377
Epoch: 0090 loss_train: 0.9332 loss_val: 0.9347
Epoch: 0100 loss_train: 0.9313 loss_val: 0.9332
Epoch: 0110 loss_train: 0.9311 loss_val: 0.9330
Epoch: 0120 loss_train: 0.9299 loss_val: 0.9318
Epoch: 0130 loss_train: 0.9293 loss_val: 0.9314
Epoch: 0140 loss_train: 0.9286 loss_val: 0.9306
Epoch: 0150 loss_train: 0.9279 loss_val: 0.9307
Epoch: 0160 loss_train: 0.9274 loss_val: 0.9306
Epoch: 0170 loss_train: 0.9267 loss_val: 0.9300
Epoch: 0180 loss_train: 0.9271 loss_val: 0.9304
Epoch: 0190 loss_train: 0.9268 loss_val: 0.9302
Epoch: 0200 loss_train: 0.9267 loss_val: 0.9301
Epoch: 0210 loss_train: 0.9269 loss_val:

 47%|████▋     | 8/17 [08:29<09:06, 60.76s/it]

Epoch: 0010 loss_train: 0.9950 loss_val: 0.9989
Epoch: 0020 loss_train: 0.9871 loss_val: 0.9916
Epoch: 0030 loss_train: 0.9720 loss_val: 0.9757
Epoch: 0040 loss_train: 0.9546 loss_val: 0.9574
Epoch: 0050 loss_train: 0.9432 loss_val: 0.9482
Epoch: 0060 loss_train: 0.9356 loss_val: 0.9409
Epoch: 0070 loss_train: 0.9295 loss_val: 0.9351
Epoch: 0080 loss_train: 0.9268 loss_val: 0.9317
Epoch: 0090 loss_train: 0.9230 loss_val: 0.9287
Epoch: 0100 loss_train: 0.9209 loss_val: 0.9274
Epoch: 0110 loss_train: 0.9200 loss_val: 0.9269
Epoch: 0120 loss_train: 0.9186 loss_val: 0.9263
Epoch: 0130 loss_train: 0.9182 loss_val: 0.9245
Epoch: 0140 loss_train: 0.9176 loss_val: 0.9247
Epoch: 0150 loss_train: 0.9171 loss_val: 0.9238
Epoch: 0160 loss_train: 0.9167 loss_val: 0.9235
Epoch: 0170 loss_train: 0.9168 loss_val: 0.9240
Epoch: 0180 loss_train: 0.9163 loss_val: 0.9243
Epoch: 0190 loss_train: 0.9166 loss_val: 0.9236
Epoch: 0200 loss_train: 0.9161 loss_val: 0.9237
Epoch: 0210 loss_train: 0.9162 loss_val:

 53%|█████▎    | 9/17 [10:11<09:48, 73.60s/it]

Epoch: 0010 loss_train: 0.9976 loss_val: 0.9998
Epoch: 0020 loss_train: 0.9958 loss_val: 0.9981
Epoch: 0030 loss_train: 0.9829 loss_val: 0.9859
Epoch: 0040 loss_train: 0.9684 loss_val: 0.9719
Epoch: 0050 loss_train: 0.9605 loss_val: 0.9641
Epoch: 0060 loss_train: 0.9540 loss_val: 0.9594
Epoch: 0070 loss_train: 0.9498 loss_val: 0.9561
Epoch: 0080 loss_train: 0.9477 loss_val: 0.9549
Epoch: 0090 loss_train: 0.9464 loss_val: 0.9536
Epoch: 0100 loss_train: 0.9453 loss_val: 0.9519
Epoch: 0110 loss_train: 0.9441 loss_val: 0.9520
Epoch: 0120 loss_train: 0.9432 loss_val: 0.9510
Epoch: 0130 loss_train: 0.9424 loss_val: 0.9515
Epoch: 0140 loss_train: 0.9418 loss_val: 0.9516
Epoch: 0150 loss_train: 0.9423 loss_val: 0.9504
Epoch: 0160 loss_train: 0.9416 loss_val: 0.9512
Epoch: 0170 loss_train: 0.9414 loss_val: 0.9511
Epoch: 0180 loss_train: 0.9413 loss_val: 0.9503
Epoch: 0190 loss_train: 0.9410 loss_val: 0.9500
Epoch: 0200 loss_train: 0.9413 loss_val: 0.9515
Epoch: 0210 loss_train: 0.9407 loss_val:

 59%|█████▉    | 10/17 [11:20<08:26, 72.32s/it]

Epoch: 0010 loss_train: 0.9811 loss_val: 1.0306
Epoch: 0020 loss_train: 0.9774 loss_val: 1.0270
Epoch: 0030 loss_train: 0.9688 loss_val: 1.0204
Epoch: 0040 loss_train: 0.9546 loss_val: 1.0085
Epoch: 0050 loss_train: 0.9437 loss_val: 0.9988
Epoch: 0060 loss_train: 0.9332 loss_val: 0.9880
Epoch: 0070 loss_train: 0.9263 loss_val: 0.9802
Epoch: 0080 loss_train: 0.9223 loss_val: 0.9805
Epoch: 0090 loss_train: 0.9201 loss_val: 0.9777
Epoch: 0100 loss_train: 0.9195 loss_val: 0.9727
Epoch: 0110 loss_train: 0.9187 loss_val: 0.9739
Epoch: 0120 loss_train: 0.9182 loss_val: 0.9743
Epoch: 0130 loss_train: 0.9138 loss_val: 0.9731
Epoch: 0140 loss_train: 0.9130 loss_val: 0.9685
Epoch: 0150 loss_train: 0.9123 loss_val: 0.9729
Epoch: 0160 loss_train: 0.9100 loss_val: 0.9704
Epoch: 0170 loss_train: 0.9115 loss_val: 0.9708
Epoch: 0180 loss_train: 0.9115 loss_val: 0.9760
Epoch: 0190 loss_train: 0.9085 loss_val: 0.9689
Epoch: 0200 loss_train: 0.9091 loss_val: 0.9738
Epoch: 0210 loss_train: 0.9104 loss_val:

 65%|██████▍   | 11/17 [11:44<05:44, 57.46s/it]

Epoch: 0010 loss_train: 1.0007 loss_val: 0.9961
Epoch: 0020 loss_train: 0.9970 loss_val: 0.9928
Epoch: 0030 loss_train: 0.9922 loss_val: 0.9883
Epoch: 0040 loss_train: 0.9844 loss_val: 0.9807
Epoch: 0050 loss_train: 0.9786 loss_val: 0.9751
Epoch: 0060 loss_train: 0.9745 loss_val: 0.9716
Epoch: 0070 loss_train: 0.9725 loss_val: 0.9700
Epoch: 0080 loss_train: 0.9717 loss_val: 0.9692
Epoch: 0090 loss_train: 0.9711 loss_val: 0.9688
Epoch: 0100 loss_train: 0.9705 loss_val: 0.9688
Epoch: 0110 loss_train: 0.9701 loss_val: 0.9683
Epoch: 0120 loss_train: 0.9697 loss_val: 0.9679
Epoch: 0130 loss_train: 0.9697 loss_val: 0.9675
Epoch: 0140 loss_train: 0.9692 loss_val: 0.9676
Epoch: 0150 loss_train: 0.9690 loss_val: 0.9670
Epoch: 0160 loss_train: 0.9688 loss_val: 0.9671
Epoch: 0170 loss_train: 0.9685 loss_val: 0.9671
Epoch: 0180 loss_train: 0.9685 loss_val: 0.9668
Epoch: 0190 loss_train: 0.9687 loss_val: 0.9670
Epoch: 0200 loss_train: 0.9682 loss_val: 0.9669
Epoch: 0210 loss_train: 0.9682 loss_val:

 71%|███████   | 12/17 [13:19<05:45, 69.05s/it]

Epoch: 0010 loss_train: 0.9979 loss_val: 1.0081
Epoch: 0020 loss_train: 0.9973 loss_val: 1.0074
Epoch: 0030 loss_train: 0.9957 loss_val: 1.0066
Epoch: 0040 loss_train: 0.9866 loss_val: 1.0004
Epoch: 0050 loss_train: 0.9756 loss_val: 0.9903
Epoch: 0060 loss_train: 0.9698 loss_val: 0.9855
Epoch: 0070 loss_train: 0.9663 loss_val: 0.9833
Epoch: 0080 loss_train: 0.9637 loss_val: 0.9814
Epoch: 0090 loss_train: 0.9616 loss_val: 0.9809
Epoch: 0100 loss_train: 0.9607 loss_val: 0.9806
Epoch: 0110 loss_train: 0.9598 loss_val: 0.9799
Epoch: 0120 loss_train: 0.9591 loss_val: 0.9790
Epoch: 0130 loss_train: 0.9587 loss_val: 0.9797
Epoch: 0140 loss_train: 0.9584 loss_val: 0.9785
Epoch: 0150 loss_train: 0.9578 loss_val: 0.9789
Epoch: 0160 loss_train: 0.9570 loss_val: 0.9779
Epoch: 0170 loss_train: 0.9572 loss_val: 0.9789
Epoch: 0180 loss_train: 0.9568 loss_val: 0.9780
Epoch: 0190 loss_train: 0.9564 loss_val: 0.9784
Epoch: 0200 loss_train: 0.9569 loss_val: 0.9780
Epoch: 0210 loss_train: 0.9564 loss_val:

 76%|███████▋  | 13/17 [13:59<04:00, 60.05s/it]

Epoch: 0010 loss_train: 1.0006 loss_val: 0.9996
Epoch: 0020 loss_train: 0.9987 loss_val: 0.9976
Epoch: 0030 loss_train: 0.9968 loss_val: 0.9959
Epoch: 0040 loss_train: 0.9947 loss_val: 0.9938
Epoch: 0050 loss_train: 0.9919 loss_val: 0.9913
Epoch: 0060 loss_train: 0.9895 loss_val: 0.9891
Epoch: 0070 loss_train: 0.9864 loss_val: 0.9865
Epoch: 0080 loss_train: 0.9844 loss_val: 0.9847
Epoch: 0090 loss_train: 0.9834 loss_val: 0.9843
Epoch: 0100 loss_train: 0.9828 loss_val: 0.9839
Epoch: 0110 loss_train: 0.9821 loss_val: 0.9836
Epoch: 0120 loss_train: 0.9819 loss_val: 0.9833
Epoch: 0130 loss_train: 0.9814 loss_val: 0.9833
Epoch: 0140 loss_train: 0.9814 loss_val: 0.9834
Epoch: 0150 loss_train: 0.9809 loss_val: 0.9831
Epoch: 0160 loss_train: 0.9808 loss_val: 0.9830
Epoch: 0170 loss_train: 0.9805 loss_val: 0.9827
Epoch: 0180 loss_train: 0.9802 loss_val: 0.9830
Epoch: 0190 loss_train: 0.9799 loss_val: 0.9830
Epoch: 0200 loss_train: 0.9802 loss_val: 0.9828
Epoch: 0210 loss_train: 0.9804 loss_val:

 82%|████████▏ | 14/17 [15:14<03:13, 64.55s/it]

Epoch: 0010 loss_train: 0.9990 loss_val: 1.0018
Epoch: 0020 loss_train: 0.9987 loss_val: 1.0015
Epoch: 0030 loss_train: 0.9975 loss_val: 1.0007
Epoch: 0040 loss_train: 0.9905 loss_val: 0.9944
Epoch: 0050 loss_train: 0.9850 loss_val: 0.9895
Epoch: 0060 loss_train: 0.9786 loss_val: 0.9846
Epoch: 0070 loss_train: 0.9740 loss_val: 0.9806
Epoch: 0080 loss_train: 0.9707 loss_val: 0.9785
Epoch: 0090 loss_train: 0.9681 loss_val: 0.9779
Epoch: 0100 loss_train: 0.9675 loss_val: 0.9772
Epoch: 0110 loss_train: 0.9664 loss_val: 0.9772
Epoch: 0120 loss_train: 0.9655 loss_val: 0.9765
Epoch: 0130 loss_train: 0.9651 loss_val: 0.9762
Epoch: 0140 loss_train: 0.9642 loss_val: 0.9770
Epoch: 0150 loss_train: 0.9645 loss_val: 0.9770
Epoch: 0160 loss_train: 0.9642 loss_val: 0.9765
Epoch: 0170 loss_train: 0.9634 loss_val: 0.9768
Epoch: 0180 loss_train: 0.9633 loss_val: 0.9763
Epoch: 0190 loss_train: 0.9630 loss_val: 0.9769
Epoch: 0200 loss_train: 0.9633 loss_val: 0.9764
Epoch: 0210 loss_train: 0.9624 loss_val:

 88%|████████▊ | 15/17 [15:31<01:40, 50.24s/it]

Epoch: 0010 loss_train: 0.9968 loss_val: 1.0028
Epoch: 0020 loss_train: 0.9965 loss_val: 1.0026
Epoch: 0030 loss_train: 0.9960 loss_val: 1.0029
Epoch: 0040 loss_train: 0.9940 loss_val: 1.0029
Epoch: 0050 loss_train: 0.9883 loss_val: 0.9999
Epoch: 0060 loss_train: 0.9798 loss_val: 0.9978
Epoch: 0070 loss_train: 0.9737 loss_val: 0.9921
Epoch: 0080 loss_train: 0.9692 loss_val: 0.9912
Epoch: 0090 loss_train: 0.9647 loss_val: 0.9896
Epoch: 0100 loss_train: 0.9614 loss_val: 0.9914
Epoch: 0110 loss_train: 0.9577 loss_val: 0.9900
Epoch: 0120 loss_train: 0.9553 loss_val: 0.9915
Epoch: 0130 loss_train: 0.9526 loss_val: 0.9916
Epoch: 0140 loss_train: 0.9498 loss_val: 0.9925
Epoch: 0150 loss_train: 0.9460 loss_val: 0.9921
Epoch: 0160 loss_train: 0.9479 loss_val: 0.9910
Epoch: 0170 loss_train: 0.9460 loss_val: 0.9930
Epoch: 0180 loss_train: 0.9458 loss_val: 0.9900
Epoch: 0190 loss_train: 0.9428 loss_val: 0.9936
Epoch: 0200 loss_train: 0.9437 loss_val: 0.9919
Epoch: 0210 loss_train: 0.9453 loss_val:

 94%|█████████▍| 16/17 [16:13<00:47, 47.75s/it]

Epoch: 0010 loss_train: 1.0009 loss_val: 0.9967
Epoch: 0020 loss_train: 0.9993 loss_val: 0.9953
Epoch: 0030 loss_train: 0.9901 loss_val: 0.9874
Epoch: 0040 loss_train: 0.9691 loss_val: 0.9659
Epoch: 0050 loss_train: 0.9549 loss_val: 0.9516
Epoch: 0060 loss_train: 0.9452 loss_val: 0.9435
Epoch: 0070 loss_train: 0.9394 loss_val: 0.9386
Epoch: 0080 loss_train: 0.9366 loss_val: 0.9354
Epoch: 0090 loss_train: 0.9347 loss_val: 0.9343
Epoch: 0100 loss_train: 0.9332 loss_val: 0.9334
Epoch: 0110 loss_train: 0.9318 loss_val: 0.9323
Epoch: 0120 loss_train: 0.9309 loss_val: 0.9314
Epoch: 0130 loss_train: 0.9303 loss_val: 0.9300
Epoch: 0140 loss_train: 0.9295 loss_val: 0.9305
Epoch: 0150 loss_train: 0.9292 loss_val: 0.9302
Epoch: 0160 loss_train: 0.9292 loss_val: 0.9288
Epoch: 0170 loss_train: 0.9283 loss_val: 0.9311
Epoch: 0180 loss_train: 0.9283 loss_val: 0.9290
Epoch: 0190 loss_train: 0.9286 loss_val: 0.9279
Epoch: 0200 loss_train: 0.9279 loss_val: 0.9293
Epoch: 0210 loss_train: 0.9282 loss_val:

100%|██████████| 17/17 [17:25<00:00, 61.50s/it]


In [11]:
pd.DataFrame(res)

Unnamed: 0,0,1,2
0,0.516488,0.664125,0.671594
1,0.704505,0.686349,0.686679
2,0.3746,0.543351,0.546318
3,0.702301,0.726596,0.727581
4,0.563174,0.625798,0.632598
5,0.095591,0.194107,0.219933
6,0.304139,0.427557,0.431296
7,0.036971,0.049754,0.052649
8,0.628163,0.685629,0.693886
9,0.4562,0.574223,0.576181
