# DLFeat - Custom Model Registration & Feature Extraction Examples

This notebook demonstrates:
1. The DLFeat registration API for image, video, audio, text, and multimodal models.
2. How to plug pre-trained PyTorch (and HF-style) models into DLFeat via the `register_*_model` functions.
3. Extracting features with `DLFeatExtractor` and using them in simple scikit-learn classifiers.
4. Running quick sanity checks and self-tests to validate registered models.


In [1]:
pip install git+https://github.com/emanuelegaliano/DLFeat.git

Defaulting to user installation because normal site-packages is not writeable
Collecting git+https://github.com/emanuelegaliano/DLFeat.git
  Cloning https://github.com/emanuelegaliano/DLFeat.git to /tmp/pip-req-build-rn3dqo4f
  Running command git clone --filter=blob:none --quiet https://github.com/emanuelegaliano/DLFeat.git /tmp/pip-req-build-rn3dqo4f
  Resolved https://github.com/emanuelegaliano/DLFeat.git to commit 294731d2925c93bdaa9515d7bb84b26f083ba040
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: dlfeat
  Building wheel for dlfeat (pyproject.toml) ... [?25ldone
[?25h  Created wheel for dlfeat: filename=dlfeat-0.6.0-py3-none-any.whl size=18813 sha256=6905e26960a892ec1252699244b056b5f3be71343a6a7728c6ce17ccf62a52cc
  Stored in directory: /tmp/pip-ephem-wheel-cache-77pm2p9_/wheels/fe/42/16/fb7f39fa13eccf41da6e563c4bb972ada2

## Resolving imports of DLFeat

In [13]:
# note: fix import
from dlfeat_lib import (
    DLFeatExtractor, 
    register_video_model, 
    register_image_model, 
    register_audio_model,
    register_text_model,
    register_multimodal_image_text_model,
    register_multimodal_video_text_model
    )

## Video custom model example

This example shows how to define a tiny 3D CNN for videos, train it on a small synthetic dataset, register it in DLFeat with `register_video_model`, and then extract fixed-size feature vectors from raw `.mp4` files using `DLFeatExtractor`.


In [None]:
# Video custom model example: tiny 3D CNN + DLFeat registration

import os
import tempfile

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision

# ---------------------------
# Hyperparameters
# ---------------------------
clip_len = 8          # number of frames per clip
frame_size = 64       # spatial resolution (H = W)
feature_dim = 128     # output feature dimension for DLFeat
num_classes = 4
num_train_videos = 32
num_val_videos = 8
batch_size = 4
epochs = 3
device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------
# 1. Create a tiny synthetic video dataset on disk
# ---------------------------
tmp_root = tempfile.mkdtemp(prefix="dlfeat_tiny_video_")

def make_random_video(path, num_frames=clip_len, size=frame_size):
    """Create a simple 'moving square' RGB video and save it as MP4."""
    video = torch.zeros(num_frames, size, size, 3, dtype=torch.uint8)  # [T, H, W, C]
    for t in range(num_frames):
        x0 = (t * 2) % (size - 8)
        y0 = (t * 3) % (size - 8)
        video[t, y0:y0+8, x0:x0+8, :] = torch.randint(
            128, 255, (8, 8, 3), dtype=torch.uint8
        )
    torchvision.io.write_video(path, video, fps=8)

def generate_split(n_samples, split_name):
    paths, labels = [], []
    for i in range(n_samples):
        cls = i % num_classes
        filename = os.path.join(tmp_root, f"{split_name}_{i:03d}_class{cls}.mp4")
        make_random_video(filename)
        paths.append(filename)
        labels.append(cls)
    return paths, torch.tensor(labels, dtype=torch.long)

train_paths, train_labels = generate_split(num_train_videos, "train")
val_paths, val_labels = generate_split(num_val_videos, "val")

class VideoFileDataset(Dataset):
    """Simple dataset that loads .mp4 files and returns (C, T, H, W) tensors."""
    def __init__(self, paths, labels, clip_len, frame_size):
        self.paths = paths
        self.labels = labels
        self.clip_len = clip_len
        self.frame_size = frame_size

    def _load_video_tensor(self, path):
        # video: [T, H, W, C]
        video, _, _ = torchvision.io.read_video(path, pts_unit="sec")
        num_frames = video.size(0)

        # Sample or pad to a fixed number of frames
        if num_frames < self.clip_len:
            pad = video[-1:].repeat(self.clip_len - num_frames, 1, 1, 1)
            video = torch.cat([video, pad], dim=0)
        else:
            idx = torch.linspace(0, num_frames - 1, steps=self.clip_len).long()
            video = video[idx]

        # To [T, C, H, W]
        video = video.permute(0, 3, 1, 2)

        # Resize frames and normalize to [0, 1]
        video = torchvision.transforms.functional.resize( # type: ignore
            video, [self.frame_size, self.frame_size], antialias=True
        )
        video = video.float() / 255.0

        # Final shape [C, T, H, W]
        return video.permute(1, 0, 2, 3)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        video_tensor = self._load_video_tensor(self.paths[idx])
        label = self.labels[idx]
        return video_tensor, label

train_ds = VideoFileDataset(train_paths, train_labels, clip_len, frame_size)
val_ds   = VideoFileDataset(val_paths,   val_labels,   clip_len, frame_size)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)

# ---------------------------
# 2. Define a tiny 3D CNN backbone + classifier head
# ---------------------------

class TinyVideoBackbone(nn.Module):
    """Very small 3D CNN that maps (C, T, H, W) to a feature vector."""
    def __init__(self, feature_dim=feature_dim):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d((1, 2, 2)),           # pool only spatial dims
            nn.Conv3d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool3d((None, 1, 1)) # keep time, pool H,W -> [B, 32, T, 1, 1]
        )
        self.proj = nn.Linear(32, feature_dim)

    def forward(self, x):
        # x: [B, C, T, H, W]
        x = self.features(x)          # [B, 32, T, 1, 1]
        x = x.mean(dim=2)             # temporal average -> [B, 32, 1, 1]
        x = x.view(x.size(0), 32)     # [B, 32]
        return self.proj(x)           # [B, feature_dim]

