# Numba Typed Dictionaries

Examples of using Numba's typed dictionaries with JIT-compiled functions.

In [1]:
import numpy as np
import hashlib
from numba import njit, types
from numba.typed import Dict
import time

# Import loader for real data experiments
from specparser.amt import loader


# --- Hash and fix functions (from schedules.py) ---

def _underlying_hash(underlying: str) -> int:
    """
    Compute a deterministic integer hash from an underlying string.
    Uses MD5 to get a consistent hash across runs, then takes modulo
    to get a small integer suitable for the fix calculation.
    """
    h = hashlib.md5(underlying.encode()).hexdigest()
    return int(h[:8], 16) % 1000000


def _fix_value(value: str, assid: int, schcnt: int, schid: int) -> str:
    """Fix a/b/c/d values to computed day numbers."""
    if value not in ("a", "b", "c", "d"):
        return value
    if schcnt > 0:
        day_offset = int(assid % 5 + 1)
        day_stride = int(20 / (schcnt + 1))
        fixed = int(schid - 1) * day_stride + day_offset
    else:
        fixed = int(assid % 5 + 1)
    return str(fixed)

## 1. Basic Typed Dict

Create and populate a typed dictionary outside JIT.

In [2]:
# Create a typed dict with int64 keys and float64 values
d = Dict.empty(
    key_type=types.int64,
    value_type=types.float64,
)

# Populate it
d[1] = 1.5
d[2] = 2.5
d[3] = 3.5

print(f"Dict: {dict(d)}")
print(f"d[2] = {d[2]}")
print(f"len(d) = {len(d)}")

Dict: {1: 1.5, 2: 2.5, 3: 3.5}
d[2] = 2.5
len(d) = 3


## 2. JIT Function with Dict Input

Pass a typed dict to a JIT-compiled function.

In [3]:
@njit
def sum_dict_values(d):
    """Sum all values in a typed dict."""
    total = 0.0
    for k in d:
        total += d[k]
    return total

@njit
def lookup_with_default(d, key, default):
    """Lookup with default value if key not found."""
    if key in d:
        return d[key]
    return default

# Test
print(f"Sum of values: {sum_dict_values(d)}")
print(f"d[2] = {lookup_with_default(d, 2, -1.0)}")
print(f"d[99] = {lookup_with_default(d, 99, -1.0)}")

Sum of values: 7.5
d[2] = 2.5
d[99] = -1.0


## 3. JIT Function that Creates Dict

Build a typed dict inside a JIT-compiled function.

In [4]:
@njit
def create_squares_dict(n):
    """Create a dict mapping i -> i^2 for i in range(n)."""
    d = Dict.empty(
        key_type=types.int64,
        value_type=types.int64,
    )
    for i in range(n):
        d[i] = i * i
    return d

squares = create_squares_dict(10)
print(f"Squares: {dict(squares)}")

Squares: {0: 0, 1: 1, 2: 4, 3: 9, 4: 16, 5: 25, 6: 36, 7: 49, 8: 64, 9: 81}


## 4. Dict with String Keys

Using unicode strings as keys (requires `types.unicode_type`).

In [5]:
# Create dict with string keys
str_dict = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64,
)

str_dict["apple"] = 1.0
str_dict["banana"] = 2.0
str_dict["cherry"] = 3.0

print(f"str_dict['banana'] = {str_dict['banana']}")

str_dict['banana'] = 2.0


In [6]:
@njit
def lookup_string_key(d, key):
    """Lookup a string key in a typed dict."""
    if key in d:
        return d[key]
    return -1.0

print(f"lookup('apple') = {lookup_string_key(str_dict, 'apple')}")
print(f"lookup('missing') = {lookup_string_key(str_dict, 'missing')}")

lookup('apple') = 1.0
lookup('missing') = -1.0


## 5. Performance Comparison

Compare lookup speed: Numba typed dict vs Python dict.

In [7]:
# Create large dicts
N = 100_000

# Python dict
py_dict = {i: float(i) for i in range(N)}

