# Build positive simulation dataset

In [None]:
import scanpy as sc
import os
import pickle
import numpy as np
from scipy.sparse import csr_matrix
def manipulate(adata,genes,lg2fc):
    if 'raw' not in adata.layers.keys():
        adata.layers['raw'] = adata.X.copy().toarray()
    else:
        adata.X = adata.layers['raw'].copy().toarray()
    adata.layers['simu'] = adata.X.copy()
    adata.obs = adata.obs.reset_index()
    gene_indices = []
    tempX = adata.layers['simu'].toarray()
    np.random.shuffle(tempX)
    adata.layers['simu'] = tempX
    for i in list(adata.obs['stage'].unique()):
        temp = adata.obs[adata.obs['stage'] == i].index.tolist()
        temptemp = adata.obs[(adata.obs['stage'] == i)& (adata.obs['name.simple'].str.startswith('Fibroblast'))].index.tolist()
        for each in genes:
            direction = each.split(':')[1]
            each = each.split(':')[0]
            gene_index = adata.var.index.tolist().index(each)
            gene_indices.append(gene_index)
            if direction == '+':
                adata.layers['simu'][temptemp,gene_index] += (4-int(i))*(lg2fc)
            else:
                adata.layers['simu'][temptemp,gene_index] += (int(i)+1)*(lg2fc)

        #if gene is not gene_index, add a gaussian noise to adata.layers['simu'][temp,:]
    all_genes = adata.var.index.tolist()
    #remove the genes that are not in the gene_indices
    gene_indices = set(gene_indices)
    gene_indices = list(gene_indices)
    all_genes = [i for i in range(len(all_genes))]
    for each in gene_indices:
        all_genes.remove(each)
    adata.layers['simu'][adata.layers['simu']<0] = 0 
    adata.X = csr_matrix(adata.layers['simu'])
    adata.obs.set_index('index',inplace=True)
    return adata

adata = sc.read_h5ad('../to_published/mes_4/dataset.h5ad')
candidate = np.load('../data/lowest25_drug_simulation_candidates.npy',allow_pickle=True)
targets = np.load('../data/lowest25_drug_simualtion_targets.npy',allow_pickle=True)

changes = [0.2,0.3,0.4]
random_s = 0
for i in range(0,len(candidate)):
    
    for change in changes:
        random_s +=1
        np.random.seed(random_s)
        adata_copy_copy = adata.copy()
        os.mkdir('../data/drug_simulation_positive/drug_%d_change_%.2f'%(i,change))
        adata_copy = manipulate(adata_copy_copy,targets[i],change)
        
        del adata_copy.layers
        del adata_copy.uns
        del adata_copy.obsp
        adata_copy.write('../data/drug_simulation_positive/drug_%d_change_%.2f/dataset.h5ad'%(i,change),compression='gzip',compression_opts=9)
        np.save('../data/drug_simulation_positive/drug_%d_change_%.2f/genes.npy'%(i,change),np.array(targets[i],dtype='object'))
            

# Build negative simulation dataset

In [None]:
import scanpy as sc
import os
import pickle
import numpy as np
from scipy.sparse import csr_matrix
def manipulate(adata,genes,lg2fc):
    if 'raw' not in adata.layers.keys():
        adata.layers['raw'] = adata.X.copy().toarray()
    else:
        adata.X = adata.layers['raw'].copy().toarray()
    adata.layers['simu'] = adata.X.copy()
    adata.obs = adata.obs.reset_index()
    gene_indices = []
    tempX = adata.layers['simu'].toarray()
    np.random.shuffle(tempX)
    adata.layers['simu'] = tempX
    for i in list(adata.obs['stage'].unique()):
        templg2fc = np.random.normal(0, 0.1*lg2fc)
        # stage == i and name.simple=='FibroblastAveolar'
        temp = adata.obs[adata.obs['stage'] == i].index.tolist()
       
        temptemp = adata.obs[(adata.obs['stage'] == i)& (adata.obs['name.simple'].str.startswith('Fibroblast'))].index.tolist()
        
        for each in genes:
            
            direction = each.split(':')[1]
            each = each.split(':')[0]
            gene_index = adata.var.index.tolist().index(each)
            gene_indices.append(gene_index)
            if direction == '+':
                #build a vector with the same shape as adata.layers['simu'][temp,gene_index] and assign 0 or 1 randomly to each element
                adata.layers['simu'][temptemp,gene_index] += np.random.normal(0, lg2fc*lg2fc/900,adata.layers['simu'][temptemp,gene_index].shape)
               
            else:
            adata.layers['simu'][temptemp,gene_index] += np.random.normal(0, lg2fc*lg2fc/900,adata.layers['simu'][temptemp,gene_index].shape)
           
    adata.layers['simu'][adata.layers['simu']<0] = 0 
    all_genes = adata.var.index.tolist()

    gene_indices = set(gene_indices)
    gene_indices = list(gene_indices)
    all_genes = [i for i in range(len(all_genes))]
    for each in gene_indices:
        all_genes.remove(each)
   
    adata.X = csr_matrix(adata.layers['simu'])
    adata.obs.set_index('index',inplace=True)
    return adata

