# Function

In [1]:
import numpy as np
import scipy
from scipy import spatial
import matplotlib.pyplot as plt
import anndata
import pandas as pd
from tqdm import tqdm,trange
import warnings
warnings.filterwarnings('ignore')


class SimulateSpatialDataGenerator:
    def __init__(
        self,
        n_obs_sqrt,n_var,
        sig_ratio=0.1,randomness=0.1,sparsity=2,cell_cluster=10,
        seed=None
        ) -> None:
        self.n_obs = n_obs_sqrt
        self.n_var = n_var
        self.randomness = randomness # controls the randomness of the data
        self.sparsity = sparsity # controls the sparse of genes; less sparse genes have higher max_vals
        self.rsig = sig_ratio # controls the ratio of significant genes
        self.n_cluster = cell_cluster
        self.cell_label = np.random.choice(self.n_cluster,self.n_obs**2)
        if seed is not None:
            np.random.seed(seed)

    def intmax_scale(self,x,max_val):
        x = ((x/np.quantile(x,q=0.75))*max_val).astype(int)
        return x

    def _RWR_Sim(
        self,
        transfer_matrix,lam=5,alpha=0.2,
        n_iters=10,
    ):
        # simulate data
        data = np.random.poisson(lam, self.n_obs**2)
        for _ in range(n_iters):
            data = (1-alpha)*(data.reshape((1,-1)) @ transfer_matrix) + alpha * data.reshape((1,-1))
        return data

    def generate_single_positive(self,max_val,transfer_matrix,**kwargs):
        max_val = max(max_val,3)
        x = self._RWR_Sim(transfer_matrix,**kwargs)
        x = self.intmax_scale(x,max_val)
        return x.reshape(-1)

    def generate_single_negative(self,max_val,lam=5,**kwargs):
        x = np.random.poisson(lam, self.n_obs**2)
        x = self.intmax_scale(x,max_val)
        return x

    def generate(self,verbose=True,**kwargs):
        locations = np.array([[i,j] for i in range(self.n_obs) for j in range(self.n_obs)])
        dist_mat = scipy.spatial.distance.cdist(locations,locations)
        max_vals = np.random.gamma(self.sparsity,5,self.n_var).astype(int)

        transfer_matrix = np.exp(
            -(
                dist_mat + \
                    self.randomness*np.random.random((self.n_obs**2,self.n_obs**2))
            )
        )
        transfer_matrix /= transfer_matrix.sum(axis=1)

        # from cell label to one-hot
        cell_label = np.zeros((self.n_obs**2,self.n_cluster))
        cell_label[np.arange(self.n_obs**2),self.cell_label] = 1
        cell_type_dist_mat = scipy.spatial.distance.cdist(cell_label,cell_label,metric='dice')
        # make tansfer matrix with cell type
        # same cell type will have higher probability to transfer
        transfer_matrix = transfer_matrix * (1-0.5*cell_type_dist_mat)

        pos_data = np.array([
            self.generate_single_positive(max_val,transfer_matrix,**kwargs) for max_val in tqdm(
                max_vals[:int(self.n_var*self.rsig)],
                desc = 'Generating positive data',
                disable = not verbose
                )
            
            ])
        neg_data = np.array([self.generate_single_positive(max_val,transfer_matrix=cell_type_dist_mat,**kwargs) for max_val in max_vals[int(self.n_var*self.rsig):]])
        data = np.concatenate([pos_data,neg_data],axis=0)
        data = data.T

        adata = anndata.AnnData(
            data,
            obs = pd.DataFrame(index=[f"cell_{i}" for i in range(self.n_obs**2)]),
            var = pd.DataFrame(index=[f"gene_{i}" for i in range(self.n_var)])
            )

        adata.obs['cell_label'] = self.cell_label
        adata.obsm['spatial'] = locations
        adata.var['positive'] = np.concatenate([np.ones(int(self.n_var*self.rsig)),np.zeros(self.n_var-int(self.n_var*self.rsig))]).astype(bool)
        adata.uns['transfer_matrix'] = transfer_matrix
        simulation_params = {
            'n_obs_sqrt':self.n_obs,
            'n_var':self.n_var,
            'randomness':self.randomness,
            'sparsity':self.sparsity,
            'rsig':self.rsig,
            'lam':5,
            'alpha':0.2,
            'n_iters':10,
        }
        simulation_params.update(kwargs)
        adata.uns['simulation_params'] = simulation_params

        self.adata = adata
        return adata
        
    def save(self,file_dir):
        self.adata.write_h5ad(file_dir)

    def plot(self,value,**kwargs):
        plt.scatter(self.adata.obsm['spatial'][:,0],self.adata.obsm['spatial'][:,1],c=value,**kwargs)


# Generate

In [None]:
import os
save_dir = './data/Simulations/'
for n_obs_sqrt in [10,30,50,80,100,150,200,250,300]:
    for randomness in np.linspace(0,0.3,4):
        seed = int(np.random.random()*1000)
        generator = SimulateSpatialDataGenerator(
            n_obs_sqrt=n_obs_sqrt,n_var=1000,randomness=randomness,seed=seed
        )
        generator.generate()
        generator.save(
            os.path.join(save_dir,f"Sim_{n_obs_sqrt}_{randomness:.1f}_{seed}.h5ad")            
        )