# Numba typed dict
nb_dict = Dict.empty(key_type=types.int64, value_type=types.float64)
for i in range(N):
    nb_dict[i] = float(i)

# Keys to lookup
keys = np.random.randint(0, N, size=1_000_000)

In [8]:
@njit
def numba_lookups(d, keys):
    """Perform many lookups in a typed dict."""
    total = 0.0
    for k in keys:
        total += d[k]
    return total

def python_lookups(d, keys):
    """Perform many lookups in a Python dict."""
    total = 0.0
    for k in keys:
        total += d[k]
    return total

# Warmup
_ = numba_lookups(nb_dict, keys[:100])

In [9]:
# Benchmark Python dict
t0 = time.perf_counter()
py_result = python_lookups(py_dict, keys)
py_time = time.perf_counter() - t0
print(f"Python dict: {py_time:.3f}s")

# Benchmark Numba typed dict
t0 = time.perf_counter()
nb_result = numba_lookups(nb_dict, keys)
nb_time = time.perf_counter() - t0
print(f"Numba dict:  {nb_time:.3f}s")

print(f"\nSpeedup: {py_time / nb_time:.1f}x")
print(f"Results match: {abs(py_result - nb_result) < 1e-6}")

Python dict: 0.094s
Numba dict:  0.009s

Speedup: 10.8x
Results match: True


## 6. Real Data: Asset Lookup Dict

Use the loader module to build a lookup dict from AMT data.

In [10]:
# Load AMT data
amt_path = "../data/amt.yml"
amt_data = loader.load_amt(amt_path)

# Get all assets
assets_table = loader.assets(amt_path, live_only=True)
print(f"Found {len(assets_table['rows'])} live assets")

Found 189 live assets


In [11]:
# Build a string -> int index lookup
asset_to_idx = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.int64,
)

for i, row in enumerate(assets_table['rows']):
    asset_name = row[0]
    asset_to_idx[asset_name] = i

print(f"Built lookup for {len(asset_to_idx)} assets")
print(f"First 5: {list(asset_to_idx.keys())[:5]}")

Built lookup for 189 assets
First 5: ['LA Comdty', 'LP Comdty', 'LX Comdty', 'LA Comdty TKRZ', 'LP Comdty TKRZ']


In [12]:
@njit
def get_asset_index(lookup, name):
    """Get asset index from lookup dict."""
    if name in lookup:
        return lookup[name]
    return -1

# Test with a known asset
test_asset = list(asset_to_idx.keys())[0]
print(f"Index of '{test_asset}': {get_asset_index(asset_to_idx, test_asset)}")
print(f"Index of 'UNKNOWN': {get_asset_index(asset_to_idx, 'UNKNOWN')}")

Index of 'LA Comdty': 0
Index of 'UNKNOWN': -1


## 7. Nested Data: Dict of Arrays

Store numpy arrays as dict values (requires specific array type).

In [13]:
from numba.core import types as nb_types

# Dict mapping int -> 1D float64 array
array_dict = Dict.empty(
    key_type=types.int64,
    value_type=types.float64[:],  # 1D array type
)

# Populate with arrays of different sizes
array_dict[0] = np.array([1.0, 2.0, 3.0])
array_dict[1] = np.array([10.0, 20.0])
array_dict[2] = np.array([100.0])

print(f"array_dict[0] = {array_dict[0]}")
print(f"array_dict[1] = {array_dict[1]}")

array_dict[0] = [1. 2. 3.]
array_dict[1] = [10. 20.]


In [54]:
from pathlib import Path
import yaml

with open("../data/amt.yml", "r") as f:
    run_options = yaml.safe_load(f)
amt = run_options.get("amt", {})
expiry_schedules = run_options.get("expiry_schedules")


s1 = expiry_schedules[amt["EURUSD Curncy"]["Options"]]
s1v = np.array(s1,dtype="<U32")
print(type(s1v))

