# **Temporal Transformer**
## 1. Motivation

The **Temporal Transformer** extends the **Vision Transformer (ViT)** or **Spatial Transformer** idea into the **time domain**.
If ViT handles *spatial* relationships between image patches in a single frame,
the Temporal Transformer handles *temporal* relationships between frames in a **video sequence**.

In a video, each frame has:

* **Spatial features** — what’s inside the frame (objects, shapes, textures)
* **Temporal features** — how these change over time (motion, actions)

A Temporal Transformer captures **how things evolve over time**, similar to how ViT captures how patches relate within space.

---

## 2. Basic Idea

Suppose you extract features from each frame using a CNN or ViT encoder.
Then you have a sequence of tokens representing time:

$$
X = [x_1, x_2, x_3, \ldots, x_T], \quad x_t \in \mathbb{R}^D
$$

Each $x_t$ encodes the spatial information of frame $t$.

The **Temporal Transformer** applies **self-attention over time**, learning how each frame relates to others.

---

## 3. Temporal Self-Attention

Just like standard attention, we compute queries, keys, and values:

$$
Q = XW^Q, \quad K = XW^K, \quad V = XW^V
$$

Then temporal attention is computed as:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
$$

Here, the attention matrix is of size $T \times T$ — each frame attends to all others.

---

## 4. Combining with Spatial Transformers

There are **two main architectures** for handling video with transformers:

### (a) Factorized (Space + Time separately)

* **Step 1:** Apply a **spatial transformer** on each frame independently.
  This captures spatial relationships.
* **Step 2:** Apply a **temporal transformer** across the resulting frame embeddings.
  This captures motion and temporal dynamics.

Mathematically:

$$
X' = \text{SpatialTransformer}(X)
$$