class TinyVideoClassifier(nn.Module):
    """Backbone + linear head for quick supervised training."""
    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        self.head = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        feats = self.backbone(x)
        return self.head(feats)

backbone = TinyVideoBackbone(feature_dim=feature_dim)
model = TinyVideoClassifier(backbone, num_classes=num_classes).to(device)

# ---------------------------
# 3. Quick training loop (few epochs on tiny synthetic data)
# ---------------------------

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for videos, labels in train_loader:
        videos = videos.to(device)   # [B, C, T, H, W]
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(videos)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * videos.size(0)

    avg_loss = running_loss / len(train_ds)
    print(f"Epoch {epoch+1}/{epochs} - train loss: {avg_loss:.4f}")

# ---------------------------
# 4. Register the trained backbone with DLFeat
# ---------------------------

register_video_model(
    model_name="tiny_video_cnn",
    dim=feature_dim,
    model=backbone,      # pass the trained backbone instance
    clip_len=clip_len,
    input_size=frame_size,
    overwrite=True,
)

# ---------------------------
# 5. Use DLFeatExtractor to get features for a list of video paths
# ---------------------------

extractor = DLFeatExtractor("tiny_video_cnn", device=device)

# Here we just reuse a few validation paths, but any list of .mp4 files works
video_paths = val_paths[:4]
features = extractor.transform(video_paths, batch_size=2)

print("Extracted feature shape:", features.shape)  # (N_videos, feature_dim)



Epoch 1/3 - train loss: 1.3968
Epoch 2/3 - train loss: 1.3906
Epoch 3/3 - train loss: 1.3890
Extracted feature shape: (4, 128)




## Image custom model example

In this example we define a tiny convolutional backbone for RGB images, train it for a few epochs on a small synthetic dataset, and then register it in DLFeat with `register_image_model`. Once registered, the trained backbone can be reused via `DLFeatExtractor` to obtain fixed-size feature vectors from any list of images.


In [6]:
# Image custom model example: tiny CNN + DLFeat registration

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageDraw

# ---------------------------
# Hyperparameters
# ---------------------------
input_size   = 64   # H = W of input images
feature_dim  = 64   # feature dimension exposed to DLFeat
num_classes  = 4
num_train    = 256
num_val      = 64
batch_size   = 32
epochs       = 3
device       = "cuda" if torch.cuda.is_available() else "cpu"

# This transform will be used BOTH for training and when registering the model
image_transform_for_model = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

# ---------------------------
# 1. Synthetic image dataset
# ---------------------------

class RandomSquaresDataset(Dataset):
    """
    Simple synthetic dataset:
    - each sample is a dark background with a colored square
    - the square color encodes the class (0..num_classes-1)
    """
    def __init__(self, num_samples, image_size, num_classes, split="train", transform=None):
        self.num_samples = num_samples
        self.image_size = image_size
        self.num_classes = num_classes
        self.split = split
        self.transform = transform
        # Separate RNG for train/val so they don't share the same sequence
        seed = 42 if split == "train" else 123
        self.rng = torch.Generator().manual_seed(seed)

        # Fixed palette: one color per class
        self.palette = [
            (220, 60, 60),
            (60, 220, 60),
            (60, 60, 220),
            (220, 180, 60),
        ]

    def __len__(self):
        return self.num_samples

    def _randint(self, low, high):
        return int(torch.randint(low, high, (1,), generator=self.rng).item())

    def _make_image(self, label):
        img = Image.new("RGB", (self.image_size, self.image_size), color=(10, 10, 10))
        draw = ImageDraw.Draw(img)

        square_size = self.image_size // 3
        x0 = self._randint(0, self.image_size - square_size)
        y0 = self._randint(0, self.image_size - square_size)
        x1 = x0 + square_size
        y1 = y0 + square_size

        color = self.palette[label % len(self.palette)]
        draw.rectangle([x0, y0, x1, y1], fill=color)

        return img

    def __getitem__(self, idx):
        # Simple deterministic label: cycle through classes
        label = idx % self.num_classes
        img = self._make_image(label)

        if self.transform is not None:
            img = self.transform(img)

        return img, label


train_ds = RandomSquaresDataset(
    num_samples=num_train,
    image_size=input_size,
    num_classes=num_classes,
    split="train",
    transform=image_transform_for_model,
)
val_ds = RandomSquaresDataset(
    num_samples=num_val,
    image_size=input_size,
    num_classes=num_classes,
    split="val",
    transform=image_transform_for_model,
)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)

# ---------------------------
# 2. Tiny CNN backbone + classifier head
# ---------------------------

class TinyImageBackbone(nn.Module):
    """Very small 2D CNN that maps (B, 3, H, W) to (B, feature_dim)."""
    def __init__(self, feature_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                  # 16 x 32 x 32

            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                  # 32 x 16 x 16

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),         # 64 x 1 x 1
        )
        self.proj = nn.Linear(64, feature_dim)

    def forward(self, x):
        x = self.conv(x)                      # [B, 64, 1, 1]
        x = x.view(x.size(0), 64)             # [B, 64]
        return self.proj(x)                   # [B, feature_dim]


class TinyImageClassifier(nn.Module):
    """Backbone + linear classification head."""
    def __init__(self, backbone, num_classes, feature_dim):
        super().__init__()
        self.backbone = backbone
        self.head = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        feats = self.backbone(x)
        logits = self.head(feats)
        return logits


backbone = TinyImageBackbone(feature_dim=feature_dim)
model = TinyImageClassifier(backbone, num_classes=num_classes, feature_dim=feature_dim).to(device)

# ---------------------------
# 3. Quick training loop
# ---------------------------

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    # Train
    model.train()
    running_loss = 0.0
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)

    avg_loss = running_loss / len(train_ds)

    # Simple validation accuracy (optional but nice to see)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            logits = model(imgs)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.numel()
    val_acc = correct / max(total, 1)

    print(f"Epoch {epoch+1}/{epochs} - train loss: {avg_loss:.4f} - val acc: {val_acc:.3f}")

