
Here we would like to perform some analysis with the IBL pipeline.

First thing first, let's **import the IBL pipeline package**, and a few other useful packages.

In [None]:
from ibl_pipeline import subject, acquisition, action, behavior, reference, ephys
import datajoint as dj
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from uuid import UUID
import datetime

## Analyzing existing data

**A simple example: compute the firing rate of each cluster across one session**

Let's take a quick look of the ephys schema:

In [None]:
dj.Diagram(ephys)

How many ephys sessions do we have?

In [None]:
ephys.Ephys()

How many ephys sessions with clustering results?

In [None]:
ephys.Ephys & ephys.Cluster

Let's then pick one ephys session to focus on

In [None]:
ephys_session = ephys.Ephys & {'subject_uuid': UUID('077d4b11-c784-4cb9-983c-5a596815434f')}

An overview of clusters in this session:

In [None]:
ephys.Cluster & ephys_session

Pick one cluster:

In [None]:
cluster = ephys.Cluster & ephys_session & 'cluster_id=0'

In [None]:
cluster

What do we need to compute the firing rate of each cluster?  
1. Total spike number  
2. Recording time length

In [None]:
# Total spike number
spk_times = cluster.fetch1('cluster_spike_times')
total_spk_num = len(spk_times)

In [None]:
# Session duration
session_duration = (acquisition.Session & cluster).proj(
    session_duration='session_end_time - session_start_time').fetch1('session_duration')

In [None]:
# compute firing rate 
fr = total_spk_num/session_duration

In [None]:
fr

Cool! We got the firing rate!  
The next question is, how do we save it in the database?  
Put the entry in a table!

## Create your own schema and tables

The first thing we would like to do is to create a schema with `dj.schema`.  
**Note**: the schema name you create has to either start with `user_{your user name}`, which is only accessible by you, or start with `group_share_`, which is accessible by the entire group. Here we use our user_name  
**Note 2**: if your user_name contains a `.`, such as `miles.wells`, please delete it (`mileswells`) when creating the table.

In [None]:
schema = dj.schema('user_shan_tutorial')

Let's check if the new schema is there:

In [None]:
dj.list_schemas()

Now let's define a **manual** table to save the firing rate result.  
A class created with DataJoint correponds to a table in the database.

In [None]:
@schema
class FiringRateManual(dj.Manual):
    definition = """
    -> ephys.Cluster         # Each cluster has a firing rate
    ---
    firing_rate:     float   # Hz
    """

Let's take a look at the brand-new table we just created.

In [None]:
FiringRateManual()

Yes, sure, it's emtpy. We haven't inserted anything into it.  
Now let's insert the firing rate we just computed into this empty table.  
We need to insert the entry with all fields defined in the table, usually in a format of dictionary.

In [None]:
# firing rate entry needs to inherit all primary keys from ephys.Cluster
cluster_key = cluster.fetch1('KEY')

In [None]:
cluster_key

In [None]:
firing_rate_entry = dict(
    **cluster_key,
    firing_rate=fr,
)

In [None]:
firing_rate_entry

Now insert it!

In [None]:
FiringRateManual.insert1(firing_rate_entry)

Let's check the table again to see what happened:

In [None]:
FiringRateManual()

Cool the entry is there!

So we can of course write a for loop to compute all fr and insert them one by one, but that's too slow. We can compute the results and insert them all at once!

In [None]:
# loop through the first 30 clusters and insert one by one, and compute time
import time
start_time = time.time()

for icluster in (ephys.Cluster & 'cluster_id between 1 and 30' & ephys_session).fetch('KEY'):
    spk_times = (ephys.Cluster & icluster).fetch1('cluster_spike_times')
    fr_entry = dict(**icluster,
                    firing_rate=len(spk_times)/session_duration)
    FiringRateManual.insert1(fr_entry)
    
print("--- %s seconds ---" % (time.time() - start_time))

In [None]:
# loop through the next 30 clusters and insert all at once as a list of dictionaries!
start_time = time.time()

fr_entries = []
for icluster in (ephys.Cluster & 'cluster_id between 31 and 60' & ephys_session).fetch('KEY'):
    spk_times = (ephys.Cluster & icluster).fetch1('cluster_spike_times')
    fr_entry = dict(**icluster,
                    firing_rate=len(spk_times)/session_duration)
    fr_entries.append(fr_entry)
    
FiringRateManual.insert(fr_entries)
print("--- %s seconds ---" % (time.time() - start_time))

In this way, we will need to remember which clusters has been computed and inserted. If we insert the same entry twice, there will be an error. For example, let's rerun the above cell. We can overcome that problem by add the argument `skip_duplicates=True` inside `.insert()` or `.insert1()`, but it is not a very elegant solution.  
The best approach here is to use a **Computed** table, it has the exact definition as the previous manual table, but with a magic **make** function

In [None]:
@schema
class FiringRateComputed(dj.Computed):
    definition = """
    -> ephys.Cluster         # Each cluster has a firing rate
    ---
    firing_rate:     float   # Hz
    """
    def make(self, key):
        session_duration = (acquisition.Session & key).proj(
            session_duration='session_end_time - session_start_time').fetch1('session_duration')
        
        spk_times = (ephys.Cluster & key).fetch1('cluster_spike_times')
        firing_rate_entry = dict(**key, firing_rate=len(spk_times)/session_duration)
        self.insert1(firing_rate_entry)

And we can `populate` the table.

In [None]:
FiringRateComputed.populate(display_progress=True)

In [None]:
FiringRateComputed()

**What does `populate` do?** 

It does two major things:  
1. From the table definition, get the keys that needs to computed, which we called `key_source`. By default, it would be the join result of the primary dependent tables minus the once has been computed.  
2. Call `make` function defined in the class, and compute one by one, with each individual key from the `key_source`

Here we still have to insert one by one, which is a bit slow. How do we do the trick of insert all firing rate of clusters in one session together?

We can change the `key_source` by redefining it to a larger scale

In [None]:
@schema
class FiringRateComputedFromSession(dj.Computed):
    definition = """
    -> ephys.Cluster         # Each cluster has a firing rate
    ---
    firing_rate:     float   # Hz
    """
    key_source = ephys.Ephys & ephys.Cluster  # populate for each ephys data set where clustering is available.
    
    def make(self, key): # the key here is now the primary key of ephys.Ephys, instead of ephys.Cluster
        session_duration = (acquisition.Session & key).proj(
            session_duration='session_end_time - session_start_time').fetch1('session_duration')
        
        fr_entries = []
        for icluster in (ephys.Cluster & key).fetch('KEY'):
            spk_times = (ephys.Cluster & icluster).fetch1('cluster_spike_times')
            fr_entry = dict(**icluster,
                            firing_rate=len(spk_times)/session_duration)
            fr_entries.append(fr_entry)
    
        self.insert(fr_entries)

In [None]:
FiringRateComputedFromSession.populate(display_progress=True)

## Delete entries and drop a table

In [None]:
(FiringRateManual & 'cluster_id=0').delete() # any restrictor would work here

In [None]:
FiringRateManual.drop()

## Bonus: How to work with data where you don't have the code to generate the class?

In [None]:
dj.list_schemas()

In [None]:
anne_analyses = dj.create_virtual_module('analyses', 'group_shared_anneurai_analyses')

In [None]:
anne_analyses.Age()