In [1]:
from ott.neural.flows.models import VelocityField
from ott.neural.flows.flows import ConstantNoiseFlow
from ott.neural.flows.samplers import sample_uniformly
from ott.neural.flows.otfm import OTFlowMatching
from ott.solvers.linear import sinkhorn
from ott.tools import sinkhorn_divergence
from ott.geometry import pointcloud
from ott.neural.data.dataloaders import OTDataLoader, ConditionalDataLoader

import optax
import scanpy as sc

import jax.numpy as jnp

import numpy as np
import pandas as pd

from ott.tools import sinkhorn_divergence
from ott.geometry import pointcloud

from fractions import Fraction
import anndata as ad

In [11]:
def dataGeneration(data: ad.AnnData, weights: pd.DataFrame): #adjustment of data according to weights wrt to cellcluster_moscot
    weight = []
    mass_rel = []
    clusters = []
    for cluster in data.obs.cellcluster_moscot.values.unique():
        if cluster in weights['celltype'].values:
            weight.append(float(sum(Fraction(s) for s in weights.loc[weights['celltype']==cluster, 'weight'].values[0].split())))
        else:
            weight.append(1.)
        mass_rel.append(data[data.obs.cellcluster_moscot==cluster].n_obs)
        clusters.append(cluster)
    
    mass_rel = np.array(mass_rel)/data.n_obs
    weight = np.array(weight)

    mass = np.sum(mass_rel*weight)
    weight_adjusted = weight/mass

    weight_adjusted = pd.DataFrame(data={'cluster': clusters, 'weight': weight_adjusted})

    
    adata_collection = []

    for cluster in data.obs.cellcluster_moscot.values.unique(): #for each cluster get samples with replacement and sample according to these indices
        data_cluster = data[data.obs.cellcluster_moscot==cluster]
        old_n_obs = data_cluster.n_obs
        obs_indices = np.random.choice(
            old_n_obs,
            size=int(old_n_obs*weight_adjusted.loc[weight_adjusted['cluster']==cluster, 'weight'].values[0]),
            replace=True)
        adata_collection.append(data[obs_indices])

    return ad.concat(adata_collection) #concatenate all cluster

def dataGenerator(data: ad.AnnData, day_source: float, day_target: float, weights: pd.DataFrame): #prepare source and target for dataGeneration
    weights_time = weights[weights['timepoint']==day_source]
    weights_source = weights_time[weights_time['marginal']=='source_marginals']
    weights_target = weights_time[weights_time['marginal']=='target_marginals']
    
    data_source_prelim = data[data.obs.day==day_source]
    data_target_prelim = data[data.obs.day==day_target]

    data_source = dataGeneration(data_source_prelim, weights_source)
    data_target = dataGeneration(data_target_prelim, weights_target)

    return data_source, data_target

In [3]:
weights = pd.read_csv('/home/icb/jonas.flor/gastrulation_atlas/moscot/weights_together.csv')
weights

Unnamed: 0.1,Unnamed: 0,timepoint,weight,marginal,celltype
0,0,8.50,1/16,target_marginals,Mesoderm
1,1,8.50,1/16,target_marginals,Neuroectoderm_and_glia
2,2,8.50,1/4,target_marginals,Endothelium
3,3,8.50,1/4,target_marginals,Epithelial_cells
4,4,8.50,1/4,target_marginals,Neural_crest_PNS_glia
...,...,...,...,...,...
1013,1013,18.75,100,target_marginals,Olidendrocytes
1014,1014,18.75,500,target_marginals,Primordial_germ_cells
1015,1015,18.75,100,target_marginals,T_cells
1016,1016,18.75,1000,target_marginals,Testis_and_adrenal


In [4]:
adata = sc.read_h5ad(f'/home/icb/jonas.flor/gastrulation_atlas/scvi/training/10k/2k_genes/1/128/integrated_adata.h5ad')
day_sort=adata.obs.day.unique()
day_sort.sort()
adata

AnnData object with n_obs × n_vars = 10000 × 2000
    obs: 'cell_id', 'keep', 'day', 'embryo_id', 'experimental_batch', 'batch', 'cell_cluster', 'celltype', 'cellcluster_moscot', '_scvi_batch', '_scvi_labels', 'dpt_pseudotime'
    uns: 'diffmap_evals', 'iroot', 'log1p', 'neighbors'
    obsm: 'X_diffmap', 'X_emb'
    obsp: 'connectivities', 'distances'