adata = sc.read_h5ad('../to_published/mes_4/dataset.h5ad')
candidate = np.load('lowest25_drug_simulation_candidates.npy',allow_pickle=True)
targets = np.load('lowest25_drug_simualtion_targets.npy',allow_pickle=True)
changes = [0.2,0.3,0.4]
random_s = 0
for i in range(len(candidate)):
    
    for change in changes:
        random_s +=1
        np.random.seed(random_s)
        adata_copy_copy = adata.copy()
        
        os.mkdir('../data/drug_simulation_negative/drug_%d_change_%.2f'%(i,change))
        
        adata_copy = manipulate(adata_copy_copy,targets[i],change)
        
        del adata_copy.layers
        del adata_copy.uns
        del adata_copy.obsp
        adata_copy.write('../data/drug_simulation_negative/drug_%d_change_%.2f/dataset.h5ad'%(i,change),compression='gzip',compression_opts=9)
        np.save('../data/drug_simulation_negative/drug_%d_change_%.2f/genes.npy'%(i,change),np.array(targets[i],dtype='object'))
        print('drug_simulation_negative')
        

# Calculate simulation results on the positive datasets

In [None]:
from scipy.stats.distributions import norm
import os
import sys
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_dir)
import pickle
import torch
import gc
import scanpy as sc
import pandas as pd
import os
import numpy as np
import random
from gcn_utilis import setup_graph
from UNAGI.model.models import VAE
import warnings
warnings.filterwarnings("ignore")
def translate_direction(input, log2fc):
    out= ''
    flag = False
    if log2fc <1:
        log2fc = 1/log2fc
    for each in input:
        if flag == True:
            out+=','
        flag  = True
        each = each.split(':')
        if each[1] == '+':
            out+=str(each[0]+':'+str(log2fc))

        else:
            out+=str(each[0]+':'+str(1/log2fc))
    return out
def reverse_translate_direction(input, log2fc):
    out= ''
    flag=False
    if log2fc <1:
        log2fc = 1/log2fc
    for each in input:
        if flag == True:
            out+=','
        flag  = True
        each = each.split(':')
        if each[1] == '-':
            out+=str(each[0]+':'+str(log2fc))
        else:
            out+=str(each[0]+':'+str(1/log2fc))
    return out
def in_silico_perturbation(adata,direction):
    temp_X = adata.X.copy()
    temp_X = temp_X.toarray()
    direction = direction.split(',')
    for each in direction:
        gene = each.split(':')[0]
        fold_change = float(each.split(':')[1])-1
        idx = list(adata.var.index).index(gene)
        temp_X[:,idx] += temp_X[:,idx]*fold_change
    adata.X = temp_X
    temp_X = None
    gc.collect()
    return adata 
def getDescendants(tempcluster,stage,edges):
    out = []
    for each in tempcluster:
        
        for item in edges[str(stage-1)]:
        
            if each == item[0]:
                
                out.append(item[1])
    return out

def getTrack(idrem_dir):
    path = idrem_dir#os.path.join('')
    filenames = os.listdir(path) #defalut path
    tracks = [[] for _ in range(len(filenames))]

    for i, each in enumerate(filenames):
        temp = each.split('.')[0].split('-')
        for item in temp:
            temp1 = item.split('n')
            tracks[i].append(temp1)
    return tracks
    
