# DataJoint pipeline design - on "merging" of pipelines

For DataJoint users working with existing pipelines or designing new ones, a common design question is on the topic of joining or merging different "branches" of the pipeline at one common node/table. 

To elaborate a bit more on this topic, let's say that in your workflow, there may be multiple sources of data that may need to go through different processing/analysis routines. But these different routines ultimately arrive at a point where the data format will be identical and can be further processed downstream in the same manner. 

In this notebook, we will go through one approach to address this design question.

## Let's consider one example scenario

To be more concrete, let me start with an example. Let's say that we are interested in tracking the position of an animal during a freely behaving experiment. Ultimately, we want the `(x, y)` coordinates of the animal over time. In this example, for one experimental session, we are tracking the animal's position using either one of the two methods below:
1. Placing a marker on the body of the animal and track this marker with a set of cameras
2. Using computer vision approach to analyse the position of the animal from the video recording of a camera

With each of the two tracking methods above, the processing and analysis will be different, and being DataJoint users, we'll design a set of tables to define the processing/analysis routine for each method.

So there will likely be two pipeline branches going in parallel, but will need to be merged together at the point where the extraction of `(x, y)` coordinates over time is completed. As there will be set of analyses downstream to be done on top of the extracted animal position, regardless of which method the tracking of a particular exprimental session comes about. 


## The pipeline for this scenario

Let's put together an example DataJoint pipeline describing this scenario

In [None]:
import datajoint as dj
import numpy as np
import hashlib
import uuid
import itertools

In [None]:
dj.conn()

In [None]:
schema = dj.Schema('ttngu207_pipeline_merging_2')

In [None]:
@schema
class Session(dj.Manual):
    definition = """
    animal_name: varchar(16)
    session_number: int
    """

In [None]:
@schema
class MethodOneTrackingRaw(dj.Imported):
    definition = """
    -> Session
    ---
    tracking_data: longblob
    """
    
    
@schema
class MethodOneProcessing(dj.Computed):
    definition = """
    -> MethodOneTrackingRaw
    ---
    tracking_data: longblob
    """
    
    
@schema
class MethodOneTracking(dj.Computed):
    definition = """
    -> MethodOneProcessing
    ---
    x: longblob
    y: longblob
    t: longblob
    """    

In [None]:
@schema
class MethodTwoTrackingRaw(dj.Imported):
    definition = """
    -> Session
    ---
    tracking_data: longblob
    """
    
    
@schema
class MethodTwoProcessing(dj.Computed):
    definition = """
    -> MethodTwoTrackingRaw
    ---
    tracking_data: longblob
    """
    

@schema
class FilterParam(dj.Lookup):
    definition = """
    param_id: int
    ---
    sigma: float
    """
    
    contents = [(0, 1), (1, 10)]

    
@schema
class MethodTwoFiltering(dj.Computed):
    definition = """
    -> MethodTwoProcessing
    -> FilterParam
    ---
    filtered_tracking_data: longblob
    """
    
    
@schema
class MethodTwoTracking(dj.Computed):
    definition = """
    -> MethodTwoFiltering
    ---
    x: longblob
    y: longblob
    t: longblob
    """    

In [None]:
dj.Diagram(schema)

## How to "merge" these two branches?

The next step in our pipeline is to run a number of analysis routines on the animal position data, using the `x, y, t` arrays as inputs. And we don't particularly care if the animal position data from a session is from method one or two, as long as we can work with the `x`, `y` and `t` arrays. 

Here, I will proposal a tables merging design approach. I opt for the term "merging" here to describe this joining/merging design to avoid confusion with DataJoint's `join`.

Consider the design below

In [None]:
@schema
class MergedTracking(dj.Manual):
    definition = """
    merged_tracking: uuid
    """
    
    class MethodOneTracking(dj.Part):
        definition = """
        -> master
        ---
        -> MethodOneTracking
        """
        
    class MethodTwoTracking(dj.Part):
        definition = """
        -> master
        ---
        -> MethodTwoTracking
        """
    
    
@schema
class Speed(dj.Computed):
    definition = """
    -> MergedTracking
    ---
    speed: longblob
    """

In [None]:
dj.Diagram(schema)

In the prototype above, the `MergedTracking` table is a `dj.Manual` table allowing for the merging of the two different branches of tracking data. 
The primary key is a single uuid-type attribute, with no non-primary attribute.
The connection to the upstream tables to be merged is done via part-tables, with one-to-one relationship to the master.

One uuid entry here should uniquely specify one "tracking" for this session, either method one ***or*** method two. 
The keyword being ***or***, thus, there must be only one part-table having an entry corresponding to one entry in the master table, and none from the other part-tables.

This design will allow for merging of tables from different branches of the pipeline (or from different pipelines), and fairly easily extendable. For example, say in the future there will be another tracking method, `MethodThreeTracking`, this can be added to the `MergedTracking` by introducing another part-table.  

## What's the catch?

There are a few caveats in this design, I'm listing below two major ones. However, I'd say these are more inconveniences rather than design flaws or drawbacks. 

1. `UUID`-type primary attribute. The fact that the `MergedTracking` has single attribute of type `uuid` causing somewhat of a "disconnection" between this merging table and the upstream. The connection to upstream is established by the non-primary foreign keys. Three points of inconveniences from this:
    + To link to the upstream tables, we always have to do a `join (*)` with this table and its part-tables in queries
    + Cannot use this as a `dj.Imported` or `dj.Computed` - DataJoint native `autopopulate` would not work
    + `.insert()` is hard to use, as the `uuid` has to be generated somehow
    
