In [1]:
import pandas as pd
import numpy as np
from helper import (get_channels,get_channel_clusters,
                    get_channel_cluster_signals,spike_heatmap,
                    get_cluster_label)


In [2]:
# the main function for generating density plot as input
def create_density_csv(fpaths,csv_path,x_tot_std = 30.225052516123352,
                      n_bin=101,include_fpath=False):
    """ create csv file containing density plot data 
        Input: - fpaths is a list of paths
               - csv_path is where the results are stored
               - x_tot_std is the standard deviation of all 
                 used spikes of the train+dev data
               - n_bin is the number of bins for the signal amplitude
    """ 
    
    # open file to store results
    with open(csv_path,"w") as f:

        # create header: spike img has shape (n_bin,64)
        for i in range(n_bin*64): f.write(str(i) + ",")

        f.write("unitClass,DataSetName\n")

        # load, generate, and write acutal data
        # loop over files to consider
        for ifpath,fpath in enumerate(fpaths):
    
            # print progress
            print(f"{ifpath+1} of {len(fpaths)}...")
            
            # load new data set
            data = pd.read_csv(fpath)

            # get channel ids of data set
            channel_ids = get_channels(data)

            # loop over channels
            for channel_id in channel_ids:
                
                # find all cluster ids
                cluster_ids = get_channel_clusters(data, channel_id)

                # loop over cluster
                for cluster_id in cluster_ids:

                    # get all spikes
                    cluster_array = get_channel_cluster_signals(data,channel_id,cluster_id)

                    # convert to heatmap with normalized spikes (w.r.t. the std)
                    cluster_heatmap = spike_heatmap(cluster_array/x_tot_std)
                    
                    # flatten density plot
                    cluster_heatmap = np.squeeze(cluster_heatmap.reshape((1,-1)))
                    
                    # get the label of these 
                    label = get_cluster_label(data,channel_id,cluster_id)
                    for i in range(n_bin*64): f.write(f"{cluster_heatmap[i]},")

                    if include_fpath: f.write(f"{label},{fpath}\n")
                    else: f.write(f"{label}\n")


In [3]:
# file names of data sets to be converted to density plots
# these files are all stored in one csv file here: csv_path
fpaths = [
'data/078e09sniff1.csv',
'data/079e02sniff1.csv',
'data/079exxsniff2.csv',
'data/080e02sniff1.csv',
'data/082e02sniff1.csv',
'data/083e02sniff1.csv',
'data/083e37sniff2.csv',
'data/085e04sniff1.csv',
'data/085e08sniff2.csv',
'data/086e20sniff1.csv',
'data/086e23sniff2.csv',
'data/086e34sniff3.csv',
'data/087e02sniff1.csv',
'data/087e34sniff2.csv',
'data/089e39sniff1.csv',
'data/089e58sniff2.csv',
'data/090e02sniff1.csv']

csv_path = "./data/density_train_dev.csv"

create_density_csv(fpaths,csv_path)

1 of 17...
2 of 17...
3 of 17...
4 of 17...
5 of 17...
6 of 17...
7 of 17...
8 of 17...
9 of 17...
10 of 17...
11 of 17...
12 of 17...
13 of 17...
14 of 17...
15 of 17...
16 of 17...
17 of 17...


In [4]:
# generate and store density plots for the test data
fpaths = ["data/084e02sniff1.csv"]
csv_path = "./data/084e02sniff1_density.csv"
create_density_csv(fpaths,csv_path)

fpaths = ["data/088e29sniff1.csv"]
csv_path = "./data/088e29sniff1_density.csv"
create_density_csv(fpaths,csv_path)

fpaths = ["data/090e27sniff2.csv"]
csv_path = "./data/090e27sniff2_density.csv"
create_density_csv(fpaths,csv_path)

fpaths = ["data/089e72sniff3.csv"]
csv_path = "./data/089e72sniff3_density.csv"
create_density_csv(fpaths,csv_path)

1 of 1...
1 of 1...
1 of 1...
1 of 1...