def unagi_background_perturabtion(adata_in,model):
    control = adata_in[adata_in.obs['stage'] == '0']
    stage1 = adata_in[adata_in.obs['stage'] == '1']
    stage2 = adata_in[adata_in.obs['stage'] == '2']
    stage3 = adata_in[adata_in.obs['stage'] == '3']

    model.eval()
    control_adj = control.obsp['gcn_connectivities']
    control_adj = control_adj.asformat('coo')
    control_adj = setup_graph(control_adj)
    control_adj = control_adj.to('cuda:2')
    stage1_adj = stage1.obsp['gcn_connectivities']
    stage1_adj = stage1_adj.asformat('coo')
    stage1_adj = setup_graph(stage1_adj)
    stage1_adj = stage1_adj.to('cuda:2')
    stage2_adj = stage2.obsp['gcn_connectivities']
    stage2_adj = stage2_adj.asformat('coo')
    stage2_adj = setup_graph(stage2_adj)
    stage2_adj = stage2_adj.to('cuda:2')
    stage3_adj = stage3.obsp['gcn_connectivities']
    stage3_adj = stage3_adj.asformat('coo')
    stage3_adj = setup_graph(stage3_adj)
    stage3_adj = stage3_adj.to('cuda:2')
    raw_control_latnet,_, _,_,_ = model.getZ(torch.tensor(np.array(control.X.toarray())).to('cuda:2'),control_adj,1,0,len(control),test=False)
    raw_stage1_latnet,_, _,_,_ = model.getZ(torch.tensor(np.array(stage1.X.toarray())).to('cuda:2'),stage1_adj,1,0,len(stage1),test=False)
    raw_stage2_latent,_, _,_,_ = model.getZ(torch.tensor(np.array(stage2.X.toarray())).to('cuda:2'),stage2_adj,1,0,len(stage2),test=False)
    raw_stage3_latent,_, _,_,_ = model.getZ(torch.tensor(np.array(stage3.X.toarray())).to('cuda:2'),stage3_adj,1,0,len(stage3),test=False)
    new_control = sc.AnnData(X = raw_control_latnet.cpu().detach().numpy(),obs = control.obs)
    new_stage1 = sc.AnnData(X = raw_stage1_latnet.cpu().detach().numpy(),obs = stage1.obs)
    new_stage2 = sc.AnnData(X = raw_stage2_latent.cpu().detach().numpy(),obs = stage2.obs)
    new_stage3 = sc.AnnData(X = raw_stage3_latent.cpu().detach().numpy(),obs = stage3.obs)

    return new_control, new_stage1, new_stage2, new_stage3
def unagi_stage_perturabtion(adata_in,model, direction1, direction2,stage):
    
    perturbed = adata_in[adata_in.obs['stage'] == str(stage)]
    model.eval()
    updated_ipf1 = in_silico_perturbation(perturbed.copy(),direction1)
    updated_ipf2 = in_silico_perturbation(perturbed.copy(),direction2)
    
    updated_ipf1_adj = updated_ipf1.obsp['gcn_connectivities']
    updated_ipf1_adj = updated_ipf1_adj.asformat('coo')
    updated_ipf1_adj = setup_graph(updated_ipf1_adj)
    updated_ipf1_adj = updated_ipf1_adj.to('cuda:2')
    updated_ipf2_adj = updated_ipf2.obsp['gcn_connectivities']
    updated_ipf2_adj = updated_ipf2_adj.asformat('coo')
    updated_ipf2_adj = setup_graph(updated_ipf2_adj)
    updated_ipf2_adj = updated_ipf2_adj.to('cuda:2')
    updated_ipf1_latent,_, _,_,_ = model.getZ(torch.tensor(np.array(updated_ipf1.X)).to('cuda:2'),updated_ipf1_adj,1,0,len(updated_ipf1),test=False)
    updated_ipf2_latent,_, _,_,_ = model.getZ(torch.tensor(np.array(updated_ipf2.X)).to('cuda:2'),updated_ipf2_adj,1,0,len(updated_ipf2),test=False)

    

    updated_ipf1 = sc.AnnData(X = updated_ipf1_latent.cpu().detach().numpy(),obs = updated_ipf1.obs)
    updated_ipf2 = sc.AnnData(X = updated_ipf2_latent.cpu().detach().numpy(),obs = updated_ipf2.obs)

    updated_ipf2_latent = None
    updated_ipf1_latent = None
    gc.collect()
    return updated_ipf1, updated_ipf2