# Dict mapping int -> 1D float64 array
val_type = types.Array(types.UnicodeCharSeq(32), 1, 'C')  # 1D contiguous array of U32

array_dict = Dict.empty(
    key_type=types.unicode_type,
    value_type=val_type,  # 1D array type
)

# Populate with arrays of different sizes
array_dict["0"] = s1v
array_dict["1"] = np.array(["that"],dtype="<U32")
array_dict["2"] = np.array(["there","1","123","0123456789"],dtype="<U32")

print(f"array_dict[0] = {array_dict["0"]}")
print(f"array_dict[1] = {array_dict["1"]}")
print(f"array_dict[2] = {array_dict["2"]}")


<class 'numpy.ndarray'>
array_dict[0] = ['N0_BDa_25' 'N0_BDb_25' 'N0_BDc_25' 'N0_BDd_25']
array_dict[1] = ['that']
array_dict[2] = ['there' '1' '123' '0123456789']


In [14]:
@njit
def sum_array_at_key(d, key):
    """Sum the array stored at a given key."""
    if key in d:
        arr = d[key]
        total = 0.0
        for i in range(len(arr)):
            total += arr[i]
        return total
    return 0.0

print(f"Sum at key 0: {sum_array_at_key(array_dict, 0)}")
print(f"Sum at key 1: {sum_array_at_key(array_dict, 1)}")
print(f"Sum at key 99: {sum_array_at_key(array_dict, 99)}")

Sum at key 0: 6.0
Sum at key 1: 30.0
Sum at key 99: 0.0


## 8. Tuple Keys

Using tuples as dictionary keys.

In [15]:
from numba import types

# Dict with (int64, int64) tuple keys
tuple_key_type = types.UniTuple(types.int64, 2)

coord_dict = Dict.empty(
    key_type=tuple_key_type,
    value_type=types.float64,
)

# Populate
coord_dict[(0, 0)] = 1.0
coord_dict[(1, 0)] = 2.0
coord_dict[(0, 1)] = 3.0
coord_dict[(1, 1)] = 4.0

print(f"coord_dict[(1, 1)] = {coord_dict[(1, 1)]}")

coord_dict[(1, 1)] = 4.0


In [16]:
@njit
def sparse_matrix_lookup(d, row, col, default):
    """Lookup (row, col) in sparse matrix dict."""
    key = (row, col)
    if key in d:
        return d[key]
    return default

print(f"(1, 1) = {sparse_matrix_lookup(coord_dict, 1, 1, 0.0)}")
print(f"(2, 2) = {sparse_matrix_lookup(coord_dict, 2, 2, 0.0)}")

(1, 1) = 4.0
(2, 2) = 0.0


## 9. Asset & Schedule Data Structures

Two separate data structures for asset data:
1. **asset_data_nb**: Numba `Dict[str, str]` - flattened asset fields with dot-paths
2. **schedule_data**: Python dict of numpy StringDType arrays - columnar schedule data

In [17]:
def flatten_dict(d, prefix=""):
    """
    Flatten a nested dict to dot-separated keys, all values as strings.
    
    Examples:
        {"a": 1, "b": {"c": 2}} -> {"a": "1", "b.c": "2"}
        {"x": [10, 20, 30]} -> {"x.0": "10", "x.1": "20", "x.2": "30"}
    """
    result = {}
    for key, value in d.items():
        full_key = f"{prefix}.{key}" if prefix else str(key)
        
        if isinstance(value, dict):
            # Recurse into nested dict
            result.update(flatten_dict(value, full_key))
        elif isinstance(value, (list, tuple)):
            # Expand list/tuple with numeric indices
            for i, item in enumerate(value):
                item_key = f"{full_key}.{i}"
                if isinstance(item, dict):
                    result.update(flatten_dict(item, item_key))
                else:
                    result[item_key] = str(item)
        else:
            # Convert scalar to string
            result[full_key] = str(value)
    return result


