# OCR-GAN Video Model Testing

This notebook contains essential tests for the OCR-GAN video model with tqdm progress bars.

## 📄 **Research Paper Foundation: Omni-frequency Channel-selection Representations**

### 🔬 **Theoretical Background**

The OCR-GAN Video model is based on the research paper **"Omni-frequency Channel-selection Representations for Unsupervised Anomaly Detection"** which introduces a novel approach to video anomaly detection through **multi-frequency feature learning** and **adaptive channel selection**.

#### **🎯 Core Research Contributions:**

1. **🌊 Omni-frequency Feature Learning**
   - **Laplacian Stream**: Captures **high-frequency details** (edges, textures)
   - **Residual Stream**: Captures **low-frequency structures** (global patterns, shapes)
   - **Multi-scale Processing**: Combines both streams for comprehensive representation

2. **🔀 Channel Selection Mechanism (CS)**
   - **Adaptive weighting** between Laplacian and Residual streams
   - **Attention-based selection** using Global Average Pooling (GAP)
   - **Context-aware** feature mixing based on input characteristics

3. **🎬 Temporal Video Processing**
   - **16-frame snippets** for temporal context modeling
   - **Frame-wise processing** with temporal consistency
   - **Video-level anomaly scoring** through reconstruction error

---

### 🧮 **Mathematical Framework**

#### **1. Omni-frequency Decomposition:**

For input video frame $I \in \mathbb{R}^{C \times H \times W}$:

```,
Fused Features:        F = x₁ + x₂
Global Context:        G = GAP(F) ∈ ℝᶜ
Attention Weights:     α = Softmax(FC(G)) ∈ ℝ²ˣᶜ
Output:               y₁ = α₁ ⊙ x₁,  y₂ = α₂ ⊙ x₂
```

#### **3. Video-level Loss Function:**

```,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
2
3
,
,
,
2
,
,
,

,
1
,
,
,
,
,
,
,
,
,
,
2
,
,
,
3
,
2
,
,
4
,
,
,
,
,
,
,
,
,
,
,
,
,
,
1
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
64
,
,
16
3
64
64
,
,
,
,
,
,
,
,
,
,
,
,
,
🚀 Training Progress")
for epoch in epoch_bar:
    train_bar = tqdm(enumerate(self.data.train, 0), 
                    total=len(self.data.train),
                    desc=f"📈 Training Epoch {epoch+1}")
    
    for i, data in train_bar:
        self.set_input(data)
        self.optimize_params()
        errors = self.get_errors()
        
        # Update progress bar with loss info
        train_bar.set_postfix({
            'G_loss': f"{errors['err_g']:.4f}",
            'D_loss': f"{errors['err_d']:.4f}"
        })
```

#### **6. Channel Shuffling Enhancement:**

The **Channel Shuffling (CS)** in `UnetGenerator_CS` improves feature mixing:

```python
class UnetSkipConnectionBlock_CS(nn.Module):
    def __init__(self, layer_id, outer_nc, inner_nc, ...):
        # Channel shuffling applied at specific layers
        if self.training and layer_id in shuffle_layers:
            self.apply_channel_shuffle = True
        
    def forward(self, x):
        if self.apply_channel_shuffle:
            x = self.channel_shuffle(x)  # Mix channels
        # ... rest of U-Net block processing
```

#### **7. Video-Specific Anomaly Detection:**

```python
def test(self):
    # Process each video snippet
    for i, data in enumerate(self.data.valid):
        self.set_input(data)
        fake = self.netg(self.input_lap)  # Reconstruct video
        
        # Calculate reconstruction error per snippet
        error = torch.mean((fake - self.input_lap) ** 2, dim=[1,2,3,4])
        self.an_scores[i] = error  # Anomaly score
    
    # Compute metrics
    auc = roc(self.gt_labels, self.an_scores)
```

### 🔍 **Key Technical Features:**

- **Temporal Consistency**: 16-frame processing maintains temporal relationships
- **Multi-Stream Processing**: Separate Laplacian and residual feature streams
- **Adaptive Channel Mixing**: Channel shuffling improves feature representations
- **Progressive Loss Weighting**: Multiple loss components balance reconstruction vs realism
- **Real-Time Monitoring**: tqdm integration provides training visibility

---

## 🔀 **Channel Shuffling (CS) Module Deep Dive**

The **Channel Shuffling (CS)** module is a key innovation in this OCR-GAN Video architecture that enhances feature mixing between the dual-stream processing (Laplacian and Residual streams). Let's break down exactly how it works:

### 🏗️ **CS Module Architecture**

```python
class CS(nn.Module):
    def __init__(self, features, WH, r, L=32):
        super(CS, self).__init__()
        d = max(int(features/r), L)  # Reduction dimension
        self.gap = nn.AvgPool2d(int(WH))  # Global Average Pooling
        self.fc = nn.Linear(int(features), d)  # First FC layer
        self.fcs = nn.ModuleList([])  # Two separate FC layers
        for i in range(2):
            self.fcs.append(nn.Linear(d, features))
        self.softmax = nn.Softmax(dim=1)  # Attention weights
```

### 🔍 **How CS Works - Step by Step**

#### **1. Input Processing**
The CS module takes **two feature maps** as input:
- `x1`: Laplacian stream features (edge/boundary information)
- `x2`: Residual stream features (temporal difference patterns)