In [6]:
dataloaders = {}
for day_ind in range(day_sort.shape[0]-1):
        source, target = dataGenerator(adata, day_sort[day_ind], day_sort[day_ind+1], weights)
        
        dataloaders[day_sort[day_ind]] = OTDataLoader(
            1024, 
            source_lin=source.obsm['X_emb'], 
            target_lin=target.obsm['X_emb'], 
            source_conditions=np.expand_dims(source.obs.day.values, axis=1)
            )

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
 

In [114]:
#data_target = dataGeneration(data_target_prelim, weights_target)
#need_adjustment = []
#no_adjustment = []
#for x in data_target_prelim.obs.cellcluster_moscot.values.unique():
#    if x in weights_target['celltype'].values:
#        need_adjustment.append(x)
#    else:
#        no_adjustment.append(x)
#
#adata_collection = [data_target_prelim[[x in no_adjustment for x in data_target_prelim.obs.cellcluster_moscot]]]
#
#for cluster in need_adjustment:
#    data_cluster = data_target_prelim[data_target_prelim.obs.cellcluster_moscot==cluster]
#    old_n_obs = data_cluster.n_obs
#    obs_indices = np.random.choice(old_n_obs, size=int(old_n_obs*float(sum(Fraction(s) for s in weights_target.loc[weights_target['celltype']==cluster, 'weight'].values[0].split()))), replace=True)
#    adata_collection.append(data_target_prelim[obs_indices])
#data_target = ad.concat(adata_collection)
target

AnnData object with n_obs × n_vars = 304 × 2000
    obs: 'cell_id', 'keep', 'day', 'embryo_id', 'experimental_batch', 'batch', 'cell_cluster', 'celltype', 'cellcluster_moscot', '_scvi_batch', '_scvi_labels', 'dpt_pseudotime'
    obsm: 'X_diffmap', 'X_emb'

In [17]:
source, target = dataGenerator(adata, 8.75, 9.0, weights)

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


In [18]:
adata[adata.obs.day==8.75]

View of AnnData object with n_obs × n_vars = 32 × 2000
    obs: 'cell_id', 'keep', 'day', 'embryo_id', 'experimental_batch', 'batch', 'cell_cluster', 'celltype', 'cellcluster_moscot', '_scvi_batch', '_scvi_labels', 'dpt_pseudotime'
    uns: 'diffmap_evals', 'iroot', 'log1p', 'neighbors'
    obsm: 'X_diffmap', 'X_emb'
    obsp: 'connectivities', 'distances'

In [19]:
source

AnnData object with n_obs × n_vars = 32 × 2000
    obs: 'cell_id', 'keep', 'day', 'embryo_id', 'experimental_batch', 'batch', 'cell_cluster', 'celltype', 'cellcluster_moscot', '_scvi_batch', '_scvi_labels', 'dpt_pseudotime'
    obsm: 'X_diffmap', 'X_emb'

In [20]:
adata[adata.obs.day==9.0]

View of AnnData object with n_obs × n_vars = 91 × 2000
    obs: 'cell_id', 'keep', 'day', 'embryo_id', 'experimental_batch', 'batch', 'cell_cluster', 'celltype', 'cellcluster_moscot', '_scvi_batch', '_scvi_labels', 'dpt_pseudotime'
    uns: 'diffmap_evals', 'iroot', 'log1p', 'neighbors'
    obsm: 'X_diffmap', 'X_emb'
    obsp: 'connectivities', 'distances'

In [21]:
target

AnnData object with n_obs × n_vars = 87 × 2000
    obs: 'cell_id', 'keep', 'day', 'embryo_id', 'experimental_batch', 'batch', 'cell_cluster', 'celltype', 'cellcluster_moscot', '_scvi_batch', '_scvi_labels', 'dpt_pseudotime'
    obsm: 'X_diffmap', 'X_emb'

In [75]:
weights_target['celltype'].values

array(['Mesoderm', 'Neuroectoderm_and_glia', 'Endothelium',
       'Epithelial_cells', 'Neural_crest_PNS_glia', 'Primitive_erythroid',
       'CNS_Neurons', 'Hepatocytes', 'Primordial_germ_cells',
       'Neural_crest_PNS_neurons', 'Muscle_cells'], dtype=object)