# Sample intersection between genomes, function test code 

In this notebook, we'll test our the sample intersection code to make sure it's functioning properly, since it's essential for training/testing with multiple datasets

In [4]:
import numpy as np
import anndata as an
import pandas as pd 
import torch
from typing import *

In [20]:
def clean_sample(
    sample: torch.Tensor,
    refgenes: List[str],
    currgenes: List[str],
) -> torch.Tensor:
    
    intersection = np.intersect1d(currgenes, refgenes, return_indices=True)
    indices = intersection[1] # List of indices in sorted(currgenes) that equal sorted(refgenes)
    
    axis = (1 if sample.ndim == 2 else 0)
    sample = np.sort(sample, axis=axis)
    sample = np.take(sample, indices, axis=axis)

    return sample

Let's write a bunch of tests!

In [18]:
refgenes = ['a', 'b', 'c', 'd']
currgenes = ['a', 'b', 'c']

sample = np.array([1,2,3])

res = clean_sample(sample, refgenes, currgenes)

assert all(res == np.array([1,2,3]))

In [19]:
refgenes = ['a', 'c', 'd']
currgenes = ['a', 'b', 'c', 'd', 'e']

sample = np.array([1,2,3,4,5])

res = clean_sample(sample, refgenes, currgenes)

assert all(res == np.array([1,3,4]))

In [22]:
refgenes = ['c', 'a', 'd', 'b']
currgenes = ['a', 'b', 'c', 'd', 'e']

sample = np.array([1,2,3,4,5])

res = clean_sample(sample, refgenes, currgenes)

assert all(res == np.array([1,2,3,4]))