In [None]:
from copick.impl.filesystem import CopickRootFSSpec

In [None]:
COPICK_CONFIG_PATH = "../assets/samba_config_jfinder.json"
root = CopickRootFSSpec.from_file(COPICK_CONFIG_PATH)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import copy
import random

from scipy.stats import ks_2samp



class SplitDataset:
    def __init__(self, config_file: str):
        self.root = CopickRootFSSpec.from_file(config_file)
        self.arrs = []
        self.tomograms = []

        N = len(self.root.runs)
        self.particle_map = {o.name: i for i,o in enumerate(self.root.config.pickable_objects)}
        self.run_stats_list = [0]*N
        for i, run in enumerate(self.root.runs[:N]):
            self.tomograms.append(run.name)
            counter = defaultdict(int)
            percent = defaultdict(float)
            all = 0
            for pick in run.picks:
                if pick.points is not None:
                    counter[pick.pickable_object_name] = len(pick.points)
                    all += len(pick.points)

            
            for k in self.particle_map.keys():
                if k not in counter:
                    counter[k] = 0
            
            # for k,v in counter.items():
            #     percent[k] = v/all
            percent = counter
            
            percent = {k: percent[k] for k in sorted(percent.keys())}
            print(i, run.name, percent)
            self.run_stats_list[i] = percent

        print(f'self.run_stats_list {len(self.run_stats_list)}')
        for d in self.run_stats_list:
            arr = []
            for k,v in d.items():
                arr += [self.particle_map[k]]*v
            self.arrs.append(arr)
        
        
    def plot_all_distributions(self):
        # Choose a colormap
        colormap = plt.cm.viridis
        # Generate a list of colors from the colormap
        colors = [colormap(i / len(self.arrs)) for i in range(len(self.arrs))]

        plt.figure(figsize=(12, 6))
        plt.subplot(1, 1, 1)
        for i,arr in enumerate(self.arrs):  
            sns.histplot(arr, kde=True, color=colors[i]) #, label=self.tomograms[i])
        plt.legend()
        plt.title('Histogram')

        # plt.subplot(1, 2, 2)
        # sns.boxplot(data=self.arrs)
        # plt.xticks([i for i in range(len(self.arrs))], [f'{i}' for i in range(len(self.arrs))])
        # plt.title('Box Plot')

        plt.show()

    
    def plot_2dist(self, first: int, second: int, arr=False):
        if arr:
            arr1 = first
            arr2 = second
            label1 = 'arr1'
            label2 = 'arr2'
        else:
            arr1 = self.arrs[first]
            arr2 = self.arrs[second]
            label1 = self.tomograms[first]
            label2 = self.tomograms[second]

        plt.figure(figsize=(12, 6))
        plt.subplot(1, 1, 1)
        sns.histplot(arr1, kde=True, color='blue', label=label1)
        sns.histplot(arr2, kde=True, color='red', label=label2)
        plt.legend()
        plt.title('Histogram')

        # plt.subplot(1, 2, 2)
        # sns.boxplot(data=[arr1, arr2])
        # plt.xticks([0, 1], [label1, label2])
        # plt.title('Box Plot')

        plt.show()
    
    
    def plot_dist_3arrs(self, arr1, arr2, arr3):
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 1, 1)
        sns.histplot(arr1, kde=True, color='blue', label='training_dataset')
        sns.histplot(arr2, kde=True, color='red', label='test_dataset1')
        sns.histplot(arr3, kde=True, color='green', label='test_dataset2')
        plt.legend()
        plt.title('Histogram')
        plt.show()
    
    def plot_dist_test(self, arr2, arr3):
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 1, 1)
        sns.histplot(arr2, kde=True, color='red', label='test_dataset1')
        sns.histplot(arr3, kde=True, color='green', label='test_dataset2')
        plt.legend()
        plt.title('Histogram')
        plt.show()
    
    def is_close_dist(self, first: int, second: int, threshold=0.05):
        return self.is_arr_close_dist(self.arrs[first], self.arrs[second], threshold)
    
    
    @staticmethod
    def is_arr_close_dist(arr1, arr2, threshold=0.05):
        # Kolmogorov-Smirnov test
        ks_stat, ks_p_value = ks_2samp(arr1, arr2)

        # If the p-value is high (e.g., > 0.05), it suggests that there is no significant difference between the distributions of the two dictionaries' values.
        print(f"KS Statistic: {ks_stat}, P-value: {ks_p_value}")
        return ks_p_value > threshold
    
    
    def make_buckets(self, threshold=0.05):
        if len(self.run_stats_list):
            stats = copy.deepcopy(self.run_stats_list[0])
            self.buckets = [[stats, set([0])]] # list of [defaultdict, set()]
        
        add_new = True
        for i in range(1, len(self.arrs)):
            for j in range(len(self.buckets)):
                arr1 = self.arrs[i]
                arr2 = []
                for k,v in self.buckets[j][0].items():
                    arr2 += [self.particle_map[k]]*v  
                if self.is_arr_close_dist(arr1, arr2, threshold):
                    self.buckets[j][0] = {k: self.buckets[j][0][k] + self.run_stats_list[i][k] for k in self.run_stats_list[i].keys()}
                    self.buckets[j][1].add(i)
                    add_new = False
                    break
                else:
                    add_new = True
            
            if add_new:  
                stats = copy.deepcopy(self.run_stats_list[i])
                self.buckets.append([stats, set([i])])

        
        for bucket in self.buckets:
            print(bucket)
        
        print(f'{len(self.buckets)} buckets')


        # Choose a colormap
        colormap = plt.cm.viridis
        # Generate a list of colors from the colormap
        colors = [colormap(i / len(self.buckets)) for i in range(len(self.buckets))]
        # visualizing buckets
        plt.subplot(1, 1, 1)
        for i,bucket in enumerate(self.buckets):
            arr = []
            for k,v in bucket[0].items():
                arr += [self.particle_map[k]]*v
            self.arrs.append(arr)  
            sns.histplot(arr, kde=True, color=colors[i], label=f'cluster {i}')
        plt.legend()
        plt.title('Histogram') 
        

    
    def random_split_list(self, my_list, ks=[0.6, 0.2, 0.2]):
        # Shuffle the original list to ensure randomness
        random.shuffle(my_list)
        
        # Calculate split indices
        train = round(ks[0] * len(my_list))
        test1 = round(ks[1] * len(my_list))
        
        # Split the list into three parts
        train_set = my_list[:train]
        test_set1 = my_list[train:train+test1]
        test_set2 = my_list[train+test1:]
        
        return train_set, test_set1, test_set2

    
    def id2arr(self, ids):
        arr = []
        for i in ids:
            arr = arr + self.arrs[i]
        return arr

    def generate_datasets(self, ks=[0.6, 0.2, 0.2]):
        #self.make_buckets()
        train_dt = [] 
        test_dt1 = []
        test_dt2 = []
        for bucket in self.buckets:
            train_set, test_set1, test_set2 = self.random_split_list(list(bucket[1]), ks)
            train_dt = train_dt + train_set
            test_dt1 = test_dt1 + test_set1
            test_dt2 = test_dt2 + test_set2

        train_dataset = [self.tomograms[i] for i in train_dt] 
        test_dataset1 = [self.tomograms[i] for i in test_dt1]
        test_dataset2 = [self.tomograms[i] for i in test_dt2] 
 
        train_arr = self.id2arr(train_dt)
        test_arr1 = self.id2arr(test_dt1)
        test_arr2 = self.id2arr(test_dt2)
        self.plot_dist_3arrs(train_arr, test_arr1, test_arr2) 
        self.plot_dist_test(test_arr1, test_arr2) 
        return train_dataset, test_dataset1, test_dataset2