def unagi_perturabtion(adata, model,direction1, direction2,tracks):
    
    raw_control_latnet, raw_stage1_latnet, raw_stage2_latent, raw_stage3_latent = unagi_background_perturabtion(adata,model)
    temp = [raw_control_latnet, raw_stage1_latnet, raw_stage2_latent, raw_stage3_latent]
    subsets = {}
    subsets_cells = {}
    scores = {}
    for i in range(len(adata.obs['stage'].unique())):
        
        updated_ipf1_latent, updated_ipf2_latent = unagi_stage_perturabtion(adata,model, direction1, direction2,i)
        for track_id, each in enumerate(tracks):
            if track_id not in list(subsets_cells.keys()):
                subsets_cells[track_id] = 0
            if track_id not in list(subsets.keys()):
                subsets[track_id] = []
            if track_id not in list(scores.keys()):
                scores[track_id] = []
            for stage, clusters in enumerate(each):

                temp_adata = temp[stage][temp[stage].obs['leiden'].isin(clusters)]
                
                subsets[track_id].append(temp_adata)
                if i == stage:
                    subset_updated_ipf1_latent = updated_ipf1_latent[updated_ipf1_latent.obs['leiden'].isin(clusters)]
                    subset_updated_ipf2_latent = updated_ipf2_latent[updated_ipf2_latent.obs['leiden'].isin(clusters)]
            subsets_cells[track_id] = len(subsets[track_id][0])+len(subsets[track_id][1])+len(subsets[track_id][2])+len(subsets[track_id][3])
            score = calcualte_distance_changes(subsets[track_id][0], subsets[track_id][1], subsets[track_id][2], subsets[track_id][3],subset_updated_ipf1_latent,subset_updated_ipf2_latent,i)
            scores[track_id].append(score)
    np_scores = np.array(list(scores.values()))

    total_cells = np.sum(np.array(list(subsets_cells.values())))
    np_scores = np.mean(np_scores,axis=1)

    total_score = np.sum(np_scores*(np.array(list(subsets_cells.values()))/total_cells))
    return total_score
def getDistance(rep, cluster):

    # cluster = cluster.reshape(1,-1)
    # cluster = cluster.repeat(rep.shape[0],axis=0)
    rep = np.mean(rep,axis=0)
    cluster = np.mean(cluster,axis=0)

    return np.linalg.norm(rep-cluster)
def calculateScore(delta,flag,weight=100):
    '''
    Calculate the perturbation score.

    parameters
    -----------
    delta: float
        The perturbation distance.(D(Perturbed cluster, others stages)  - D(Original cluster, others stages)  (in z space))
    flag: int
        The stage of the time-series single-cell data.
    weight: float
        The weight to control the perturbation score.

    return
    --------
    out: float
        The perturbation score.
    '''
    out = 0
    out1 = 0
    for i, each in enumerate(delta):
        
        if i != flag:
            out+=(1-1/(1+np.exp(weight*each*np.sign(i-flag)))-0.5)/0.5

    return out/(len(delta)-1)#, out1

def calcualte_distance_changes(control, stage1,stage2,stage3, updated_direction1, updated_direction2,stage):
    control = np.array(control.X.toarray())
    stage1 = np.array(stage1.X.toarray())
    stage2 = np.array(stage2.X.toarray())
    stage3 = np.array(stage3.X.toarray())
    reps = [control,stage1,stage2,stage3]
    raw_distance = []
    for each in reps:
        raw_distance.append(getDistance(reps[stage], each))

    updated_direction1 = np.array(updated_direction1.X)
    updated_direction2 = np.array(updated_direction2.X)
    direciton1_distance = []
    direciton2_distance = []
    for each in reps:
        direciton1_distance.append(getDistance(updated_direction1, each))
        direciton2_distance.append(getDistance(updated_direction2, each))
    delta1 = np.array(direciton1_distance) - np.array(raw_distance)
    delta2 = np.array(direciton2_distance) - np.array(raw_distance)
    score1 = calculateScore(delta1,stage)
    score2 = calculateScore(delta2,stage)
    final_score = np.abs(score1-score2)/2

    return final_score
