# 3. Training a CLIP-Field

In this tutorial, we are going to create a CLIP-Field from our saved data. CLIP-Field is an implicit neural field that maps from 3D XYZ coordinates to higher dimensional representations such as CLIP visual features and Sentence-BERT semantic embeddings.

In [1]:
import logging
import os
import pprint
import random
from typing import Dict, Union

import hydra
import numpy as np
import torch
import torch.nn.functional as F
import torchmetrics
import tqdm
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Subset

import wandb
import sys
sys.path.append('..')

In [2]:
from dataloaders import (
    R3DSemanticDataset,
    DeticDenseLabelledDataset,
    ClassificationExtractor,
)
from misc import ImplicitDataparallel
from grid_hash_model import GridCLIPModel

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## Load the data and create a model

Now, we will set up the constants and create the models.

In [3]:
# Set up the constants

SAVE_DIRECTORY = "../clip_implicit_model"
DEVICE = "cuda"
IMAGE_TO_LABEL_CLIP_LOSS_SCALE = 1.0
LABEL_TO_IMAGE_LOSS_SCALE = 1.0
EXP_DECAY_COEFF = 0.5
SAVE_EVERY = 5
METRICS = {
    "accuracy": torchmetrics.Accuracy,
}

BATCH_SIZE = 11000
NUM_WORKERS = 10

CLIP_MODEL_NAME = "ViT-B/32"
SBERT_MODEL_NAME = "all-mpnet-base-v2"

In [4]:
# Load the data and create the dataloader created in the previous tutorial notebook

training_data = torch.load("../detic_labeled_dataset.pt")
max_coords, _ = training_data._label_xyz.max(dim=0)
min_coords, _ = training_data._label_xyz.min(dim=0)

In [5]:
# Set up the model

label_model = GridCLIPModel(
    image_rep_size=training_data[0]["clip_image_vector"].shape[-1],
    text_rep_size=training_data[0]["clip_vector"].shape[-1],
    mlp_depth=1,
    mlp_width=600,
    log2_hashmap_size=20,
    num_levels=18,
    level_dim=8,
    per_level_scale=2,
    max_coords=max_coords,
    min_coords=min_coords,
).to(DEVICE)

## Training and evaulation code

Now, we will set up the training and the evaluation code. We will train the model to predict the CLIP/SBert features from the 3D coordinates with a contrastive loss. For evaluation, we will measure the zero-shot label accuracy of the model.

In [6]:
@torch.no_grad()
def zero_shot_eval(
    classifier: ClassificationExtractor, 
    predicted_label_latents: torch.Tensor, 
    predicted_image_latents: torch.Tensor, 
    language_label_index: torch.Tensor, 
    metric_calculators: Dict[str, Dict[str, torchmetrics.Metric]]
):
    """Evaluate the model on the zero-shot classification task."""
    class_probs = classifier.calculate_classifications(
        model_text_features=predicted_label_latents,
        model_image_features=predicted_image_latents,
    )
    # Now figure out semantic accuracy and loss.
    # Semseg mask is necessary for the boundary case where all the points in the batch are "unlabeled"
    semseg_mask = torch.logical_and(
        language_label_index != -1,
        language_label_index < classifier.total_label_classes,
    ).squeeze(-1)
    if not torch.any(semseg_mask):
        classification_loss = torch.zeros_like(semseg_mask).mean(dim=-1)
    else:
        # Figure out the right classes.
        masked_class_prob = class_probs[semseg_mask]
        masked_labels = language_label_index[semseg_mask].squeeze(-1).long()
        classification_loss = F.cross_entropy(
            torch.log(masked_class_prob),
            masked_labels,
        )
        if metric_calculators.get("semantic"):
            for _, calculators in metric_calculators["semantic"].items():
                _ = calculators(masked_class_prob, masked_labels)
    return classification_loss

