In [1]:
# Install required packages for Step 4.
# You can re-run this cell if you create a new environment.

%pip install --upgrade pip

# Core libraries
%pip install torch torchvision torchaudio

# NLP / vision / training utilities
%pip install transformers datasets webdataset accelerate timm sentencepiece


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


# Step 4: Train Unimodal Experts (Text & Image)\n
\n
This notebook uses the packed shards from **Step 3** to train:\n
\n
- A **text-only expert** over the combined text (OCR + captions).\n
- An **image-only expert** over meme images.\n
\n
You can start with the 50-example pilot shards created in Step 3 and later point this\n
notebook at larger shards once you pack the full datasets.\n

In [2]:
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader

import webdataset as wds
import json

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoImageProcessor,
    AutoModelForImageClassification,
)

# Detect project root so this works whether you start Jupyter in the repo root
# or from inside Step_4/.
cwd = Path.cwd().resolve()
if (cwd / "Step_3").is_dir():
    root = cwd
else:
    root = cwd.parent

step3 = root / "Step_3"
shards_dir = step3 / "shards" / "train"

# For the 50-example pilot, we expect a single shard here
shard_pattern = str(shards_dir / "shard-000000.tar")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Model names (can be adjusted later)
TEXT_MODEL_NAME = "microsoft/mdeberta-v3-base"
VISION_MODEL_NAME = "google/siglip-base-patch16-384"  # or another SigLIP-compatible model

Using device: cpu


> **Note:** If you have not installed the required libraries yet, run the following\n
> in a separate cell (uncomment as needed):\n
\n
> ```bash\n
> pip install torch torchvision torchaudio\n
> pip install transformers datasets webdataset accelerate timm sentencepiece\n
> ```\n

In [3]:
from typing import Dict, Any, List, Optional

def make_text_dataset(shard_pattern: str, max_samples: Optional[int] = None):
    ds = (
        wds.WebDataset(shard_pattern, shardshuffle=False)
        .to_tuple("txt", "json")
    )

    out: List[Dict[str, Any]] = []
    for text_obj, meta_obj in ds:
        # text: bytes or str
        if isinstance(text_obj, (bytes, bytearray)):
            text = text_obj.decode("utf-8", errors="replace")
        else:
            text = str(text_obj)

        # meta: dict or bytes
        if isinstance(meta_obj, (bytes, bytearray)):
            meta = json.loads(meta_obj.decode("utf-8"))
        else:
            meta = meta_obj

        labels = (meta or {}).get("labels", {})
        y = labels.get("abuse_hate")
        if y is None:
            continue
        out.append({"text": text, "label": int(y)})
        if max_samples is not None and len(out) >= max_samples:
            break
    return out


def make_image_dataset(shard_pattern: str, max_samples: Optional[int] = None):
    ds = (
        wds.WebDataset(shard_pattern, shardshuffle=False)
        .decode("pil")
        .to_tuple("png", "json")
    )

    out: List[Dict[str, Any]] = []
    for img, meta_obj in ds:
        if isinstance(meta_obj, (bytes, bytearray)):
            meta = json.loads(meta_obj.decode("utf-8"))
        else:
            meta = meta_obj

        labels = (meta or {}).get("labels", {})
        y = labels.get("abuse_hate")
        if y is None:
            continue
        out.append({"image": img, "label": int(y)})
        if max_samples is not None and len(out) >= max_samples:
            break
    return out

text_examples = make_text_dataset(shard_pattern, max_samples=1000)
image_examples = make_image_dataset(shard_pattern, max_samples=1000)
print(f"Loaded {len(text_examples)} text examples and {len(image_examples)} image examples.")

Loaded 50 text examples and 50 image examples.


In [4]:
from torch.utils.data import Dataset
from typing import Dict, Any, List


class TextDataset(Dataset):
    def __init__(self, data: List[Dict[str, Any]], tokenizer, max_length: int = 256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        text = item["text"]
        label = item["label"]
        enc = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        enc = {k: v.squeeze(0) for k, v in enc.items()}
        enc["labels"] = torch.tensor(label, dtype=torch.long)
        return enc


# Initialize tokenizer and model for text expert
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
model_text = AutoModelForSequenceClassification.from_pretrained(
    TEXT_MODEL_NAME,
    num_labels=2,
)
model_text.to(device)

# Use all loaded text examples as training data for this pilot
dataset_text = TextDataset(text_examples, tokenizer)
loader_text = DataLoader(dataset_text, batch_size=8, shuffle=True)

optimizer = torch.optim.AdamW(model_text.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

EPOCHS = 1  # pilot run; can be increased later

model_text.train()
for epoch in range(EPOCHS):
    total_loss = 0.0
    for batch in loader_text:
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")
        outputs = model_text(**batch)
        logits = outputs.logits
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"[Text] Epoch {epoch + 1}/{EPOCHS} - loss: {total_loss / len(loader_text):.4f}")

# Save text expert
models_dir = root / "models"
models_dir.mkdir(exist_ok=True)
model_text.save_pretrained(models_dir / "text_expert")
tokenizer.save_pretrained(models_dir / "text_expert")
print("Saved text expert to", models_dir / "text_expert")

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


spm.model:   0%|          | 0.00/4.31M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


pytorch_model.bin:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/mdeberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

[Text] Epoch 1/1 - loss: 0.6173
Saved text expert to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/text_expert


In [5]:
from torchvision import transforms
from torch.utils.data import Dataset
from typing import Dict, Any, List


class ImageDataset(Dataset):
    def __init__(self, data: List[Dict[str, Any]], image_processor, train: bool = True):
        self.data = data
        self.image_processor = image_processor
        self.train = train
        self.aug = transforms.Compose([
            transforms.RandomHorizontalFlip(),
        ])

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        img = item["image"]
        label = item["label"]
        if self.train:
            img = self.aug(img)
        inputs = self.image_processor(images=img, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)
        return {
            "pixel_values": pixel_values,
            "labels": torch.tensor(label, dtype=torch.long),
        }


# Initialize image processor and model for vision expert
image_processor = AutoImageProcessor.from_pretrained(VISION_MODEL_NAME)
model_vision = AutoModelForImageClassification.from_pretrained(
    VISION_MODEL_NAME,
    num_labels=2,
)
model_vision.to(device)

# Dataset and loader for image examples
dataset_img = ImageDataset(image_examples, image_processor, train=True)
loader_img = DataLoader(dataset_img, batch_size=8, shuffle=True)

optimizer_v = torch.optim.AdamW(model_vision.parameters(), lr=1e-4)
criterion_v = nn.CrossEntropyLoss()

EPOCHS_V = 1  # pilot

model_vision.train()
for epoch in range(EPOCHS_V):
    total_loss = 0.0
    for batch in loader_img:
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")
        outputs = model_vision(**batch)
        logits = outputs.logits
        loss = criterion_v(logits, labels)
        optimizer_v.zero_grad()
        loss.backward()
        optimizer_v.step()
        total_loss += loss.item()
    print(f"[Vision] Epoch {epoch + 1}/{EPOCHS_V} - loss: {total_loss / len(loader_img):.4f}")

# Save vision expert
models_dir = root / "models"
models_dir.mkdir(exist_ok=True)
model_vision.save_pretrained(models_dir / "vision_expert")
image_processor.save_pretrained(models_dir / "vision_expert")
print("Saved vision expert to", models_dir / "vision_expert")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Some weights of SiglipForImageClassification were not initialized from the model checkpoint at google/siglip-base-patch16-384 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Vision] Epoch 1/1 - loss: 0.7791
Saved vision expert to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/vision_expert