def unagi_random(adata,model,times=1000):
    background_score = []
    for time in times:
        random_genes = random.sample(list(adata.var.index),2)
        random_perturbed_genes = []
        for each in random_genes:
            temp =''
            direction = random.choice(['+','-'])
            temp = each+':'+direction
            random_perturbed_genes.append(temp)
        extend = random.uniform(2,5)
        direction1 = translate_direction(random_perturbed_genes, extend)
        direction2 = reverse_translate_direction(random_perturbed_genes, extend)
        raw_control_latnet, raw_ipf_latnet, updated_ipf1_latent, updated_ipf2_latent = unagi_perturabtion(adata,model, direction1, direction2)
        score = calcualte_distance_changes(raw_control_latnet, raw_ipf_latnet, updated_ipf1_latent, updated_ipf2_latent)
        background_score.append(score)
    np.save('../data/unagi_background_score.npy',background_score)
def unagi_real():
    real_scores = []
    drug_groups = os.listdir('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_positive/')
    all_adata = sc.read('/mnt/md0/yumin/to_published/mes_4/dataset.h5ad')

    background = np.load('/mnt/md0/yumin/UNAGI_revision/data/unagi_background_score_feb27.npy',allow_pickle=True)
    tracks = getTrack('/mnt/md0/yumin/to_published/mes_4/idrem')
    drug_groups = sorted(drug_groups)
    for each in drug_groups:
        
        if 'fine_tune.pth' not in os.listdir('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_positive/%s/'%(each)):
            continue
        print(each)
        adata = sc.read('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_positive/%s/dataset.h5ad'%(each))
        adata.obsp['gcn_connectivities'] = all_adata.obsp['gcn_connectivities']
        target = np.load('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_positive/%s/genes.npy'%(each),allow_pickle=True)
        adata.obs['condition'] = None
        if 'fine_tune.pth' not in os.listdir('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_positive/%s/'%(each)):
            continue
        model = VAE(len(adata.var.index), dimZ, dimG, 0.5)
        model = model.to('cuda:2')
        model.load_state_dict(torch.load('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_positive/%s/fine_tune.pth'%(each),map_location='cuda:2'))
       
        for extend in np.arange(1.1,2.1,0.2):
            print(extend)
            direction1 = translate_direction(target, extend)
            direction2 = reverse_translate_direction(target, extend)
            scores = unagi_perturabtion(adata,model, direction1, direction2,tracks)
            real_scores.append(scores)
            print('real_scores:',scores)
            #calculate pval for each score
        pval = 1-norm.cdf(np.array(real_scores), loc=np.mean(background), scale=np.std(background))
        print(pval)
        # print(gdsg)
        np.save('/mnt/md0/yumin/UNAGI_revision/data/unagi_positive_perturbation_score.npy',real_scores)


if __name__ == '__main__':
    dimZ=64
    dimG = 0
    unagi_real()
 

# Calculate simulation results on the negative datasets

In [None]:
import pickle
from scipy.stats import norm
import os
import sys
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_dir)
import torch
import gc
import scanpy as sc
import pandas as pd
import os
import numpy as np
import random
from gcn_utilis import setup_graph
from UNAGI.model.models import VAE
import warnings

warnings.filterwarnings("ignore")
def translate_direction(input, log2fc):
    out= ''
    flag = False
    if log2fc <1:
        log2fc = 1/log2fc
    for each in input:
        if flag == True:
            out+=','
        flag  = True
        each = each.split(':')
        if each[1] == '+':
            out+=str(each[0]+':'+str(log2fc))

        else:
            out+=str(each[0]+':'+str(1/log2fc))
    return out
def reverse_translate_direction(input, log2fc):
    out= ''
    flag=False
    if log2fc <1:
        log2fc = 1/log2fc
    for each in input:
        if flag == True:
            out+=','
        flag  = True
        each = each.split(':')
        if each[1] == '-':
            out+=str(each[0]+':'+str(log2fc))
        else:
            out+=str(each[0]+':'+str(1/log2fc))
    return out