$$
Y = \text{TemporalTransformer}(X')
$$

Example: **TimeSformer** (Bertasius et al., 2021)

---

### (b) Joint Space–Time Attention

The transformer attends jointly to space and time dimensions using 3D tokens.
This is more expensive but captures both simultaneously.

---

## 5. Temporal Positional Encoding

Just as ViT adds **positional embeddings** to patches in space,
Temporal Transformers add **temporal embeddings** to indicate *when* each frame occurs.

$$
X_t = X_t + P_t
$$

where $P_t$ is a learned vector encoding the frame index.

---

## 6. Example: TimeSformer (2021)

**TimeSformer** uses divided space–time attention:

1. Split the video into **T frames**, each into **N patches**.
2. Flatten them into tokens of size $(T \times N, D)$.
3. In each block:

   * Apply *temporal attention* among the same spatial patch across frames.
   * Then apply *spatial attention* within each frame.

This factorization reduces computation from quadratic in $(T \times N)$
to linear in $T$ and $N$ separately.

---

## 7. Applications

- ✅ **Action recognition**
- ✅ **Video classification**
- ✅ **Motion understanding**
- ✅ **Video captioning**
- ✅ **Dynamic NeRFs / Temporal 3D Reconstruction**

---

## 8. Minimal PyTorch-style Pseudocode

```python
import torch
import torch.nn as nn

class TemporalTransformer(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        # x: [T, B, D]
        attn_out, _ = self.attn(x, x, x)
        x = x + attn_out
        x = self.norm1(x)
        x = x + self.ffn(x)
        x = self.norm2(x)
        return x
```

Here:

* $T$: number of frames (time steps)
* $B$: batch size
* $D$: embedding dimension

---

## Comparing **Spatial**, **Temporal**, and **Spatio-Temporal** 


A **Temporal Transformer** fits into the bigger picture of **video transformers**.
We’ll walk through 3 levels:

1. **Spatial Transformer** (within each frame)
2. **Temporal Transformer** (across frames)
3. **Spatio-Temporal Transformer** (joint space+time attention)

---

## 1. Toy Video Setup

Assume a video of **4 frames** (T = 4).
Each frame has **2 × 2 patches**, so **N = 4 patches per frame**.
Each patch is embedded into a D-dimensional vector (say D = 3 for simplicity).

So the total tokens are:

$$
X \in \mathbb{R}^{T \times N \times D} = \mathbb{R}^{4 \times 4 \times 3}
$$

That means:

| Frame | Patch 0 | Patch 1 | Patch 2 | Patch 3 |
| :---- | :------ | :------ | :------ | :------ |
| F₁    | x₁₁     | x₁₂     | x₁₃     | x₁₄     |
| F₂    | x₂₁     | x₂₂     | x₂₃     | x₂₄     |
| F₃    | x₃₁     | x₃₂     | x₃₃     | x₃₄     |
| F₄    | x₄₁     | x₄₂     | x₄₃     | x₄₄     |

---

## 2. Spatial Transformer (per-frame attention)

In each frame, we look **within** that frame only:

$$
\text{SpatialAttention}(x_t) = \text{softmax}\left(\frac{Q_t K_t^\top}{\sqrt{d_k}}\right)V_t
$$

where each $x_t \in \mathbb{R}^{N \times D}$.

Each frame’s patches attend to one another — e.g., patch 1 of F₁ attends to all 4 patches of F₁.

After applying this independently to all frames, we obtain:

$$
X' = [x'_1, x'_2, x'_3, x'_4]
$$

---

### Visualization

```
Frame 1: [●───●───●───●]   → Spatial attention among its 4 patches
Frame 2: [●───●───●───●]   → Spatial attention
Frame 3: [●───●───●───●]
Frame 4: [●───●───●───●]
```

Each row is processed separately.

---

## 3. Temporal Transformer (across frames)

Now we look **across frames** for each patch location.
For example, all patch-1 tokens across time:

$$
[x'_1(1), x'_2(1), x'_3(1), x'_4(1)]
$$

This sequence tells how patch 1 evolves over time (motion, brightness, etc.).

Temporal self-attention:

$$
\text{TemporalAttention}(x'(i)) =
\text{softmax}\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right)V_i
$$

---

### Visualization

```
Patch 1 across time: ●───●───●───●  (motion of same spatial location)
Patch 2 across time: ●───●───●───●
Patch 3 across time: ●───●───●───●
Patch 4 across time: ●───●───●───●
```

Each **column** is processed separately.

---

## 4. Combined Space–Time Pipeline

Putting it together (like in **TimeSformer**):

```
[ Video Frames ]
   ↓
Patch embedding (T × N × D)
   ↓
Spatial Transformer  → captures spatial structure
   ↓
Temporal Transformer → captures motion / dynamics
   ↓
Classification / Prediction Head
```

This factorization (space first, then time) reduces cost from
$O((T·N)^2)$ to $O(T^2 + N^2)$.

---

## 5. PyTorch-like Illustration

```python
import torch
import torch.nn as nn

class MiniSpaceTimeTransformer(nn.Module):
    def __init__(self, dim=3, num_heads=1):
        super().__init__()
        self.spatial = nn.MultiheadAttention(dim, num_heads)
        self.temporal = nn.MultiheadAttention(dim, num_heads)

    def forward(self, x):
        # x: [B, T, N, D]
        B, T, N, D = x.shape
        # --- Spatial attention (per frame) ---
        spatial_out = []
        for t in range(T):
            xt = x[:, t]            # [B, N, D]
            yt, _ = self.spatial(xt, xt, xt)
            spatial_out.append(yt)
        x_spatial = torch.stack(spatial_out, dim=1)  # [B, T, N, D]

        # --- Temporal attention (per patch index) ---
        temporal_out = []
        for n in range(N):
            xn = x_spatial[:, :, n]                 # [B, T, D]
            yn, _ = self.temporal(xn, xn, xn)
            temporal_out.append(yn)
        x_temporal = torch.stack(temporal_out, dim=2)  # [B, T, N, D]
        return x_temporal
```

This shows the same factorization logic:

* First loop over frames → spatial attention.
* Then loop over patches → temporal attention.

---


Let’s dive into **MiniSpaceTimeTransformer** step by step,

We’ll go through:

1. The **goal** of the model
2. **Input–output** shapes
3. The **forward pass logic** (why and how each line works)
4. How this approximates real **video transformers** like **TimeSformer**

---

## 1. Goal of the Model

The goal of this toy model is to show how to process a **video** represented as a sequence of **frames** (each containing patch embeddings) using **two stages**:

1. **Spatial attention** — learn relationships *within each frame* (e.g., between patches).
2. **Temporal attention** — learn relationships *across frames* (e.g., motion through time).

This is the essence of a **Temporal Transformer** combined with **Spatial Transformer**.

---

## 2. Input and Output

We assume that before entering this transformer, the video has already been divided into patches and projected into embeddings.

So the input tensor:

$$
x \in \mathbb{R}^{B \times T \times N \times D}
$$

where:

* **B** = batch size (number of videos processed at once)
* **T** = number of frames per video
* **N** = number of patches per frame
* **D** = feature dimension (embedding size of each patch)

**Output** has the same shape:
$$
y \in \mathbb{R}^{B \times T \times N \times D}
$$
but now each token (patch embedding) has been refined by both spatial and temporal attention.

---

## 3. The Code (with full explanation)

```python
import torch
import torch.nn as nn

class MiniSpaceTimeTransformer(nn.Module):
    def __init__(self, dim=3, num_heads=1):
        super().__init__()
        self.spatial = nn.MultiheadAttention(dim, num_heads)
        self.temporal = nn.MultiheadAttention(dim, num_heads)
```

### Explanation:

* We define a PyTorch module.
* It has **two attention blocks**:

  * `self.spatial`: handles relationships between patches *within* a frame.
  * `self.temporal`: handles relationships between frames *over time*.
* `dim` is the token embedding size (e.g., 3 in our toy example).
* `num_heads` is the number of attention heads.

These attention layers use the standard Transformer **scaled dot-product attention** mechanism:
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
$$

---

### Forward Pass

```python
    def forward(self, x):
        # x: [B, T, N, D]
        B, T, N, D = x.shape
```

This extracts shape dimensions from the input tensor.

---

### Step 1: Spatial Attention

```python
        # --- Spatial attention (per frame) ---
        spatial_out = []
        for t in range(T):
            xt = x[:, t]            # [B, N, D]
            yt, _ = self.spatial(xt, xt, xt)
            spatial_out.append(yt)
        x_spatial = torch.stack(spatial_out, dim=1)  # [B, T, N, D]
```

#### Explanation in detail:

1. We iterate over each **frame** (time index `t`).
2. Extract all **patches** for that frame:

   * `x[:, t]` gives a tensor of shape `[B, N, D]` (for that time step).
3. Feed it into `nn.MultiheadAttention`:

   * Since `nn.MultiheadAttention` expects `[sequence_length, batch_size, dim]`, in practice you’d often permute dimensions.
     Here we simplify for clarity (PyTorch can broadcast automatically if needed).
4. The attention computes **how patches within that frame attend to each other**, returning:

   * `yt`: the new embedding for each patch (after attending to other patches in the same frame).
5. Append this processed frame to a list.
6. Stack all processed frames back along the time dimension to form:
   $$ X' \in \mathbb{R}^{B \times T \times N \times D} $$

At this point, each frame’s patches know about their spatial context,
but not yet about how frames relate over time.

---

### Step 2: Temporal Attention

```python
        # --- Temporal attention (per patch index) ---
        temporal_out = []
        for n in range(N):
            xn = x_spatial[:, :, n]                 # [B, T, D]
            yn, _ = self.temporal(xn, xn, xn)
            temporal_out.append(yn)
        x_temporal = torch.stack(temporal_out, dim=2)  # [B, T, N, D]
        return x_temporal
```

#### Explanation in detail:

1. Now we iterate over **patch index** `n`.
2. Extract that same patch across all frames:

   * `x_spatial[:, :, n]` → `[B, T, D]`
     (e.g., patch 0 from frame 1, patch 0 from frame 2, etc.)
3. Feed it into the **temporal attention** layer.

   * This learns how this patch changes **across time** (motion, brightness, movement).
4. Append this processed sequence to a list.
5. After processing all patches, stack them again to get:
   $$ Y \in \mathbb{R}^{B \times T \times N \times D} $$

At this stage, every token (patch at frame *t*) has learned:

* Spatial context (from within its own frame)
* Temporal context (from the same patch location across time)

---

## 4. How it Works Conceptually

### Before the model:

* Each token only knows its *own value* — no interaction between patches or frames.

### After spatial attention:

* Each patch embedding knows how it relates to other patches **in the same frame**.
  For instance, a patch containing part of a “hand” attends to a nearby “arm” patch.

### After temporal attention:

* Each patch embedding learns **how it changes over time**.
  For example, that same “hand” patch learns that it moves upward between frame 2 and frame 4.

Together, the model builds a **rich spatio-temporal representation**.

---

## 5. Complexity and Relation to Real Models

In full models like **TimeSformer** or **ViViT**:

* Spatial and temporal blocks are stacked in multiple layers.
* Tokens are projected into higher dimensions (e.g., D = 768).
* Positional encodings are added for both spatial and temporal positions.
* Classification is done by a [CLS] token that aggregates all context.

This toy model captures the same *core mechanism*:
$$
\text{Video Representation} = f_{\text{Temporal}}\big(f_{\text{Spatial}}(X)\big)
$$

but in a minimal, educational way.

---

## 6. Key Takeaways

| Step | Transformer Type | Input Shape         | Learns                | Effect                 |
| ---- | ---------------- | ------------------- | --------------------- | ---------------------- |
| 1    | Spatial          | [B, N, D] per frame | Patch–Patch relations | Scene layout per frame |
| 2    | Temporal         | [B, T, D] per patch | Frame–Frame relations | Motion / dynamics      |

So each token evolves as:

$$
x_{t,n}^{(out)} = f_{temp}\big(f_{spatial}(x_{t,n}^{(in)})\big)
$$

---




## 6. Spatio-Temporal Transformer (joint attention)

If we **don’t** separate space and time, we flatten everything:

$$
X \in \mathbb{R}^{(T \times N) \times D}
$$

and perform attention directly:

$$
\text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
$$

Now each patch in each frame can attend to **every other patch in every frame** —
richer but more expensive ($O((T·N)^2)$).

---

## 7. Summary Table

| Model Type                | Attention     | Complexity | Example              |
| ------------------------- | ------------- | ---------- | -------------------- |
| Spatial only              | Within frame  | O(N²)      | ViT                  |
| Temporal only             | Across frames | O(T²)      | Temporal Transformer |
| Factorized (space → time) | Separate      | O(N² + T²) | TimeSformer          |
| Joint space-time          | Global        | O((TN)²)   | ViViT, Video Swin    |

---

## 8. Intuition Recap

* **Spatial transformer:** What is happening *in* each frame?
* **Temporal transformer:** How does it *change over time*?
* **Spatio-temporal:** Combines both in one attention mechanism.

---