# 03a · Using `InfoNCELoss` with `transformers.Trainer`

This notebook shows how to integrate the core `InfoNCELoss` module with the standard Hugging Face `Trainer`.

We demonstrate:

1. Preparing SFT and contrastive datasets with `tokenizer.apply_chat_template` so assistant tokens are masked.
2. Extending `Trainer` to combine SFT and InfoNCE losses via `combine_craft_losses`.
3. Logging basic metrics such as contrastive accuracy.

> ⚠️ Update checkpoints, dataset sizes, and output directories before running.


## 0. Optional environment setup


In [None]:
# !pip install -U "contrastive-ft @ git+https://github.com/omarkamali/craft"
# !pip install -U "datasets>=2.19" "transformers>=4.43" "accelerate>=0.30"

## 1. Imports


In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments

from craft.data import CRAFTCollator, make_craft_datasets, CRAFTMixedDataLoader
from craft.losses import InfoNCELoss, combine_craft_losses
from craft.metrics import compute_contrastive_accuracy


## 2. Load demo datasets


In [None]:
sft_raw = load_dataset("HuggingFaceH4/ultrachat_200k", split="train[:0.2%]")
contrastive_raw = load_dataset("sentence-transformers/all-nli", split="train[:0.2%]")


## 3. Tokeniser helpers


In [None]:
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenizer.pad_token = tokenizer.eos_token
MAX_LENGTH = 1024

def encode_sft(example):
    encoded = tokenizer.apply_chat_template(
        example["messages"],
        tokenize=True,
        add_generation_prompt=False,
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt",
        return_dict=True,
        return_assistant_tokens_mask=True,
    )
    input_ids = encoded["input_ids"][0]
    attention_mask = encoded["attention_mask"][0]
    assistant_mask = encoded["assistant_masks"][0]
    labels = input_ids.clone().masked_fill(assistant_mask == 0, -100)
    return {
        "input_ids": input_ids.tolist(),
        "attention_mask": attention_mask.tolist(),
        "labels": labels.tolist(),
        "assistant_mask": assistant_mask.tolist(),
    }

def encode_contrastive(example):
    anchor = tokenizer(
        example["premise"], padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt"
    )
    positive = tokenizer(
        example["hypothesis"], padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt"
    )
    return {
        "input_ids": anchor["input_ids"][0].tolist(),
        "attention_mask": anchor["attention_mask"][0].tolist(),
        "input_ids_tgt": positive["input_ids"][0].tolist(),
        "attention_mask_tgt": positive["attention_mask"][0].tolist(),
    }


## 4. Tokenise datasets


In [None]:
tokenized_sft = sft_raw.map(encode_sft, remove_columns=sft_raw.column_names)
tokenized_contrastive = contrastive_raw.map(
    encode_contrastive, remove_columns=contrastive_raw.column_names
)


## 5. Build bundle and collator


In [None]:
bundle = make_craft_datasets(
    tokenized_sft, contrastive_dataset=tokenized_contrastive, strategy="paired_dataset"
)
collator = CRAFTCollator()


## 6. Model and loss objects


In [None]:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model.config.use_cache = False

craft_loss = InfoNCELoss(
    temperature=0.08,
    pooling="last_token",
    hidden_size=model.config.hidden_size,
)
CRAFT_ALPHA = 0.6


## 7. Custom Trainer


In [None]:
class CraftTrainer(Trainer):
    def __init__(self, *args, craft_bundle, craft_loss, craft_alpha=0.5, length_strategy="oversample", **kwargs):
        super().__init__(*args, **kwargs)
        self.craft_bundle = craft_bundle
        self.craft_loss = craft_loss.to(self.model.device)
        self.craft_alpha = craft_alpha
        self.craft_length_strategy = length_strategy

    def get_train_dataloader(self):
        base_loader = super().get_train_dataloader()
        contrastive_loader = self._build_contrastive_loader(base_loader.batch_size)
        return CRAFTMixedDataLoader(
            base_loader,
            contrastive_loader,
            beta=0.5,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
            length_strategy=self.craft_length_strategy,
        )

    def _build_contrastive_loader(self, batch_size):
        from torch.utils.data import DataLoader

        return DataLoader(
            self.craft_bundle.contrastive_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=self.data_collator,
        )

    def compute_loss(self, model, inputs, return_outputs=False):
        batch_type = inputs.pop("craft_batch_type", "sft")
        if batch_type == "craft":
            loss = self._compute_craft_loss(model, inputs)
            outputs = None
        else:
            outputs = model(**inputs)
            loss = outputs.loss
            loss = combine_craft_losses(sft_loss=loss, contrastive_loss=None, alpha=self.craft_alpha).total_loss
            self.log({
                "loss/craft_sft": float(loss.detach()),
                "loss/craft_total": float(loss.detach()),
            })
        return (loss, outputs) if return_outputs else loss

    def _compute_craft_loss(self, model, inputs):
        anchor_ids = inputs["input_ids"]
        anchor_mask = inputs["attention_mask"]
        positive_ids = inputs["input_ids_tgt"]
        positive_mask = inputs["attention_mask_tgt"]

        anchor_outputs = model(input_ids=anchor_ids, attention_mask=anchor_mask, output_hidden_states=True)
        positive_outputs = model(input_ids=positive_ids, attention_mask=positive_mask, output_hidden_states=True)

        loss, details = self.craft_loss(
            anchor_outputs.hidden_states[-1],
            positive_outputs.hidden_states[-1],
            anchor_mask,
            positive_mask,
            return_details=True,
        )
        total = combine_craft_losses(sft_loss=None, contrastive_loss=loss, alpha=self.craft_alpha)
        accuracy = compute_contrastive_accuracy(details["anchor_embeddings"], details["positive_embeddings"])
        self.log({
            "loss/craft_contrast": float(loss.detach()),
            "loss/craft_total": float(total.total_loss.detach()),
            "metrics/craft_contrastive_accuracy": float(accuracy.detach()),
        })
        return total.total_loss


## 8. Training arguments


In [None]:
training_args = TrainingArguments(
    output_dir="./outputs/craft-transformers-trainer",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    logging_steps=10,
    save_steps=50,
    num_train_epochs=1,
)


## 9. Instantiate and train


In [None]:
trainer = CraftTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_sft,
    data_collator=collator,
    craft_bundle=bundle,
    craft_loss=craft_loss,
    craft_alpha=CRAFT_ALPHA,
    length_strategy="oversample",
)

trainer.train()


## 10. Inspect logs


In [None]:
trainer.state.log_history[-5:]


## 11. Save artefacts


In [None]:
trainer.save_model("./outputs/craft-transformers-trainer")