def in_silico_perturbation(adata,direction):
    temp_X = adata.X.copy()
    temp_X = temp_X.toarray()
    direction = direction.split(',')
    for each in direction:
        gene = each.split(':')[0]
        fold_change = float(each.split(':')[1])-1
        idx = list(adata.var.index).index(gene)
        temp_X[:,idx] += temp_X[:,idx]*fold_change
    adata.X = temp_X
    temp_X = None
    gc.collect()
    return adata 
def getDescendants(tempcluster,stage,edges):
    out = []
    for each in tempcluster:
        
        for item in edges[str(stage-1)]:
        
            if each == item[0]:
                
                out.append(item[1])
    return out

def getTrack(idrem_dir):
    path = idrem_dir#os.path.join('')
    filenames = os.listdir(path) #defalut path
    tracks = [[] for _ in range(len(filenames))]

    for i, each in enumerate(filenames):
        temp = each.split('.')[0].split('-')
        for item in temp:
            temp1 = item.split('n')
            tracks[i].append(temp1)
    return tracks
    
def unagi_background_perturabtion(adata_in,model):
    control = adata_in[adata_in.obs['stage'] == '0']
    stage1 = adata_in[adata_in.obs['stage'] == '1']
    stage2 = adata_in[adata_in.obs['stage'] == '2']
    stage3 = adata_in[adata_in.obs['stage'] == '3']

    model.eval()
    control_adj = control.obsp['gcn_connectivities']
    control_adj = control_adj.asformat('coo')
    control_adj = setup_graph(control_adj)
    control_adj = control_adj.to('cuda:2')
    stage1_adj = stage1.obsp['gcn_connectivities']
    stage1_adj = stage1_adj.asformat('coo')
    stage1_adj = setup_graph(stage1_adj)
    stage1_adj = stage1_adj.to('cuda:2')
    stage2_adj = stage2.obsp['gcn_connectivities']
    stage2_adj = stage2_adj.asformat('coo')
    stage2_adj = setup_graph(stage2_adj)
    stage2_adj = stage2_adj.to('cuda:2')
    stage3_adj = stage3.obsp['gcn_connectivities']
    stage3_adj = stage3_adj.asformat('coo')
    stage3_adj = setup_graph(stage3_adj)
    stage3_adj = stage3_adj.to('cuda:2')
    raw_control_latnet,_, _,_,_ = model.getZ(torch.tensor(np.array(control.X.toarray())).to('cuda:2'),control_adj,1,0,len(control),test=False)
    raw_stage1_latnet,_, _,_,_ = model.getZ(torch.tensor(np.array(stage1.X.toarray())).to('cuda:2'),stage1_adj,1,0,len(stage1),test=False)
    raw_stage2_latent,_, _,_,_ = model.getZ(torch.tensor(np.array(stage2.X.toarray())).to('cuda:2'),stage2_adj,1,0,len(stage2),test=False)
    raw_stage3_latent,_, _,_,_ = model.getZ(torch.tensor(np.array(stage3.X.toarray())).to('cuda:2'),stage3_adj,1,0,len(stage3),test=False)
    new_control = sc.AnnData(X = raw_control_latnet.cpu().detach().numpy(),obs = control.obs)
    new_stage1 = sc.AnnData(X = raw_stage1_latnet.cpu().detach().numpy(),obs = stage1.obs)
    new_stage2 = sc.AnnData(X = raw_stage2_latent.cpu().detach().numpy(),obs = stage2.obs)
    new_stage3 = sc.AnnData(X = raw_stage3_latent.cpu().detach().numpy(),obs = stage3.obs)

    return new_control, new_stage1, new_stage2, new_stage3