#### **2. Feature Fusion & Global Context**
```python
def forward(self, x):
    x1, x2 = x  # Separate the two streams
    x = x1 + x2  # Element-wise addition for global context
    
    # Global Average Pooling - reduces spatial dimensions to 1x1
    fea_s = self.gap(x).squeeze_()  # Shape: (batch_size, features)
```

#### **3. Attention Weight Generation**
```python
    # First reduction layer
    fea_z = self.fc(fea_s.cpu())  # Shape: (batch_size, d)
    
    # Generate attention vectors for each stream
    for i, fc in enumerate(self.fcs):
        vector = fc(fea_z).unsqueeze_(dim=1)
        if i == 0:
            attention_vec = vector
        else:
            attention_vec = torch.cat([attention_vec, vector], dim=1)
    
    # Apply softmax to create normalized attention weights
    attention_vec = self.softmax(attention_vec)  # Shape: (batch_size, 2, features)
```

#### **4. Adaptive Feature Weighting**
```python
    # Reshape for broadcasting
    attention_vec = attention_vec.unsqueeze(-1).unsqueeze(-1)
    attention_vec = attention_vec.transpose(0,1).to(device)
    
    # Apply attention weights to each stream
    out_x1 = x1 * attention_vec[0]  # Weighted Laplacian features
    out_x2 = x2 * attention_vec[1]  # Weighted Residual features
    
    return (out_x1, out_x2)
```

### 🎯 **Key Benefits of CS Module**

#### **1. 🧠 Adaptive Feature Selection**
- **Learns which stream is more important** for each spatial location
- **Dynamic weighting** based on input content
- **Context-aware** feature mixing

#### **2. 🔄 Cross-Stream Information Exchange**
- Combines information from **both streams** (x1 + x2) to generate attention
- Each stream gets **informed by the other** through shared attention computation
- Maintains **stream identity** while enabling **cross-pollination**

#### **3. 📏 Spatial Invariance**
- Uses **Global Average Pooling** to capture overall feature statistics
- Attention weights are **spatially uniform** but **channel-specific**
- Focuses on **what features matter** rather than **where they are**

### 🔧 **Technical Parameters**

| Parameter | Purpose | Typical Value |
|-----------|---------|---------------|
| `features` | Number of input channels | 64, 128, 256, 512 |
| `WH` | Spatial dimension for GAP | 4, 8, 16, 32, 64 |
| `r` | Reduction ratio | 2 (reduces complexity) |
| `L` | Minimum reduction dimension | 32 (prevents over-reduction) |

### 🌊 **Integration in U-Net Flow**

```
Input Streams (Laplacian, Residual)
        ↓
   Downsampling Layers
        ↓
┌─────────────────────────┐
│      CS Module          │
│  ┌─────────────────────┐ │
│  │ Global Average Pool │ │  Captures global context
│  └─────────────────────┘ │
│           ↓             │
│  ┌─────────────────────┐ │
│  │ Attention Generation│ │  Learns importance weights
│  └─────────────────────┘ │
│           ↓             │
│  ┌─────────────────────┐ │
│  │ Adaptive Weighting  │ │  Applies learned weights
│  └─────────────────────┘ │
└─────────────────────────┘
        ↓
Enhanced Feature Streams
        ↓
   Upsampling Layers
```

### 💡 **Why This Design Works**

1. **🎭 Complementary Streams**: Laplacian (edges) and Residual (motion) capture different aspects
2. **🤝 Intelligent Fusion**: CS learns when to emphasize each stream
3. **🎯 Context Awareness**: Global pooling provides scene-level understanding
4. **⚖️ Balanced Learning**: Softmax ensures attention weights sum to 1
5. **🔄 Feedback Loop**: Each stream benefits from the other's information

### 📊 **Example Attention Behavior**

- **High-motion scenes**: CS might emphasize residual stream (temporal changes)
- **Detailed textures**: CS might emphasize Laplacian stream (edge information)
- **Uniform regions**: CS balances both streams equally
- **Complex scenes**: CS adaptively weights based on local feature statistics

This **Channel Shuffling mechanism** is what makes the OCR-GAN Video model particularly effective at capturing both **spatial details** and **temporal dynamics** in video anomaly detection! 🚀

---

In [None]:
# Simple test with progress bars - 2 epochs for quick validation
print("🚀 Testing OCR-GAN Video Training with Progress Bars 🚀")
print("=" * 60)
print("Running 2 epochs with progress bars...")
%run train_video.py --model ocr_gan_video --dataset ucsd2 --dataroot data/ucsd2 --num_frames 8 --batchsize 1 --niter 2 --lr 0.0002 --gpu_ids -1 --ngpu 0 --device cpu --name ucsd2_quick_test

In [None]:
# Full training test - more epochs for complete validation
print("🚀 Full OCR-GAN Video Training Test 🚀")
print("=" * 60)
print("Running 5 epochs for complete training test...")
%run train_video.py --model ocr_gan_video --dataset ucsd2 --dataroot data/ucsd2 --num_frames 16 --batchsize 1 --niter 5 --lr 0.0002 --gpu_ids -1 --ngpu 0 --device cpu --name ucsd2_full_test

In [None]:
# Test tqdm progress bars functionality
import time
try:
    from tqdm.notebook import tqdm
    print("✅ Using notebook version of tqdm")
except ImportError:
    from tqdm import tqdm
    print("⚠️ Using standard tqdm (works in terminal)")

print("\nTesting progress bar...")
for i in tqdm(range(10), desc="🧪 Testing Progress"):
    time.sleep(0.1)
print("\n✅ Progress bar test completed!")