# ---------------------------
# 4. Register the trained backbone with DLFeat
# ---------------------------

register_image_model(
    model_name="tiny_cnn_image",
    dim=feature_dim,
    model=backbone,                 # pass the trained backbone instance
    input_size=input_size,
    image_transform=image_transform_for_model,
    overwrite=True,
)

# ---------------------------
# 5. Use DLFeatExtractor to get features for arbitrary images
# ---------------------------

extractor = DLFeatExtractor("tiny_cnn_image", device=device)

# Create a few random PIL images and extract features
def make_random_pil_image(size=input_size):
    arr = torch.randint(0, 255, (size, size, 3), dtype=torch.uint8).numpy()
    return Image.fromarray(arr)

test_images = [make_random_pil_image() for _ in range(4)]
features = extractor.transform(test_images, batch_size=2)

print("Extracted feature shape:", features.shape)  # (N_images, feature_dim)


Epoch 1/3 - train loss: 1.3659 - val acc: 0.500
Epoch 2/3 - train loss: 1.2388 - val acc: 0.500
Epoch 3/3 - train loss: 0.9011 - val acc: 0.609
Extracted feature shape: (4, 64)


## Audio custom model example
This section shows how to define, train, and register a tiny 1D CNN for raw audio waveforms using DLFeat.  
We create a synthetic dataset of short sine-wave signals, train a lightweight classifier, then register the backbone with `register_audio_model` and use `DLFeatExtractor` to obtain fixed-size embeddings.

In [10]:
# Audio custom model example: tiny 1D CNN + DLFeat registration

import math
import types
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ---------------------------
# Hyperparameters
# ---------------------------
sample_rate = 16_000          # Hz
duration_s = 0.5              # seconds
num_samples = int(sample_rate * duration_s)

num_classes = 3
feature_dim = 64

num_train_signals = 300
num_val_signals = 60

batch_size = 32
epochs = 5

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------
# 1. Synthetic audio dataset
#    (sine waves at different frequencies + noise)
# ---------------------------

class SineWaveDataset(Dataset):
    """Synthetic dataset: each class = sine wave with a different base frequency."""
    def __init__(self, n_samples, num_classes, sample_rate, num_samples):
        self.n_samples = n_samples
        self.num_classes = num_classes
        self.sample_rate = sample_rate
        self.num_samples = num_samples

        # Choose one frequency per class
        self.class_freqs = torch.tensor([220.0, 440.0, 880.0])[:num_classes]

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        cls = idx % self.num_classes
        freq = self.class_freqs[cls]

        t = torch.linspace(0, duration_s, self.num_samples, dtype=torch.float32)
        phase = 2 * math.pi * torch.rand(1).item()
        amplitude = 0.5 + 0.5 * torch.rand(1).item()

        clean = amplitude * torch.sin(2 * math.pi * freq * t + phase)
        noise = 0.05 * torch.randn_like(clean)
        waveform = clean + noise  # shape [T]

        return waveform, cls

train_ds = SineWaveDataset(num_train_signals, num_classes, sample_rate, num_samples)
val_ds   = SineWaveDataset(num_val_signals,   num_classes, sample_rate, num_samples)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)

# ---------------------------
# 2. Tiny 1D CNN encoder + classifier
# ---------------------------

class TinyAudioEncoder(nn.Module):
    """Very small 1D CNN that maps (B, 1, T) to feature vectors of size feature_dim."""
    def __init__(self, feature_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=9, padding=4),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(4),

            nn.Conv1d(16, 32, kernel_size=9, padding=4),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(4),

            nn.AdaptiveAvgPool1d(1)  # -> (B, 32, 1)
        )
        self.proj = nn.Linear(32, feature_dim)

    def forward(self, x):
        # x: (B, 1, T)
        x = self.net(x)          # (B, 32, 1)
        x = x.squeeze(-1)        # (B, 32)
        x = self.proj(x)         # (B, feature_dim)
        return x

class TinyAudioClassifier(nn.Module):
    """Encoder + linear classification head."""
    def __init__(self, encoder, feature_dim, num_classes):
        super().__init__()
        self.encoder = encoder
        self.head = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        # x: (B, 1, T)
        feats = self.encoder(x)         # (B, feature_dim)
        logits = self.head(feats)       # (B, num_classes)
        return logits

encoder = TinyAudioEncoder(feature_dim=feature_dim)
classifier = TinyAudioClassifier(encoder, feature_dim, num_classes).to(device)

# ---------------------------
# 3. Quick training loop
# ---------------------------

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

for epoch in range(epochs):
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for waveforms, labels in train_loader:
        # waveforms: (B, T)
        waveforms = waveforms.to(device)
        labels = labels.to(device)

        # Add channel dimension -> (B, 1, T)
        waveforms = waveforms.unsqueeze(1)

        optimizer.zero_grad()
        logits = classifier(waveforms)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * waveforms.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / len(train_ds)
    acc = correct / total if total > 0 else 0.0
    print(f"Epoch {epoch+1}/{epochs} - train loss: {avg_loss:.4f} - acc: {acc:.3f}")

# ---------------------------
# 4. Wrap encoder in a HF-style module for DLFeat
#    (so that .transform() can use .last_hidden_state)
# ---------------------------

class HFStyleAudioWrapper(nn.Module):
    """
    Wraps TinyAudioEncoder so that forward(input_values) returns an object
    with a .last_hidden_state attribute, as expected by DLFeat's audio path.
    """
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def forward(self, input_values):
        # input_values: (B, T)
        x = input_values.unsqueeze(1)      # (B, 1, T)
        feats = self.encoder(x)            # (B, feature_dim)
        # DLFeat will do outputs.last_hidden_state.mean(dim=1),
        # so we expose an extra "sequence" dimension of length 1.
        return types.SimpleNamespace(last_hidden_state=feats.unsqueeze(1))  # (B, 1, D)