In [7]:
def train(
    clip_train_loader: DataLoader,
    labelling_model: Union[GridCLIPModel, ImplicitDataparallel],
    optim: torch.optim.Optimizer,
    epoch: int,
    classifier: ClassificationExtractor,
    device: Union[str, torch.device] = DEVICE,
    exp_decay_coeff: float = EXP_DECAY_COEFF,
    image_to_label_loss_ratio: float = IMAGE_TO_LABEL_CLIP_LOSS_SCALE,
    label_to_image_loss_ratio: float = LABEL_TO_IMAGE_LOSS_SCALE,
    disable_tqdm: bool = False,
    metric_calculators: Dict[str, Dict[str, torchmetrics.Metric]] = {},
):
    """
    Train the model for one epoch.
    """
    total_loss = 0
    label_loss = 0
    image_loss = 0
    classification_loss = 0
    total_samples = 0
    total_classification_loss = 0
    labelling_model.train()
    total = len(clip_train_loader)
    for clip_data_dict in tqdm.tqdm(
        clip_train_loader,
        total=total,
        disable=disable_tqdm,
        desc=f"Training epoch {epoch}",
    ):
        xyzs = clip_data_dict["xyz"].to(device)
        clip_labels = clip_data_dict["clip_vector"].to(device)
        clip_image_labels = clip_data_dict["clip_image_vector"].to(device)
        image_weights = torch.exp(-exp_decay_coeff * clip_data_dict["distance"]).to(
            device
        )
        label_weights = clip_data_dict["semantic_weight"].to(device)
        image_label_index: torch.Tensor = (
            clip_data_dict["img_idx"].to(device).reshape(-1, 1)
        )
        language_label_index: torch.Tensor = (
            clip_data_dict["label"].to(device).reshape(-1, 1)
        )

        (predicted_label_latents, predicted_image_latents) = labelling_model(xyzs)
        # Calculate the loss from the image to label side.
        batch_size = len(image_label_index)
        image_label_mask: torch.Tensor = (
            image_label_index != image_label_index.t()
        ).float() + torch.eye(batch_size, device=device)
        language_label_mask: torch.Tensor = (
            language_label_index != language_label_index.t()
        ).float() + torch.eye(batch_size, device=device)

        # For logging purposes, keep track of negative samples per point.
        image_label_mask.requires_grad = False
        language_label_mask.requires_grad = False
        contrastive_loss_labels = labelling_model.compute_loss(
            predicted_label_latents,
            clip_labels,
            label_mask=language_label_mask,
            weights=label_weights,
        )
        contrastive_loss_images = labelling_model.compute_loss(
            predicted_image_latents,
            clip_image_labels,
            label_mask=image_label_mask,
            weights=image_weights,
        )
        del (
            image_label_mask,
            image_label_index,
            language_label_mask,
        )

        # Mostly for evaluation purposes, calculate the classification loss.
        classification_loss = zero_shot_eval(
            classifier, predicted_label_latents, predicted_image_latents, language_label_index, metric_calculators
        )

        contrastive_loss = (
            image_to_label_loss_ratio * contrastive_loss_images
            + label_to_image_loss_ratio * contrastive_loss_labels
        )

        optim.zero_grad(set_to_none=True)
        contrastive_loss.backward()
        optim.step()
        # Clip the temperature term for stability
        labelling_model.temperature.data = torch.clamp(
            labelling_model.temperature.data, max=np.log(100.0)
        )
        label_loss += contrastive_loss_labels.detach().cpu().item()
        image_loss += contrastive_loss_images.detach().cpu().item()
        total_classification_loss += classification_loss.detach().cpu().item()
        total_loss += contrastive_loss.detach().cpu().item()
        total_samples += 1

    to_log = {
        "train_avg/contrastive_loss_labels": label_loss / total_samples,
        "train_avg/contrastive_loss_images": image_loss / total_samples,
        "train_avg/semseg_loss": total_classification_loss / total_samples,
        "train_avg/loss_sum": total_loss / total_samples,
        "train_avg/labelling_temp": torch.exp(labelling_model.temperature.data.detach())
        .cpu()
        .item(),
    }
    for metric_dict in metric_calculators.values():
        for metric_name, metric in metric_dict.items():
            try:
                to_log[f"train_avg/{metric_name}"] = (
                    metric.compute().detach().cpu().item()
                )
            except RuntimeError as e:
                to_log[f"train_avg/{metric_name}"] = 0.0
            metric.reset()
    wandb.log(to_log)
    logging.debug(pprint.pformat(to_log, indent=4, width=1))
    return total_loss