In [None]:
datasets = SplitDataset(COPICK_CONFIG_PATH)

In [None]:
datasets.plot_all_distributions()

In [None]:
datasets.make_buckets()

In [None]:
train_dataset, test_dataset1, test_dataset2 = datasets.generate_datasets()
print(f'train_dataset\n{train_dataset}\ntest_dataset1\n{test_dataset1}\ntest_dataset2\n{test_dataset2}')
print(len(train_dataset), len(test_dataset1), len(test_dataset2))

## Evaluate the datasets

```
for each fake model in [sigma = 0, sigma = 1, sigma = 10]:
    for each tomogram:
        for each particle type:
            newpicks = []
            for each point in ground truth/jfinder picks:
                point[0] += gaussian(sigma)
                point[1] += gaussian(sigma)
                point[2] += gaussian(sigma)
                # TODO add clipping to check that points stay within tomogram dimensions
                newpicks.append(point)
```

In [None]:
import math

def gaussian_function(x, mu=0.0, sigma=1.0):
    """
    Calculate the value of the Gaussian function for a given x, mean (mu), and standard deviation (sigma).
    
    Parameters:
    x (float): The value at which to evaluate the Gaussian function.
    mu (float): The mean of the Gaussian distribution.
    sigma (float): The standard deviation of the Gaussian distribution.
    
    Returns:
    float: The value of the Gaussian function at x.
    """
    coefficient = 1.0 / (sigma * math.sqrt(2 * math.sqrt(math.pi)))
    exponent = -((x - mu) ** 2) / (2 * sigma ** 2)
    return coefficient * math.exp(exponent)

# Example usage:
x_value = 1.0

result = gaussian_function(x_value)
print("The value of the Gaussian function at x =", x_value, "is", result)

In [None]:
for sigma in [0.0, 1.0, 10.0]:
    for i, run in enumerate(root.runs[:10]):
        for pick in run.picks:
                if pick.points is not None:
                    print(pick.pickable_object_name, len(pick.points))
    