def unagi_stage_perturabtion(adata_in,model, direction1, direction2,stage):
    
    perturbed = adata_in[adata_in.obs['stage'] == str(stage)]
    model.eval()
    updated_ipf1 = in_silico_perturbation(perturbed.copy(),direction1)
    updated_ipf2 = in_silico_perturbation(perturbed.copy(),direction2)
    
    updated_ipf1_adj = updated_ipf1.obsp['gcn_connectivities']
    updated_ipf1_adj = updated_ipf1_adj.asformat('coo')
    updated_ipf1_adj = setup_graph(updated_ipf1_adj)
    updated_ipf1_adj = updated_ipf1_adj.to('cuda:2')
    updated_ipf2_adj = updated_ipf2.obsp['gcn_connectivities']
    updated_ipf2_adj = updated_ipf2_adj.asformat('coo')
    updated_ipf2_adj = setup_graph(updated_ipf2_adj)
    updated_ipf2_adj = updated_ipf2_adj.to('cuda:2')
    updated_ipf1_latent,_, _,_,_ = model.getZ(torch.tensor(np.array(updated_ipf1.X)).to('cuda:2'),updated_ipf1_adj,1,0,len(updated_ipf1),test=False)
    updated_ipf2_latent,_, _,_,_ = model.getZ(torch.tensor(np.array(updated_ipf2.X)).to('cuda:2'),updated_ipf2_adj,1,0,len(updated_ipf2),test=False)

    

    updated_ipf1 = sc.AnnData(X = updated_ipf1_latent.cpu().detach().numpy(),obs = updated_ipf1.obs)
    updated_ipf2 = sc.AnnData(X = updated_ipf2_latent.cpu().detach().numpy(),obs = updated_ipf2.obs)

    updated_ipf2_latent = None
    updated_ipf1_latent = None
    gc.collect()
    return updated_ipf1, updated_ipf2
def unagi_perturabtion(adata, model,direction1, direction2,tracks):
    
    raw_control_latnet, raw_stage1_latnet, raw_stage2_latent, raw_stage3_latent = unagi_background_perturabtion(adata,model)
    temp = [raw_control_latnet, raw_stage1_latnet, raw_stage2_latent, raw_stage3_latent]
    subsets = {}
    subsets_cells = {}
    scores = {}
    for i in range(len(adata.obs['stage'].unique())):
        
        updated_ipf1_latent, updated_ipf2_latent = unagi_stage_perturabtion(adata,model, direction1, direction2,i)
        for track_id, each in enumerate(tracks):
            if track_id not in list(subsets_cells.keys()):
                subsets_cells[track_id] = 0
            if track_id not in list(subsets.keys()):
                subsets[track_id] = []
            if track_id not in list(scores.keys()):
                scores[track_id] = []
            for stage, clusters in enumerate(each):

                temp_adata = temp[stage][temp[stage].obs['leiden'].isin(clusters)]
                
                subsets[track_id].append(temp_adata)
                if i == stage:
                    subset_updated_ipf1_latent = updated_ipf1_latent[updated_ipf1_latent.obs['leiden'].isin(clusters)]
                    subset_updated_ipf2_latent = updated_ipf2_latent[updated_ipf2_latent.obs['leiden'].isin(clusters)]
            subsets_cells[track_id] = len(subsets[track_id][0])+len(subsets[track_id][1])+len(subsets[track_id][2])+len(subsets[track_id][3])
            score = calcualte_distance_changes(subsets[track_id][0], subsets[track_id][1], subsets[track_id][2], subsets[track_id][3],subset_updated_ipf1_latent,subset_updated_ipf2_latent,i)
            scores[track_id].append(score)
    np_scores = np.array(list(scores.values()))

    total_cells = np.sum(np.array(list(subsets_cells.values())))
    np_scores = np.mean(np_scores,axis=1)

    total_score = np.sum(np_scores*(np.array(list(subsets_cells.values()))/total_cells))
    return total_score
def getDistance(rep, cluster):

    # cluster = cluster.reshape(1,-1)
    # cluster = cluster.repeat(rep.shape[0],axis=0)
    rep = np.mean(rep,axis=0)
    cluster = np.mean(cluster,axis=0)

    return np.linalg.norm(rep-cluster)
def calculateScore(delta,flag,weight=100):
    '''
    Calculate the perturbation score.

    parameters
    -----------
    delta: float
        The perturbation distance.(D(Perturbed cluster, others stages)  - D(Original cluster, others stages)  (in z space))
    flag: int
        The stage of the time-series single-cell data.
    weight: float
        The weight to control the perturbation score.

    return
    --------
    out: float
        The perturbation score.
    '''
    out = 0
    out1 = 0
    for i, each in enumerate(delta):
        
        if i != flag:
            out+=(1-1/(1+np.exp(weight*each*np.sign(i-flag)))-0.5)/0.5

    return out/(len(delta)-1)#, out1