In [8]:
def save(
    labelling_model: Union[ImplicitDataparallel, GridCLIPModel],
    optim: torch.optim.Optimizer,
    epoch: int,
    save_directory: str = SAVE_DIRECTORY,
    saving_dataparallel: bool = False,
):
    if saving_dataparallel:
        to_save = labelling_model.module
    else:
        to_save = labelling_model
    state_dict = {
        "model": to_save.state_dict(),
        "optim": optim.state_dict(),
        "epoch": epoch,
    }
    torch.save(
        state_dict,
        f"{save_directory}/implicit_scene_label_model_latest.pt",
    )
    return 0

## Set up the auxilary classes

Like zero-shot classifier, dataloader, evaluators, optimizer, etc.

In [9]:
train_classifier = ClassificationExtractor(
    clip_model_name=CLIP_MODEL_NAME,
    sentence_model_name=SBERT_MODEL_NAME,
    class_names=training_data._all_classes,
    device=DEVICE,
)

INFO - 2022-10-11 10:25:47,753 - SentenceTransformer - Load pretrained SentenceTransformer: all-mpnet-base-v2


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

In [10]:
# Set up our metrics on this dataset.
train_metric_calculators = {}
train_class_count = {"semantic": train_classifier.total_label_classes}
average_style = ["micro", "macro", "weighted"]
for classes, counts in train_class_count.items():
    train_metric_calculators[classes] = {}
    for metric_name, metric_cls in METRICS.items():
        for avg in average_style:
            if "accuracy" in metric_name:
                new_metric = metric_cls(
                    num_classes=counts, average=avg, multiclass=True
                ).to(DEVICE)
                train_metric_calculators[classes][
                    f"{classes}_{metric_name}_{avg}"
                ] = new_metric


In [11]:
# No dataparallel for now
batch_multiplier = 1

clip_train_loader = DataLoader(
    training_data,
    batch_size=batch_multiplier * BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=NUM_WORKERS,
)
logging.debug(f"Total train dataset sizes: {len(training_data)}")

In [12]:
# Set up optimizer

optim = torch.optim.Adam(
    label_model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999),
    weight_decay=0.003,
)

## Model training

Now we run our training loop and save the model occassionally. We ran this for 5 epochs just to validate everything is working properly, but to train a full model you should train it for longer.

In [13]:
wandb.init(
    project="clipfields",
)
# Set the extra parameters.
wandb.config.web_labelled_points = len(training_data)

ERROR - 2022-10-11 10:25:49,099 - jupyter - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: Currently logged in as: [33mmahi[0m. Use [1m`wandb login --relogin`[0m to force relogin


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.03333655993143717, max=1.0)…

In [14]:
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Just to reduce excessive logging from sbert

epoch = 0
NUM_EPOCHS = 5

while epoch <= NUM_EPOCHS:
    train(
        clip_train_loader,
        label_model,
        optim,
        epoch,
        train_classifier,
        metric_calculators=train_metric_calculators,
    )
    epoch += 1
    if epoch % SAVE_EVERY == 0:
        save(label_model, optim, epoch)

Training epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.25it/s]
Training epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.34it/s]
Training epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.33it/s]
Training epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.35it/s]
Training epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

This is saved already in `../clip_implicit_model`, so we don't have to save our trained model again. You can see the run data on [Weights and biases](https://wandb.ai/mahi/clipfields/runs/j12j175e). On our next tutorial episode, we will evaluate our model.