2. From the database perspective, this table design does not guarantee mutual exclusivity of the member tables to be merged. This means just purely from the table definition, one can have an entry in `MergedTracking` with corresponding entries in both the `MethodOneTracking` and `MethodTwoTracking` part-tables, violating the "***or***" intention. 

##### Can we enhance the experience?

To enhance the usage experience, we can overwrite the `.insert()` method to:
1. auto-generate the ***uuid*** 
2. insert also to the part-table
3. ensure mutual exclusivity of member tables to be merged

We can also introduce a convenient property `.all_joined` to:
1. return the left join of the master tables with all of its parts
2. downstream queries only need to join with `MergedTracking().all_joined` to be able to reference to the upstream tables being merged.


In [None]:
@schema
class MergedTracking(dj.Manual):
    definition = """
    merged_tracking: uuid
    """
    
    class MethodOneTracking(dj.Part):
        definition = """
        -> master
        ---
        -> MethodOneTracking
        """
        
    class MethodTwoTracking(dj.Part):
        definition = """
        -> master
        ---
        -> MethodTwoTracking
        """
    
    @property
    def all_joined(self):
        parts = self.parts(as_objects=True)
        primary_attrs = list(dict.fromkeys(itertools.chain.from_iterable([p.heading.names for p in parts])))
        
        query = dj.U(*primary_attrs) * parts[0].proj(..., **{a: 'NULL' for a in primary_attrs if a not in parts[0].heading.names})
        for part in parts[1:]:
            query += dj.U(*primary_attrs) * part.proj(..., **{a: 'NULL' for a in primary_attrs if a not in part.heading.names})

        return query
    
    @classmethod
    def insert(cls, rows, **kwargs):
        """
        :param rows: An iterable where an element is a dictionary.
        """
        
        try:
            for r in iter(rows):
                assert isinstance(r, dict), 'Input "rows" must be a list of dictionaries'
        except TypeError:
                raise TypeError('Input "rows" must be a list of dictionaries')
        
        parts = cls.parts(as_objects=True)
        master_entries = []
        parts_entries = {p: [] for p in parts}
        for row in rows:
            key = {}
            for part in parts:
                parent = part.parents(as_objects=True)[-1]
                if parent & row:
                    if not key:
                        key = (parent & row).fetch1('KEY')
                        master_key = {cls.primary_key[0]: dj.hash.key_hash(key)}
                        parts_entries[part].append({**master_key, **key})
                        master_entries.append(master_key)
                    else:
                        raise ValueError(f'Mutual Exclusivity Error! Entry exists in more than one parent table - Entry: {row}')
            
            if not key:
                raise ValueError(f'Non-existing entry in any of the parent tables - Entry: {row}')
        
        with cls.connection.transaction:
            super().insert(cls(), master_entries, **kwargs)
            for part, part_entries in parts_entries.items():
                part.insert(part_entries, **kwargs)
        

## Pipeline in action

#### First, let's populate these tables with some mock data

Let's create 4 sessions with mock data.

Sessions 1 and 2 will be using method one, and session 3 and 4 will be using method two for tracking

In [None]:
Session.insert([('subject1', 1), ('subject1', 2), ('subject1', 3), ('subject1', 4)])

In [None]:
MethodOneTrackingRaw.insert([('subject1', 1, np.random.randn(10)), ('subject1', 2, np.random.randn(10))], allow_direct_insert=True)
MethodOneProcessing.insert([('subject1', 1, np.random.randn(10)), ('subject1', 2, np.random.randn(10))], allow_direct_insert=True)

In [None]:
MethodTwoTrackingRaw.insert([('subject1', 3, np.random.randn(10)), ('subject1', 4, np.random.randn(10))], allow_direct_insert=True)
MethodTwoProcessing.insert([('subject1', 3, np.random.randn(10)), ('subject1', 4, np.random.randn(10))], allow_direct_insert=True)

In [None]:
MethodTwoFiltering.insert([('subject1', 3, 0, np.random.randn(10)), 
                           ('subject1', 4, 0, np.random.randn(10))], allow_direct_insert=True)

In [None]:
MethodOneTracking.insert([('subject1', 1, np.random.randn(10), np.random.randn(10), np.arange(10)), 
                          ('subject1', 2, np.random.randn(10), np.random.randn(10), np.arange(10))], allow_direct_insert=True)

In [None]:
MethodTwoTracking.insert([('subject1', 3, 0, np.random.randn(10), np.random.randn(10), np.arange(10)), 
                           ('subject1', 4, 0, np.random.randn(10), np.random.randn(10), np.arange(10))], allow_direct_insert=True)

In [None]:
MethodOneTracking()

In [None]:
MethodTwoTracking()

#### Now, let's generate the corresponding entries in the `MergedTracking` table

In [None]:
method_one_entries = MethodOneTracking.fetch('KEY')
method_two_entries = MethodTwoTracking.fetch('KEY')

In [None]:
MergedTracking.insert(method_one_entries)

In [None]:
MergedTracking()

In [None]:
MergedTracking.insert(method_two_entries)

In [None]:
MergedTracking()

#### Using the `.all_joined`

In [None]:
MergedTracking().all_joined

#### A few more example queries

In [None]:
MergedTracking().all_joined & 'animal_name = "subject1"' & 'session_number = 3'

In [None]:
MethodOneTrackingRaw & (MergedTracking().all_joined & 'animal_name = "subject1"' & 'session_number = 3')

In [None]:
MethodTwoTrackingRaw & (MergedTracking().all_joined & 'animal_name = "subject1"' & 'session_number = 3')

In [None]:
Speed & (MergedTracking().all_joined & 'animal_name = "subject1"' & 'session_number = 3')