audio_backbone = HFStyleAudioWrapper(encoder).to(device)

# ---------------------------
# 5. Custom audio preprocessing for DLFeat
#    Here we assume inputs to DLFeatExtractor.transform() are 1D tensors.
# ---------------------------

def tiny_audio_preprocess(audio_input):
    """
    audio_input: a 1D torch.Tensor (waveform).
    Returns a dict of tensors like a HF processor would.
    """
    if not isinstance(audio_input, torch.Tensor):
        raise TypeError("This example expects a 1D torch.Tensor as input.")
    # shape -> (1, T) so that DLFeat builds batches correctly
    waveform = audio_input.float().unsqueeze(0)
    return {"input_values": waveform}

# ---------------------------
# 6. Register the audio model in DLFeat
# ---------------------------

register_audio_model(
    model_name="tiny_audio_cnn",
    dim=feature_dim,
    model=audio_backbone,         # pre-trained backbone instance
    sampling_rate=sample_rate,    # only for metadata; not used by our custom preprocess
    audio_preprocess=tiny_audio_preprocess,
    overwrite=True,
)

# ---------------------------
# 7. Use DLFeatExtractor to compute features
# ---------------------------

extractor = DLFeatExtractor("tiny_audio_cnn", device=device)

# Take a few validation waveforms
audio_examples = [val_ds[i][0] for i in range(8)]  # list of 1D tensors

features = extractor.transform(audio_examples, batch_size=4)
print("Extracted feature shape:", features.shape)   # (N_signals, feature_dim)


Epoch 1/5 - train loss: 1.0562 - acc: 0.707
Epoch 2/5 - train loss: 0.9180 - acc: 0.920
Epoch 3/5 - train loss: 0.6874 - acc: 1.000
Epoch 4/5 - train loss: 0.4264 - acc: 1.000
Epoch 5/5 - train loss: 0.2126 - acc: 1.000
Extracted feature shape: (8, 64)


## Text custom model example
This section shows how to define, train, and register a tiny text encoder using DLFeat.  
We build a synthetic toy dataset, train a lightweight RNN-based classifier, then register the backbone with `register_text_model` and use `DLFeatExtractor` to obtain fixed-size embeddings from raw text.


In [12]:
# Text custom model example: tiny RNN encoder + DLFeat registration

import types
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ---------------------------
# Hyperparameters
# ---------------------------
num_classes = 3
feature_dim = 32       # output embedding dimension for DLFeat
embed_dim = 32
hidden_dim = 32

num_train_samples = 300
num_val_samples = 60

batch_size = 32
epochs = 5
max_len = 8            # fixed max sequence length for tokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------
# 1. Tiny toy text dataset
#    (3 classes, each with a simple template sentence)
# ---------------------------

class TinyTextDataset(Dataset):
    """
    Toy dataset:
      class 0 -> "this is class zero sample"
      class 1 -> "this is class one example"
      class 2 -> "this is class two item"
    """
    def __init__(self, n_samples, num_classes=3):
        assert num_classes <= 3, "This toy dataset supports up to 3 classes."
        self.n_samples = n_samples
        self.num_classes = num_classes

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        cls = idx % self.num_classes
        if cls == 0:
            text = "this is class zero sample"
        elif cls == 1:
            text = "this is class one example"
        else:
            text = "this is class two item"
        return text, cls

train_ds = TinyTextDataset(num_train_samples, num_classes=num_classes)
val_ds   = TinyTextDataset(num_val_samples,   num_classes=num_classes)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)

# ---------------------------
# 2. Simple vocabulary + tokenizer
# ---------------------------

VOCAB_LIST = [
    "[PAD]", "[UNK]",
    "this", "is", "class", "zero", "one", "two",
    "sample", "example", "item"
]
VOCAB = {w: i for i, w in enumerate(VOCAB_LIST)}
PAD_ID = VOCAB["[PAD]"]
UNK_ID = VOCAB["[UNK]"]

def tiny_tokenizer(
    texts,
    padding=True,
    truncation=True,
    return_tensors=None,
    max_length=max_len,
    **kwargs
):
    """
    Minimal tokenizer with a HuggingFace-like signature.
    - texts: str or list[str]
    Returns:
      {
        "input_ids": LongTensor [B, L],
        "attention_mask": LongTensor [B, L]
      }
    """
    if isinstance(texts, str):
        texts = [texts]

    batch_ids = []
    for t in texts:
        tokens = t.lower().split()
        ids = [VOCAB.get(tok, UNK_ID) for tok in tokens]

        if truncation and len(ids) > max_length:
            ids = ids[:max_length]

        if padding:
            while len(ids) < max_length:
                ids.append(PAD_ID)

        batch_ids.append(ids)

    input_ids = torch.tensor(batch_ids, dtype=torch.long)
    attention_mask = (input_ids != PAD_ID).long()

    return {"input_ids": input_ids, "attention_mask": attention_mask}

vocab_size = len(VOCAB_LIST)

# ---------------------------
# 3. Tiny RNN text backbone + classifier
# ---------------------------

