# Ablations and embedding quality evaluation

## UNAGI w.o GAN and GCN

In [None]:
import warnings
import torch
import numpy as np
import random
import os
warnings.filterwarnings('ignore')
from UNAGI import UNAGI
for i in range(15):
    torch.manual_seed(i)
    np.random.seed(i)
    random.seed(i)
    unagi = UNAGI()

    print('seed %d'%i)
    os.system('cp -r ../data/mes_raw ../data/plain_ziln_CPO_seed_%d'%i)
    unagi.setup_data('../data/plain_ziln_CPO_seed_%d'%i,total_stage=4,stage_key='stage')
    unagi.setup_training(task='plain_ziln_CPO_seed_%d'%i,dist='ziln',device='cuda:0',GPU=True,epoch_iter=0,epoch_initial=10,max_iter=1,BATCHSIZE=1024,GCN=False,adversarial=False)
    unagi.run_UNAGI(idrem_dir = 'PATH_TO_IDREM')

## UNAGI w.o GCN

In [None]:
import warnings
import torch
import numpy as np
import random
import os
warnings.filterwarnings('ignore')
from UNAGI import UNAGI
for i in range(15):
    torch.manual_seed(i)
    np.random.seed(i)
    random.seed(i)
    unagi = UNAGI()

    print('seed %d'%i)
    os.system('cp -r ../data/mes_raw ../data/gan_ziln_CPO_seed_%d'%i)
    unagi.setup_data('../data/gan_ziln_CPO_seed_%d'%i,total_stage=4,stage_key='stage')
    unagi.setup_training(task='gan_ziln_CPO_seed_%d'%i,dist='ziln',device='cuda:0',GPU=True,epoch_iter=0,epoch_initial=10,max_iter=1,BATCHSIZE=1024,GCN=False,adversarial=True)
    unagi.run_UNAGI(idrem_dir = 'PATH_TO_IDREM',CPO=False)

## UNAGI w.o. GAN

In [None]:
import warnings
import torch
import numpy as np
import random
import os
warnings.filterwarnings('ignore')
from UNAGI import UNAGI
for i in range(15):
    torch.manual_seed(i)
    np.random.seed(i)
    random.seed(i)
    unagi = UNAGI()

    print('seed %d'%i)
    os.system('cp -r ../data/mes_raw ../data/gcn_ziln_CPO_seed_%d'%i)
    unagi.setup_data('../data/gcn_ziln_CPO_seed_%d'%i,total_stage=4,stage_key='stage')
    unagi.setup_training(task='gcn_ziln_CPO_seed_%d'%i,dist='ziln',device='cuda:0',GPU=True,epoch_iter=0,epoch_initial=10,max_iter=1,BATCHSIZE=1024,GCN=True,adversarial=False)
    unagi.run_UNAGI(idrem_dir = '/mnt/md0/yumin/to_upload/idrem')

## Evaluate embedding quality

In [None]:
import scanpy as sc
from UNAGI.utils import evaluate
import pickle
import concurrent.futures

# Function to process each seed
def process_seed(i):
    prefix = 'plain_ziln_CPO'
    target_dir = '../data/'
    file_name = prefix + '_seed_' + str(i)
    adata = sc.read(target_dir + file_name + '/0/stagedata/org_dataset.h5ad')

    aris, NMI,DBI,label_scores, silhouettes, isolated_asws, cell_type_asws, isolated_labels_f1s,clisi_graphs, overall_scibs = evaluate.run_metrics(adata, 'ident', 'stage')
    
    with open(target_dir + file_name+'.txt','w') as f:
        f.write('ari\t ' + str(aris) + '\n')
        f.write('nmi\t ' + str(NMI) + '\n')
        f.write('dbi\t ' + str(DBI) + '\n')
        f.write('label_score\t ' + str(label_scores) + '\n')
        f.write('sil\t ' + str(silhouettes) + '\n')
        f.write('isolated_asw\t ' + str(isolated_asws) + '\n')
        f.write('cell_type_asw\t ' + str(cell_type_asws) + '\n')
        f.write('isolated_labels_f1\t ' + str(isolated_labels_f1s) + '\n')
        f.write('clisi_graphs\t ' + str(clisi_graphs) + '\n')
        f.write('overall_scibs\t ' + str(overall_scibs) + '\n')


    return aris, NMI,DBI,label_scores, silhouettes, isolated_asws, cell_type_asws, isolated_labels_f1s,clisi_graphs, overall_scibs


def main():
    prefix = 'plain_ziln_CPO'
    target_dir = '../data/'
    aris = []
    nmis = []
    dbis = []
    label_socres = []
    sils = []
    isolated_asws = []
    cell_type_asws = []
    isolated_labels_f1s = []
    clisi_graphs = []
    overall_scibs = []
    
    # Using ProcessPoolExecutor to parallelize
    with concurrent.futures.ProcessPoolExecutor(max_workers=3) as executor:
        results = list(executor.map(process_seed, range(15)))

    # Unpack results
    for result in results:
        ari, nmi,dbi, label_socre,sil, isolated_asw, cell_type_asw, isolated_labels_f1,clisi_graph,overall_scib = result
        aris.append(ari)
        nmis.append(nmi)
        dbis.append(dbi)
        label_socres.append(label_socre)
        sils.append(sil)
        isolated_asws.append(isolated_asw)
        cell_type_asws.append(cell_type_asw)
        isolated_labels_f1s.append(isolated_labels_f1)
        clisi_graphs.append(clisi_graph)
        overall_scibs.append(overall_scib)

    # Write results to file
    with open(target_dir + prefix + '_metrics.txt', 'w') as f:
        f.write('ari\t ' + str(aris) + '\n')
        f.write('nmi\t ' + str(nmis) + '\n')
        f.write('dbi\t ' + str(dbis) + '\n')
        f.write('label_score\t ' + str(label_socres) + '\n')
        f.write('sil\t ' + str(sils) + '\n')
        f.write('isolated_asw\t ' + str(isolated_asws) + '\n')
        f.write('cell_type_asw\t ' + str(cell_type_asws) + '\n')
        f.write('isolated_labels_f1\t ' + str(isolated_labels_f1s) + '\n')
        f.write('clisi_graphs\t ' + str(clisi_graphs) + '\n')
        f.write('overall_scibs\t ' + str(overall_scibs) + '\n')

if __name__ == '__main__':
    main()
