# NNEinFact: Near-Universal Nonnegative Einsum Factorization
This notebook demonstrates how to fit a custom spatiotemporal tensor factorization model to **Uber pickup data** in New York City.  Using **NNEinFact**, we can approximate a 5-mode count tensor $Y \in \mathbb{N}_0^{27 \times 24 \times 7 \times 100 \times 100}$ (week, hour, day, and spatial indices $i, j$  corresponding to latitudes and longitudes).

### Model Specification
We specify a custom factorization designed to capture temporal patterns ($r$) and temporal-specific spatial factors ($k$):

**Model String:** `'wr,hr,dr,ikr,jkr->whdij'`

This corresponds to the element-wise form:
$$\hat{y}_{wdhij} = \sum_{r=1}^{R} \theta_{wr}^{(1)} \theta_{dr}^{(2)} \theta_{hr}^{(3)} \sum_{r=1}^{R} \theta_{irk}^{(4)} \theta_{jrk}^{(5)}$$


In [None]:
import numpy as np
from einfact import NNEinFact

Y = np.load('data/Y.npz')['Y'] #Uber data: Y.shape = (27,24,7,100,100)

model_str = 'wr,hr,dr,ikr,jkr->whdij' #example model from the paper

shape_dict = {**dict(zip(model_str.split('->')[-1], Y.shape)), 'k': 6, 'r': 10} #pass the shape of each mode and the number of components for latent modes

model = NNEinFact(model_str, shape_dict=shape_dict, device='cuda', alpha=0.5, beta = 0.5)

model.fit(Y, verbose=True)

Theta_WK, Theta_HK, Theta_DK, Theta_IKR, Theta_JKR = model.get_params() #get the estimated parameters

### Visualizing Latent Structure
After training, we can visualize the estimated parameters $\{\Theta^{(l)}\}_{l=1}^L$ to extract interpretable patterns. For example, the factor matrix corresponding to the **Hour of Day** mode reveals how similar times of the day are typically clustered together.

`get_params()` retrieves these factors directly from the device (CPU/GPU) as NumPy arrays.

In [None]:
from matplotlib import pyplot as plt

plt.imshow(Theta_HK, aspect='auto')
plt.xticks(range(10), labels=[f'{i}' for i in range(1,11)])
plt.xlabel('Latent Classes (k)')
plt.ylabel('Hours of Day')
plt.yticks(range(1, 24, 3), labels= [f'{i} AM' if i < 12 else (f'{i-12} PM' if i > 12 else '12 PM') for i in range(1, 24, 3)])
plt.show()