class TinyTextBackbone(nn.Module):
    """
    Tiny encoder that mimics the HF output interface:
    forward(input_ids, attention_mask=None) -> object with .last_hidden_state
    where last_hidden_state has shape [B, L, D].
    """
    def __init__(self, vocab_size, embed_dim, hidden_dim, feature_dim, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.proj = nn.Linear(hidden_dim, feature_dim)

    def forward(self, input_ids, attention_mask=None):
        # input_ids: [B, L]
        x = self.embedding(input_ids)      # [B, L, E]
        # simple GRU
        output, _ = self.rnn(x)            # [B, L, H]
        feats = self.proj(output)          # [B, L, D]
        # DLFeat expects .last_hidden_state
        return types.SimpleNamespace(last_hidden_state=feats)

class TinyTextClassifier(nn.Module):
    """
    Backbone + head for quick supervised training.
    We use the representation at position 0 (like a [CLS] token).
    """
    def __init__(self, backbone, feature_dim, num_classes):
        super().__init__()
        self.backbone = backbone
        self.head = nn.Linear(feature_dim, num_classes)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        cls_emb = outputs.last_hidden_state[:, 0, :]  # [B, D]
        logits = self.head(cls_emb)                   # [B, num_classes]
        return logits

backbone = TinyTextBackbone(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    hidden_dim=hidden_dim,
    feature_dim=feature_dim,
    pad_idx=PAD_ID,
)
classifier = TinyTextClassifier(backbone, feature_dim, num_classes).to(device)

# ---------------------------
# 4. Quick training loop
# ---------------------------

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

for epoch in range(epochs):
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for texts, labels in train_loader:
        labels = labels.to(device)
        # tokenize list of strings
        batch_tokens = tiny_tokenizer(texts, return_tensors="pt")
        input_ids = batch_tokens["input_ids"].to(device)
        attention_mask = batch_tokens["attention_mask"].to(device)

        optimizer.zero_grad()
        logits = classifier(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * input_ids.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / len(train_ds)
    acc = correct / total if total > 0 else 0.0
    print(f"Epoch {epoch+1}/{epochs} - train loss: {avg_loss:.4f} - acc: {acc:.3f}")

# ---------------------------
# 5. Register the trained backbone with DLFeat
# ---------------------------

register_text_model(
    model_name="tiny_text_rnn",
    dim=feature_dim,
    model=backbone,          # pre-trained backbone instance
    tokenizer=tiny_tokenizer,
    overwrite=True,
)

# ---------------------------
# 6. Use DLFeatExtractor to obtain text features
# ---------------------------

extractor = DLFeatExtractor("tiny_text_rnn", device=device)

# Take a few validation sentences
val_texts = [val_ds[i][0] for i in range(6)]
features = extractor.transform(val_texts, batch_size=3)

print("Validation texts:", val_texts)
print("Extracted feature shape:", features.shape)  # (N_texts, feature_dim)


Epoch 1/5 - train loss: 1.1003 - acc: 0.290
Epoch 2/5 - train loss: 1.1012 - acc: 0.333
Epoch 3/5 - train loss: 1.0993 - acc: 0.333
Epoch 4/5 - train loss: 1.0993 - acc: 0.333
Epoch 5/5 - train loss: 1.1005 - acc: 0.333
Validation texts: ['this is class zero sample', 'this is class one example', 'this is class two item', 'this is class zero sample', 'this is class one example', 'this is class two item']
Extracted feature shape: (6, 32)


## Multimodal image–text custom model example
This section shows how to define, train, and register a tiny image–text encoder with DLFeat.  
We build a synthetic toy dataset of colored squares with simple text descriptions, train a lightweight multimodal classifier, then register the backbone with `register_multimodal_image_text_model` and use `DLFeatExtractor` to obtain aligned image and text embeddings.


In [15]:
# Multimodal image–text custom model example: tiny encoders + DLFeat registration

import types
import random

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from PIL import Image, ImageDraw
from torchvision import transforms as T

# ---------------------------
# Hyperparameters
# ---------------------------
num_classes  = 3               # e.g. red / green / blue
feature_dim  = 64              # shared embedding dim for image & text
img_size     = 64
embed_dim    = 32
hidden_dim   = 32
max_len      = 5               # max tokens in text

num_train_samples = 300
num_val_samples   = 60

batch_size = 32
epochs     = 5

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------
# 1. Tiny synthetic image–text dataset
# ---------------------------

COLORS = {
    0: ("red",   (220, 40, 40)),
    1: ("green", (40, 180, 60)),
    2: ("blue",  (40, 80, 220)),
}

class TinyImageTextDataset(Dataset):
    """
    Each sample:
      - image: solid colored square (red/green/blue)
      - text:  "a red square", "a green square", etc.
      - label: 0, 1, or 2
    """
    def __init__(self, n_samples, num_classes=3, img_size=64):
        assert num_classes <= 3
        self.n_samples = n_samples
        self.num_classes = num_classes
        self.img_size = img_size

    def _make_image(self, color_rgb):
        img = Image.new("RGB", (self.img_size, self.img_size), color=color_rgb)
        # Optional: draw a little border just to make it less trivial
        draw = ImageDraw.Draw(img)
        draw.rectangle([4, 4, self.img_size - 5, self.img_size - 5], outline=(0, 0, 0), width=2)
        return img

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        cls = idx % self.num_classes
        color_name, rgb = COLORS[cls]
        img = self._make_image(rgb)
        text = f"a {color_name} square"
        return img, text, cls

train_ds = TinyImageTextDataset(num_train_samples, num_classes=num_classes, img_size=img_size)
val_ds   = TinyImageTextDataset(num_val_samples,   num_classes=num_classes, img_size=img_size)

def collate_fn(batch):
    """
    Keep PIL images and strings as lists (no default tensor collation),
    stack labels into a tensor.
    """
    images, texts, labels = zip(*batch)
    labels = torch.tensor(labels, dtype=torch.long)
    return list(images), list(texts), labels

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# ---------------------------
# 2. Tiny vocabulary & tokenizer (text side)
# ---------------------------

VOCAB_LIST = [
    "[PAD]", "[UNK]",
    "a", "red", "green", "blue", "square"
]
VOCAB = {w: i for i, w in enumerate(VOCAB_LIST)}
PAD_ID = VOCAB["[PAD]"]
UNK_ID = VOCAB["[UNK]"]
vocab_size = len(VOCAB_LIST)

def tiny_text_tokenizer(
    texts,
    padding=True,
    truncation=True,
    return_tensors="pt",
    max_length=max_len,
    **kwargs,
):
    """
    Minimal tokenizer with HF-like signature.
    Returns:
      {
        "input_ids": LongTensor [B, L],
        "attention_mask": LongTensor [B, L]
      }
    """
    if isinstance(texts, str):
        texts = [texts]

    batch_ids = []
    for t in texts:
        tokens = t.lower().split()
        ids = [VOCAB.get(tok, UNK_ID) for tok in tokens]

        if truncation and len(ids) > max_length:
            ids = ids[:max_length]

        if padding:
            while len(ids) < max_length:
                ids.append(PAD_ID)

        batch_ids.append(ids)

    input_ids = torch.tensor(batch_ids, dtype=torch.long)
    attention_mask = (input_ids != PAD_ID).long()
    return {"input_ids": input_ids, "attention_mask": attention_mask}

# ---------------------------
# 3. Image preprocessing transform
# ---------------------------

img_transform = T.Compose([
    T.Resize((img_size, img_size)),
    T.ToTensor(),   # [C, H, W] in [0,1]
])

def preprocess_images(pil_images):
    tensors = [img_transform(img) for img in pil_images]  # list of [C,H,W]
    return torch.stack(tensors, dim=0)                    # [B,C,H,W]

# ---------------------------
# 4. Multimodal processor for DLFeat (image + text)
# ---------------------------

def tiny_multimodal_processor(
    text,
    images,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=max_len,
    **kwargs,
):
    """
    Processor with HF CLIP-like API:
      - text: list[str]
      - images: list[PIL.Image.Image]
    Returns a dict with keys:
      "pixel_values", "input_ids", "attention_mask"
    """
    # Tokenize text
    text_tokens = tiny_text_tokenizer(
        text,
        padding=padding,
        truncation=truncation,
        return_tensors=return_tensors,
        max_length=max_length,
    )
    # Preprocess images
    pixel_values = preprocess_images(images)  # [B,C,H,W]

    return {
        "pixel_values": pixel_values,
        "input_ids": text_tokens["input_ids"],
        "attention_mask": text_tokens["attention_mask"],
    }

# ---------------------------
# 5. Tiny image & text encoders + multimodal backbone
# ---------------------------

class TinyImageEncoder(nn.Module):
    """Simple CNN: (B,3,H,W) -> (B, feature_dim)."""
    def __init__(self, feature_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # H/2, W/2
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),  # [B, 32, 1, 1]
        )
        self.proj = nn.Linear(32, feature_dim)

    def forward(self, x):
        x = self.net(x)              # [B,32,1,1]
        x = x.view(x.size(0), 32)    # [B,32]
        return self.proj(x)          # [B,feature_dim]

class TinyTextEncoder(nn.Module):
    """GRU text encoder: (B,L) -> (B, feature_dim)."""
    def __init__(self, vocab_size, embed_dim, hidden_dim, feature_dim, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.proj = nn.Linear(hidden_dim, feature_dim)

    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)         # [B,L,E]
        outputs, _ = self.rnn(x)              # [B,L,H]

        if attention_mask is not None:
            mask = attention_mask.unsqueeze(-1).float()  # [B,L,1]
            outputs = outputs * mask
            lengths = mask.sum(dim=1).clamp(min=1.0)     # [B,1]
            pooled = outputs.sum(dim=1) / lengths        # [B,H]
        else:
            pooled = outputs.mean(dim=1)

        return self.proj(pooled)              # [B,feature_dim]

class TinyImageTextBackbone(nn.Module):
    """
    Multimodal backbone with HF-like output:
      forward(pixel_values, input_ids, attention_mask=None)
    -> object with .image_embeds and .text_embeds (both [B,feature_dim]).
    """
    def __init__(self, feature_dim, vocab_size, embed_dim, hidden_dim, pad_idx):
        super().__init__()
        self.image_encoder = TinyImageEncoder(feature_dim)
        self.text_encoder  = TinyTextEncoder(vocab_size, embed_dim, hidden_dim, feature_dim, pad_idx)

    def forward(self, pixel_values, input_ids, attention_mask=None):
        img_embeds  = self.image_encoder(pixel_values)
        text_embeds = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        return types.SimpleNamespace(
            image_embeds=img_embeds,
            text_embeds=text_embeds,
        )

class TinyImageTextClassifier(nn.Module):
    """
    Backbone + linear head for quick supervised training.
    We simply average image and text embeddings and classify.
    """
    def __init__(self, backbone, feature_dim, num_classes):
        super().__init__()
        self.backbone = backbone
        self.head = nn.Linear(feature_dim, num_classes)

    def forward(self, pixel_values, input_ids, attention_mask=None):
        outputs = self.backbone(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        fused = 0.5 * (outputs.image_embeds + outputs.text_embeds)  # [B,feature_dim]
        logits = self.head(fused)
        return logits

backbone = TinyImageTextBackbone(
    feature_dim=feature_dim,
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    hidden_dim=hidden_dim,
    pad_idx=PAD_ID,
)
classifier = TinyImageTextClassifier(backbone, feature_dim, num_classes).to(device)

# ---------------------------
# 6. Quick training loop
# ---------------------------

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

for epoch in range(epochs):
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, texts, labels in train_loader:
        labels = labels.to(device)

        # Preprocess with our multimodal processor
        proc = tiny_multimodal_processor(text=texts, images=images, return_tensors="pt")
        pixel_values   = proc["pixel_values"].to(device)
        input_ids      = proc["input_ids"].to(device)
        attention_mask = proc["attention_mask"].to(device)

        optimizer.zero_grad()
        logits = classifier(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * labels.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / len(train_ds)
    acc = correct / total if total > 0 else 0.0
    print(f"Epoch {epoch+1}/{epochs} - train loss: {avg_loss:.4f} - acc: {acc:.3f}")

# ---------------------------
# 7. Register the trained multimodal backbone with DLFeat
# ---------------------------

register_multimodal_image_text_model(
    model_name="tiny_image_text_encoder",
    dim=feature_dim,
    model=backbone,                  # pre-trained backbone instance
    processor=tiny_multimodal_processor,
    overwrite=True,
)

# ---------------------------
# 8. Use DLFeatExtractor to get image & text features
# ---------------------------

extractor = DLFeatExtractor("tiny_image_text_encoder", device=device)

# Build a small list of (PIL.Image, text) pairs from the validation set
eval_pairs = []
for i in range(6):
    img, text, _ = val_ds[i]
    eval_pairs.append((img, text))

features = extractor.transform(eval_pairs, batch_size=3)

print("Number of eval pairs:", len(eval_pairs))
print("Image features shape:", features["image_features"].shape)  # (N, feature_dim)
print("Text  features shape:", features["text_features"].shape)   # (N, feature_dim)


Epoch 1/5 - train loss: 1.0809 - acc: 0.410
Epoch 2/5 - train loss: 0.9158 - acc: 0.933
Epoch 3/5 - train loss: 0.6145 - acc: 1.000
Epoch 4/5 - train loss: 0.2440 - acc: 1.000
Epoch 5/5 - train loss: 0.0508 - acc: 1.000
Number of eval pairs: 6
Image features shape: (6, 64)
Text  features shape: (6, 64)


## Multimodal video–text custom model example
In this example we build a tiny multimodal encoder that jointly processes short video clips and text descriptions.  
We train it on a synthetic dataset and then register the backbone with `register_multimodal_video_text_model` so that `DLFeatExtractor` can return aligned video and text features.

In [None]:
# Multimodal video–text custom model example: tiny encoder + DLFeat registration

import os
import tempfile
from types import SimpleNamespace

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from PIL import Image

# ---------------------------
# Hyperparameters
# ---------------------------
num_frames   = 8      # frames per clip
frame_size   = 64     # spatial resolution (H = W)
feature_dim  = 64     # output feature dimension for DLFeat
num_classes  = 3
num_train    = 24
num_val      = 6
batch_size   = 4
epochs       = 3
device       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------
# 1. Create a tiny synthetic video–text dataset on disk
# ---------------------------

tmp_root = tempfile.mkdtemp(prefix="dlfeat_tiny_mv_")

COLORS = [
    ("red",   (255, 64, 64)),
    ("green", (64, 255, 64)),
    ("blue",  (64, 64, 255)),
]

def make_colored_video(path, rgb, num_frames=num_frames, size=frame_size):
    """
    Create a very simple RGB video: a solid-colored frame with a small moving square.
    """
    video = torch.zeros(num_frames, size, size, 3, dtype=torch.uint8)  # [T, H, W, C]
    base_color = torch.tensor(rgb, dtype=torch.uint8)

    for t in range(num_frames):
        video[t] = base_color
        # small darker square moving diagonally
        x0 = (t * 3) % (size - 10)
        y0 = (t * 2) % (size - 10)
        video[t, y0:y0+10, x0:x0+10, :] = base_color // 2

    # Requires PyAV installed
    torchvision.io.write_video(path, video, fps=8)

def generate_split(n_samples, split_name):
    paths, texts, labels = [], [], []
    for i in range(n_samples):
        cls = i % num_classes
        color_name, rgb = COLORS[cls]
        filename = os.path.join(tmp_root, f"{split_name}_{i:03d}_class{cls}.mp4")
        make_colored_video(filename, rgb)
        paths.append(filename)
        texts.append(f"a short video of a {color_name} moving square")
        labels.append(cls)
    return paths, texts, torch.tensor(labels, dtype=torch.long)

train_paths, train_texts, train_labels = generate_split(num_train, "train")
val_paths,   val_texts,   val_labels   = generate_split(num_val,   "val")


class VideoTextFileDataset(Dataset):
    """
    Dataset that stores (video_path, text, label).
    We will decode frames inside the training loop via a helper.
    """
    def __init__(self, paths, texts, labels):
        self.paths = paths
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        return self.paths[idx], self.texts[idx], int(self.labels[idx])


def collate_video_text(batch):
    video_paths, texts, labels = zip(*batch)
    labels = torch.tensor(labels, dtype=torch.long)
    return list(video_paths), list(texts), labels


train_ds = VideoTextFileDataset(train_paths, train_texts, train_labels)
val_ds   = VideoTextFileDataset(val_paths,   val_texts,   val_labels)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_video_text,
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_video_text,
)

# ---------------------------
# 2. Tiny video–text encoders
# ---------------------------

class TinyVideoEncoder(nn.Module):
    """Very small 3D CNN: (B, C, T, H, W) -> (B, feature_dim)."""
    def __init__(self, feature_dim=feature_dim):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d((1, 2, 2)),           # pool spatial dims
            nn.Conv3d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool3d((None, 1, 1)),  # keep time, pool H,W
        )
        self.proj = nn.Linear(32, feature_dim)

    def forward(self, videos):
        # videos: [B, C, T, H, W]
        x = self.features(videos)          # [B, 32, T, 1, 1]
        x = x.mean(dim=2)                  # temporal average -> [B, 32, 1, 1]
        x = x.view(x.size(0), 32)          # [B, 32]
        return self.proj(x)                # [B, feature_dim]


class TinyTextEncoder(nn.Module):
    """Simple GRU-based text encoder: token ids -> (B, feature_dim)."""
    def __init__(self, vocab_size=128, emb_dim=64, hidden_dim=feature_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.gru = nn.GRU(emb_dim, hidden_dim, batch_first=True)

    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)  # [B, L, E]
        outputs, _ = self.gru(x)
        if attention_mask is not None:
            attn = attention_mask.unsqueeze(-1).float()
            summed = (outputs * attn).sum(dim=1)
            lengths = attn.sum(dim=1).clamp(min=1.0)
            return summed / lengths
        else:
            return outputs.mean(dim=1)


class TinyVideoTextBackbone(nn.Module):
    """
    Multimodal backbone returning two aligned embeddings:
    - video_embeds: (B, feature_dim)
    - text_embeds:  (B, feature_dim)
    This is what DLFeatExtractor will call.
    """
    def __init__(self, feature_dim=feature_dim):
        super().__init__()
        self.video_encoder = TinyVideoEncoder(feature_dim=feature_dim)
        self.text_encoder  = TinyTextEncoder(hidden_dim=feature_dim)

    def forward(self, videos, input_ids, attention_mask=None):
        # videos: [B, C, T, H, W]
        v_emb = self.video_encoder(videos)
        t_emb = self.text_encoder(input_ids, attention_mask)
        return SimpleNamespace(video_embeds=v_emb, text_embeds=t_emb)


class TinyVideoTextClassifier(nn.Module):
    """
    Backbone + linear head for quick supervised training.
    We fuse video and text embeddings by simple averaging.
    """
    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        self.head = nn.Linear(feature_dim, num_classes)

    def forward(self, videos, input_ids, attention_mask=None):
        outs = self.backbone(videos=videos, input_ids=input_ids, attention_mask=attention_mask)
        fused = 0.5 * (outs.video_embeds + outs.text_embeds)
        return self.head(fused)


# ---------------------------
# 3. Tiny tokenizer + processor (DLFeat-style)
# ---------------------------

def tiny_text_tokenizer(texts, max_length=16):
    """
    Toy tokenizer:
    - lowercase
    - keep only printable ASCII (32..126)
    - map chars to ids in [1..127], 0 reserved for padding.
    """
    if isinstance(texts, str):
        texts = [texts]

    encoded = []
    for t in texts:
        t = t.lower()
        ids = []
        for ch in t:
            if " " <= ch <= "~":
                ids.append(ord(ch) - 31)  # 1..96-ish
        ids = ids[:max_length]
        if len(ids) < max_length:
            ids = ids + [0] * (max_length - len(ids))
        encoded.append(ids)

    input_ids = torch.tensor(encoded, dtype=torch.long)
    attention_mask = (input_ids != 0).long()
    return {"input_ids": input_ids, "attention_mask": attention_mask}


frame_transform = transforms.Compose(
    [
        transforms.Resize((frame_size, frame_size)),
        transforms.ToTensor(),  # [C,H,W], in [0,1]
    ]
)

def tiny_multimodal_video_processor(text, videos, return_tensors="pt", padding=True, **kwargs):
    """
    Processor compatible with DLFeatExtractor for 'multimodal_video_text'.

    Expected input (from DLFeatExtractor):
      - text: list of strings (one per video)
      - videos: list of lists of PIL.Image (frames for each clip)
    Returns a dict of tensors ready for the model(**inputs).
    """
    # 1) Process videos -> [B, C, T, H, W]
    video_tensors = []
    for clip in videos:  # clip: list[PIL.Image]
        frame_tensors = [frame_transform(f) for f in clip]           # list [C,H,W]
        clip_tensor = torch.stack(frame_tensors, dim=1)              # [C,T,H,W]
        video_tensors.append(clip_tensor)
    videos_batch = torch.stack(video_tensors, dim=0)                 # [B,C,T,H,W]

    # 2) Process text -> token ids + mask
    text_tokens = tiny_text_tokenizer(text)
    input_ids = text_tokens["input_ids"]
    attention_mask = text_tokens["attention_mask"]

    if return_tensors != "pt":
        raise ValueError("This tiny processor only supports return_tensors='pt'.")

    return {
        "videos": videos_batch,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
    }


# ---------------------------
# 4. Helper to read videos as lists of PIL frames for training
# ---------------------------

def read_video_as_pil_list(path, num_frames=num_frames):
    """
    Decode a video file into a list of PIL frames, uniformly sampled.
    """
    frames_tensor, _, _ = torchvision.io.read_video(path, pts_unit="sec")
    if frames_tensor.numel() == 0:
        raise ValueError(f"No frames found in {path}")

    total = frames_tensor.shape[0]
    if total < num_frames:
        pad = frames_tensor[-1:].repeat(num_frames - total, 1, 1, 1)
        frames_tensor = torch.cat([frames_tensor, pad], dim=0)
    else:
        idx = torch.linspace(0, total - 1, steps=num_frames).long()
        frames_tensor = frames_tensor[idx]

    frames = []
    for t in range(num_frames):
        frame_np = frames_tensor[t].numpy()  # [H,W,C]
        frames.append(Image.fromarray(frame_np))
    return frames


# ---------------------------
# 5. Quick supervised training on the synthetic dataset
# ---------------------------

backbone = TinyVideoTextBackbone(feature_dim=feature_dim)
model = TinyVideoTextClassifier(backbone=backbone, num_classes=num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for video_paths, texts, labels in train_loader:
        # Decode videos to lists of PIL frames
        batch_videos = [read_video_as_pil_list(p) for p in video_paths]

        # Use the same processor we will give to DLFeat
        proc_out = tiny_multimodal_video_processor(text=texts, videos=batch_videos, return_tensors="pt")
        videos_tensor = proc_out["videos"].to(device)
        input_ids     = proc_out["input_ids"].to(device)
        attention_mask = proc_out["attention_mask"].to(device)

        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(videos=videos_tensor, input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * labels.size(0)

    avg_loss = running_loss / len(train_ds)
    print(f"Epoch {epoch+1}/{epochs} - train loss: {avg_loss:.4f}")

# ---------------------------
# 6. Register the trained backbone with DLFeat
# ---------------------------

register_multimodal_video_text_model(
    model_name="xclip_tiny_video_text",   # prefix 'xclip' so DLFeat uses frame lists
    dim=feature_dim,
    model=backbone,                       # we register only the backbone (features)
    processor=tiny_multimodal_video_processor,
    num_frames=num_frames,
    overwrite=True,
)

# ---------------------------
# 7. Use DLFeatExtractor to get aligned video/text features
# ---------------------------

extractor = DLFeatExtractor("xclip_tiny_video_text", device=device) # type: ignore

# Build a small list of (video_path, text) pairs (could be any external data)
eval_pairs = list(zip(val_paths[:4], val_texts[:4]))
features_dict = extractor.transform(eval_pairs, batch_size=2)

print("Video features shape:", features_dict["video_features"].shape)
print("Text  features shape:", features_dict["text_features"].shape)




Epoch 1/3 - train loss: 1.0771
Epoch 2/3 - train loss: 0.9520
Epoch 3/3 - train loss: 0.6978
Video features shape: (4, 64)
Text  features shape: (4, 64)