def calcualte_distance_changes(control, stage1,stage2,stage3, updated_direction1, updated_direction2,stage):
    control = np.array(control.X.toarray())
    stage1 = np.array(stage1.X.toarray())
    stage2 = np.array(stage2.X.toarray())
    stage3 = np.array(stage3.X.toarray())
    reps = [control,stage1,stage2,stage3]
    raw_distance = []
    for each in reps:
        raw_distance.append(getDistance(reps[stage], each))

    updated_direction1 = np.array(updated_direction1.X)
    updated_direction2 = np.array(updated_direction2.X)
    direciton1_distance = []
    direciton2_distance = []
    for each in reps:
        direciton1_distance.append(getDistance(updated_direction1, each))
        direciton2_distance.append(getDistance(updated_direction2, each))
    delta1 = np.array(direciton1_distance) - np.array(raw_distance)
    delta2 = np.array(direciton2_distance) - np.array(raw_distance)
    score1 = calculateScore(delta1,stage)
    score2 = calculateScore(delta2,stage)
    final_score = np.abs(score1-score2)/2

    return final_score
def unagi_random(adata,model,times=1000):
    background_score = []
    for time in times:
        random_genes = random.sample(list(adata.var.index),2)
        random_perturbed_genes = []
        for each in random_genes:
            temp =''
            direction = random.choice(['+','-'])
            temp = each+':'+direction
            random_perturbed_genes.append(temp)
        extend = random.uniform(2,5)
        direction1 = translate_direction(random_perturbed_genes, extend)
        direction2 = reverse_translate_direction(random_perturbed_genes, extend)
        raw_control_latnet, raw_ipf_latnet, updated_ipf1_latent, updated_ipf2_latent = unagi_perturabtion(adata,model, direction1, direction2)
        score = calcualte_distance_changes(raw_control_latnet, raw_ipf_latnet, updated_ipf1_latent, updated_ipf2_latent)
        background_score.append(score)
    np.save('../data/unagi_background_score.npy',background_score)
def unagi_real():
    real_scores = []
    drug_groups = os.listdir('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_negative/')
    all_adata = sc.read('/mnt/md0/yumin/to_published/mes_4/dataset.h5ad')
    tracks = getTrack('/mnt/md0/yumin/to_published/mes_4/idrem')
    background = np.load('/mnt/md0/yumin/UNAGI_revision/data/unagi_background_score_feb27_prob.npy',allow_pickle=True)
    memory = []
    drug_groups = sorted(drug_groups)
    for each in drug_groups:
        adata = sc.read('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_negative/%s/dataset.h5ad'%(each))
        adata.obsp['gcn_connectivities'] = all_adata.obsp['gcn_connectivities']
        model = VAE(len(adata.var.index), dimZ, dimG, 0.5)
        model = model.to('cuda:0')
        if 'neg_tune.pth' not in os.listdir('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_negative/%s/'%(each)):
            continue
        print(each)
        model.load_state_dict(torch.load('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_negative/%s/neg_tune.pth'%(each),map_location='cuda:2'))
     
        old_target = np.load('/mnt/md0/yumin/UNAGI_revision/data/drug_simulation_negative/%s/genes.npy'%(each),allow_pickle=True)
        random_genes = old_target
        random_perturbed_genes = []

        for each in random_genes:
            temp =''
            direction = random.choice(['+','-'])
            temp = each+':'+direction
            random_perturbed_genes.append(temp)
        target = random_perturbed_genes

        adata.obs['condition'] = None
        for extend in np.arange(1.1,2.1,0.2):

            direction1 = translate_direction(target, extend)
            direction2 = reverse_translate_direction(target, extend)
            scores = unagi_perturabtion(adata,model, direction1, direction2,tracks)
            real_scores.append(scores)

        pval = 1-norm.cdf(np.array(real_scores), loc=np.mean(background), scale=np.std(background))
        print(pval)
        # print(gdsg)
        np.save('/mnt/md0/yumin/UNAGI_revision/data/unagi_negative_perturbation_score.npy',real_scores)

if __name__ == '__main__':
    dimZ=64
    dimG = 0
    unagi_real()