def dict_to_numba(py_dict):
    """Convert a Python dict to a Numba typed dict (string -> string)."""
    nb_dict = Dict.empty(
        key_type=types.unicode_type,
        value_type=types.unicode_type,
    )
    for k, v in py_dict.items():
        nb_dict[k] = v
    return nb_dict

In [28]:
# Get asset data for a specific asset
test_asset = "EURUSD Curncy"
asset_data = loader.get_asset(amt_path, test_asset)

# Get schedule directly from AMT data (inline, no schedules module)
schedule_name = asset_data.get("Options")
expiry_schedules = amt_data.get("expiry_schedules", {})
raw_schedule = expiry_schedules.get(schedule_name, []) if schedule_name else []

# Compute asset hash for _fix_value
assid = _underlying_hash(test_asset)
schcnt = len(raw_schedule)
print(f"Asset: {test_asset}")
print(f"Asset hash (assid): {assid}")
print(f"Schedule count: {schcnt}")
print(f"Schedule name: {schedule_name}")
print(f"Raw schedule: {raw_schedule}")
print()

# Parse schedule into columnar format
# Format: "ntrc_xprc+xprv_wgt" e.g. "1_F-3_1"
schcnt_vec = np.full(schcnt,str(schcnt),dtype=np.dtypes.StringDType())
schid_vec  = np.full(schcnt,"",dtype=np.dtypes.StringDType())
ntrc_vec   = np.full(schcnt,"",dtype=np.dtypes.StringDType())
ntrv_vec   = np.full(schcnt,"",dtype=np.dtypes.StringDType())
xprc_vec   = np.full(schcnt,"",dtype=np.dtypes.StringDType())
xprv_vec   = np.full(schcnt,"",dtype=np.dtypes.StringDType())
wgt_vec    = np.full(schcnt,"",dtype=np.dtypes.StringDType())

for i, entry in enumerate(raw_schedule):
    schid_vec[i]= str(i)
    schid = i + 1  # 1-based index for fix_value
    print(f"{entry}")
    parts = entry.split("_")
    if len(parts) < 3: continue
    
    # entry_part is just ntrc (a letter like "N"), ntrv follows
    ntrc_vec[i] = parts[0][:1]
    ntrv_vec[i] = parts[0][1:]
    
    # expiry_part parsing
    expiry_part = parts[1]
    if parts[1] == "OVERRIDE":
        xprc_vec[i] = "OVERRIDE"
        xprv_vec[i] = ""
    elif parts[1][:2] == "BD":
        # BD schedule: xprc="BD", xprv is single letter a/b/c/d after "BD"
        xprc_vec[i] = "BD"
        xprv_vec[i] = _fix_value( parts[1][2:] , assid, schcnt, schid)
    elif parts[1][0] in ("F", "R", "W"):
        xprc_vec[i] = parts[1][0]
        xprv_vec[i] = parts[1][1:]
    else:
        xprc_vec[i] = parts[1]
        xprv_vec[i] = ""
    
    wgt_vec[i] = parts[2]

print()
print("schcnt|schid|asset|ntrc|ntrv|xprc|xprv|wgt")
print("-" * 60)
for i in range(len(ntrc_list)):
    print(
        f"{schcnt_vec[i]}|"
        f"{schid_vec[i]}|"
        f"{ntrc_vec[i]}|"
        f"{ntrv_vec[i]}|"
        f"{xprc_vec[i]}|"
        f"{xprv_vec[i]}|"
        f"{wgt_vec[i]}"
    )

Asset: EURUSD Curncy
Asset hash (assid): 479570
Schedule count: 4
Schedule name: schedule1
Raw schedule: ['N0_BDa_25', 'N0_BDb_25', 'N0_BDc_25', 'N0_BDd_25']

N0_BDa_25
N0_BDb_25
N0_BDc_25
N0_BDd_25

schcnt|schid|asset|ntrc|ntrv|xprc|xprv|wgt
------------------------------------------------------------
4|0|N|0|BD|1|25
4|1|N|0|BD|5|25
4|2|N|0|BD|9|25
4|3|N|0|BD|13|25
