# Step 5: Multimodal Fusion (Pilot)

This notebook trains a **small multimodal fusion model** on the packed shards from **Step 3**, reusing the **text expert** and **vision expert** trained in **Step 4**.

For now we run on the 50-example pilot shard (Hateful Memes) to validate the full pipeline.


In [1]:
# Install required packages for Step 5 (run once per environment).
# You can skip this cell if everything is already installed.

%pip install --upgrade pip

# Core libraries
%pip install torch torchvision torchaudio

# NLP / vision / training utilities
%pip install transformers datasets webdataset accelerate timm sentencepiece


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from pathlib import Path

import json

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

import webdataset as wds

from transformers import AutoTokenizer, AutoImageProcessor, AutoModel, AutoModelForImageClassification

# Detect project root so this works whether you start Jupyter in the repo root
# or from inside Step_5/.
cwd = Path.cwd().resolve()
if (cwd / "Step_3").is_dir():
    root = cwd
else:
    root = cwd.parent

step3 = root / "Step_3"
shards_dir = step3 / "shards" / "train"
shard_pattern = str(shards_dir / "shard-000000.tar")  # 50-example pilot shard

models_root = root / "models"
text_expert_dir = models_root / "text_expert"
vision_expert_dir = models_root / "vision_expert"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("Shard pattern:", shard_pattern)
print("Text expert dir:", text_expert_dir)
print("Vision expert dir:", vision_expert_dir)


Using device: cpu
Shard pattern: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Step_3/shards/train/shard-000000.tar
Text expert dir: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/text_expert
Vision expert dir: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/vision_expert


In [3]:
from typing import Dict, Any, List, Optional


def make_fusion_examples(shard_pattern: str, max_samples: Optional[int] = None) -> List[Dict[str, Any]]:
    """Create a list of {text, image, label} from WebDataset shards.

    Uses:
    - '<id>.txt'  for combined text
    - '<id>.png'  for image (already resized in Step 3)
    - '<id>.json' for labels (expects labels.abuse_hate in {0,1})
    """

    ds = (
        wds.WebDataset(shard_pattern, shardshuffle=False)
        .decode("pil")
        .to_tuple("txt", "png", "json")
    )

    out: List[Dict[str, Any]] = []
    for text_obj, img, meta_obj in ds:
        # text may be bytes or str
        if isinstance(text_obj, (bytes, bytearray)):
            text = text_obj.decode("utf-8", errors="replace")
        else:
            text = str(text_obj)

        # meta may be dict or bytes
        if isinstance(meta_obj, (bytes, bytearray)):
            meta = json.loads(meta_obj.decode("utf-8"))
        else:
            meta = meta_obj

        labels = (meta or {}).get("labels", {})
        y = labels.get("abuse_hate")
        if y is None:
            continue

        out.append({
            "text": text,
            "image": img,
            "label": int(y),
        })

        if max_samples is not None and len(out) >= max_samples:
            break

    return out


fusion_examples = make_fusion_examples(shard_pattern, max_samples=1000)
print(f"Loaded {len(fusion_examples)} multimodal examples.")


Loaded 50 multimodal examples.


In [4]:
from torchvision import transforms


class FusionDataset(Dataset):
    def __init__(
        self,
        data: List[Dict[str, Any]],
        text_tokenizer,
        image_processor,
        max_length: int = 256,
        train: bool = True,
    ) -> None:
        self.data = data
        self.text_tokenizer = text_tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        self.train = train
        self.aug = transforms.Compose([
            transforms.RandomHorizontalFlip(),
        ])

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        text = item["text"]
        img = item["image"]
        label = item["label"]

        if self.train:
            img = self.aug(img)

        text_enc = self.text_tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        text_enc = {k: v.squeeze(0) for k, v in text_enc.items()}

        img_enc = self.image_processor(images=img, return_tensors="pt")
        pixel_values = img_enc["pixel_values"].squeeze(0)

        return {
            "input_ids": text_enc["input_ids"],
            "attention_mask": text_enc["attention_mask"],
            "pixel_values": pixel_values,
            "labels": torch.tensor(label, dtype=torch.long),
        }


