# Custom Codecs

This tutorial covers extending DataJoint's type system. You'll learn:

- **Codec basics** — Encoding and decoding
- **Creating codecs** — Domain-specific types
- **Codec chaining** — Composing codecs

In [1]:
import datajoint as dj
import numpy as np

schema = dj.Schema('tutorial_codecs')

[2026-01-23 17:40:02,849][INFO]: DataJoint 2.1.0a6 connected to root@127.0.0.1:3306


## Creating a Custom Codec

In [2]:
import networkx as nx

class GraphCodec(dj.Codec):
    """Store NetworkX graphs."""
    
    name = "graph"  # Use as <graph>
    
    def get_dtype(self, is_store: bool) -> str:
        return "<blob>"
    
    def encode(self, value, *, key=None, store_name=None):
        return {'nodes': list(value.nodes(data=True)), 'edges': list(value.edges(data=True))}
    
    def decode(self, stored, *, key=None):
        g = nx.Graph()
        g.add_nodes_from(stored['nodes'])
        g.add_edges_from(stored['edges'])
        return g
    
    def validate(self, value):
        if not isinstance(value, nx.Graph):
            raise TypeError(f"Expected nx.Graph")

In [3]:
@schema
class Connectivity(dj.Manual):
    definition = """
    conn_id : int32
    ---
    network : <graph>
    """

In [4]:
# Create and insert
g = nx.Graph()
g.add_edges_from([(1, 2), (2, 3), (1, 3)])
Connectivity.insert1({'conn_id': 1, 'network': g})

# Fetch
result = (Connectivity & {'conn_id': 1}).fetch1('network')
print(f"Type: {type(result)}")
print(f"Edges: {list(result.edges())}")

Type: <class 'networkx.classes.graph.Graph'>
Edges: [(1, 2), (1, 3), (2, 3)]


## Codec Structure

```python
class MyCodec(dj.Codec):
    name = "mytype"  # Use as <mytype>
    
    def get_dtype(self, is_store: bool) -> str:
        return "<blob>"  # Storage type
    
    def encode(self, value, *, key=None, store_name=None):
        return serializable_data
    
    def decode(self, stored, *, key=None):
        return python_object
    
    def validate(self, value):  # Optional
        pass
```

## Example: Spike Train

In [5]:
from dataclasses import dataclass

@dataclass
class SpikeTrain:
    times: np.ndarray
    unit_id: int
    quality: str

class SpikeTrainCodec(dj.Codec):
    name = "spike_train"
    
    def get_dtype(self, is_store: bool) -> str:
        return "<blob>"
    
    def encode(self, value, *, key=None, store_name=None):
        return {'times': value.times, 'unit_id': value.unit_id, 'quality': value.quality}
    
    def decode(self, stored, *, key=None):
        return SpikeTrain(times=stored['times'], unit_id=stored['unit_id'], quality=stored['quality'])

In [6]:
@schema
class Unit(dj.Manual):
    definition = """
    unit_id : int32
    ---
    spikes : <spike_train>
    """

train = SpikeTrain(times=np.sort(np.random.uniform(0, 100, 50)), unit_id=1, quality='good')
Unit.insert1({'unit_id': 1, 'spikes': train})

result = (Unit & {'unit_id': 1}).fetch1('spikes')
print(f"Type: {type(result)}, Spikes: {len(result.times)}")

Type: <class '__main__.SpikeTrain'>, Spikes: 50


In [7]:
schema.drop(prompt=False)