The purpose of this notebook is to indentify what cells were used for testing.

In [1]:
# open test_results.pkl

import pickle
import numpy as np
from core.backend import open_minian

with open('test_results.pkl', 'rb') as f:
    results = pickle.load(f)

In [2]:
# Iterate through experiments
dict = {"cross_session_same_day": [["./data/PL010/PL010_D1S1", "./data/PL010/PL010_D1S4"], ["./data/AA058/AA058_D1S1", "./data/AA058/AA058_D1S4"], ["./data/AA036/AA036_D2S1", "./data/AA036/AA036_D2S4"], ["./data/AA034/AA034_D1S1", "./data/AA034/AA034_D1S4"]],
        "cross_day_same_session": [["./data/PL010/PL010_D1S1", "./data/PL010/PL010_D8S1"], ["./data/AA058/AA058_D1S1", "./data/AA058/AA058_D5S1"], ["./data/AA036/AA036_D2S1", "./data/AA036/AA036_D6S1"], ["./data/AA034/AA034_D1S1", "./data/AA034/AA034_D7S1"]],
        "cross_day_cross_session": [["./data/PL010/PL010_D1S1", "./data/PL010/PL010_D8S4"], ["./data/AA058/AA058_D1S1", "./data/AA058/AA058_D5S4"], ["./data/AA036/AA036_D2S1", "./data/AA036/AA036_D6S4"], ["./data/AA034/AA034_D1S1", "./data/AA034/AA034_D7S4"]],
        "cross_animal": [["./data/PL010/PL010_D1S1", "./data/AA058/AA058_D1S1"], ["./data/AA058/AA058_D1S1", "./data/AA036/AA036_D2S1"], ["./data/AA036/AA036_D2S1", "./data/AA034/AA034_D1S1"], ["./data/AA034/AA034_D1S1", "./data/PL010/PL010_D1S1"]],
        "within_session": [["./data/PL010/PL010_D1S1"], ["./data/AA058/AA058_D1S1"], ["./data/AA036/AA036_D2S1"], ["./data/AA034/AA034_D1S1"]]}

experiments = results.keys()

def find_test_set(experiment, training_set):
        arr = dict[experiment]
        if len(arr[0]) == 1:
                for train in arr:
                        if train[0][-len(training_set):] == training_set:
                                return train[0]
        else:
                for train, test in arr:
                        if train[-len(training_set):] == training_set:
                                return test
        print(f"Error: Test set not found for {experiment} and {training_set}")

def get_ground_truth(arr):
        # Get rows 0, 2, 4 etc. from the array
        return arr[::2]
        

In [3]:
results["cross_session_same_day"]["AA034_D1S1"]

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 1, 1],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 1, 1]], dtype=int8)

In [15]:
test_cells_dict = {}

for experiment in experiments:
    test_cells_dict[experiment] = {}
    training_sets = results[experiment].keys()
    for training_set in training_sets:
        test_cells_dict[experiment][training_set] = {}
        if experiment == "within_session" and training_set == "AA034_D1S1":
            print(training_set + " " + experiment)
        test_set = find_test_set(experiment, training_set)
        E = open_minian(test_set)["E"].load()
        test_set = test_set.split("/")[-1]
        test_cells_dict[experiment][training_set][test_set] = []
        all_unit_ids = E.unit_id.values
        verified = E.verified.values.astype(int)
        unit_ids = all_unit_ids[verified==1]
        E = E.sel(unit_id=unit_ids)
        ground_truth = get_ground_truth(results[experiment][training_set])
        # We need to match the ground truth with the results in E and return the name of the cell
        for row in ground_truth:
            # This took me a while to figure out but I made a small mistake when trying to generate the data, I used 26999 frames instead of 27000
            # Split the row into 5 equal parts
            mini_rows = np.array_split(row, 5)
            mini_indices = []
            for mini_r in mini_rows:
                found = False
                for unit_id in unit_ids:
                    if (E.sel(unit_id=unit_id).values[:-1] == mini_r).all():
                        mini_indices.append(unit_id)
                        found = True
                        break
            
                if not found:
                    raise Exception(f"Error: Cell not found in test set {test_set} for experiment {experiment} and training set {training_set}")
                
            test_cells_dict[experiment][training_set][test_set].append(mini_indices) 
    
    

1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
  arr = list(xr.open_zarr(arr_path).values())[0]
1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
  arr = list(xr.open_zarr(arr_path).values())[0]
1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try r

AA034_D1S1 within_session


1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
  arr = list(xr.open_zarr(arr_path).values())[0]
  verified = E.verified.values.astype(int)
1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
  arr = list(xr.open_zarr(arr_path).values())[0]


In [16]:
test_cells_dict["cross_session_same_day"]["AA034_D1S1"]



{'AA034_D1S4': [[80, 131, 50, 73, 40],
  [131, 70, 17, 26, 109],
  [50, 96, 97, 40, 43],
  [70, 40, 50, 26, 43],
  [17, 50, 70, 111, 131],
  [26, 96, 131, 50, 73],
  [97, 26, 54, 80, 109],
  [26, 43, 96, 50, 80],
  [97, 50, 26, 73, 40],
  [109, 73, 70, 131, 80],
  [111, 26, 73, 80, 43],
  [131, 80, 97, 111, 17],
  [111, 96, 40, 97, 43],
  [109, 40, 43, 70, 96],
  [40, 70, 54, 113, 97],
  [43, 131, 54, 111, 70],
  [97, 17, 70, 131, 43],
  [17, 96, 70, 111, 113],
  [111, 109, 50, 131, 96],
  [96, 70, 113, 54, 131],
  [26, 109, 50, 17, 40],
  [96, 17, 109, 73, 131],
  [40, 70, 109, 17, 80],
  [70, 43, 113, 96, 17],
  [50, 131, 70, 109, 26],
  [17, 131, 40, 54, 43],
  [40, 113, 80, 109, 111],
  [97, 70, 96, 17, 26],
  [97, 80, 113, 43, 111],
  [96, 70, 17, 40, 26]]}

In [12]:
results

{'cross_animal': {'AA034_D1S1': array([[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 1],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]], dtype=int8),
  'AA036_D2S1': array([[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 1],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 1, 1]], dtype=int8),
  'AA058_D1S1': array([[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 1],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 1, 1]], dtype=int8),
  'PL010_D1S1': array([[0, 0, 0, ..., 0, 1, 1],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 1],
         [0, 0, 0, ..., 0, 0, 1]], dtype=int8)},