# Load tokenizers / processors from unimodal experts
text_tokenizer = AutoTokenizer.from_pretrained(text_expert_dir)
image_processor = AutoImageProcessor.from_pretrained(vision_expert_dir)

dataset_fusion = FusionDataset(
    fusion_examples,
    text_tokenizer=text_tokenizer,
    image_processor=image_processor,
    max_length=256,
    train=True,
)

loader_fusion = DataLoader(dataset_fusion, batch_size=8, shuffle=True)
print("Batches per epoch:", len(loader_fusion))


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Batches per epoch: 7


In [5]:
# Build fusion model on top of frozen text and vision encoders

# Load encoders from the fine-tuned unimodal experts
text_encoder = AutoModel.from_pretrained(text_expert_dir)
vision_encoder = AutoModelForImageClassification.from_pretrained(vision_expert_dir)

text_encoder.to(device)
vision_encoder.to(device)

# Freeze encoder parameters for this pilot (only train fusion head)
for p in text_encoder.parameters():
    p.requires_grad = False
for p in vision_encoder.parameters():
    p.requires_grad = False

# Determine feature dimensions
# Text: use hidden_size from the backbone
t_dim = text_encoder.config.hidden_size
# Vision: we will use the classifier logits as features (dim = num_labels)
v_dim = vision_encoder.config.num_labels

print("Text feature dim:", t_dim)
print("Vision feature dim:", v_dim)

fusion_hidden = 512
num_labels = 2


class FusionModel(nn.Module):
    def __init__(
        self,
        text_encoder: nn.Module,
        vision_encoder: nn.Module,
        t_dim: int,
        v_dim: int,
        hidden_dim: int,
        num_labels: int,
    ) -> None:
        super().__init__()
        self.text_encoder = text_encoder
        self.vision_encoder = vision_encoder
        self.mlp = nn.Sequential(
            nn.Linear(t_dim + v_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_labels),
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        # Extract CLS / pooled representations without updating encoder weights
        with torch.no_grad():
            text_out = self.text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            # Prefer pooler_output if available, else CLS token
            if hasattr(text_out, "pooler_output") and text_out.pooler_output is not None:
                t_repr = text_out.pooler_output
            else:
                t_repr = text_out.last_hidden_state[:, 0, :]

            # For the vision expert, use classification logits as a compact feature
            vision_out = self.vision_encoder(pixel_values=pixel_values)
            v_repr = vision_out.logits

        h = torch.cat([t_repr, v_repr], dim=-1)
        logits = self.mlp(h)
        return logits


fusion_model = FusionModel(
    text_encoder=text_encoder,
    vision_encoder=vision_encoder,
    t_dim=t_dim,
    v_dim=v_dim,
    hidden_dim=fusion_hidden,
    num_labels=num_labels,
).to(device)

optimizer = torch.optim.AdamW(fusion_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

EPOCHS = 1  # pilot run

for epoch in range(EPOCHS):
    fusion_model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch in loader_fusion:
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")

        logits = fusion_model(**batch)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / max(1, len(loader_fusion))
    acc = correct / total if total > 0 else 0.0
    print(f"[Fusion] Epoch {epoch + 1}/{EPOCHS} - loss: {avg_loss:.4f} - acc: {acc:.3f}")

# Save fusion model weights (pilot artifact)
mm_dir = models_root / "mm_fusion"
mm_dir.mkdir(exist_ok=True)
fusion_path = mm_dir / "fusion_model.pt"
torch.save(fusion_model.state_dict(), fusion_path)
print("Saved fusion model weights to", fusion_path)


Text feature dim: 768
Vision feature dim: 2
[Fusion] Epoch 1/1 - loss: 0.4944 - acc: 0.840
Saved fusion model weights to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/mm_fusion/fusion_model.pt
