In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from transformers import AutoTokenizer, Siglip2TextModel

device = "cuda" if torch.cuda.is_available() else "cpu"

class DualSiglip2Model(nn.Module):
    def __init__(self, model_name="google/siglip2-base-patch16-224"):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoder_a = Siglip2TextModel.from_pretrained(model_name)
        self.encoder_b = deepcopy(self.encoder_a)
        self.bias = nn.Parameter(torch.zeros(1))
        self.to(device)

    def tokenize(self, texts):
        return self.tokenizer(texts, padding="max_length", truncation=True, return_tensors="pt", max_length=64).to(device)

    def embed(self, texts, encoder):
        inputs = self.tokenize(texts)
        outputs = encoder(**inputs)
        return F.normalize(outputs.pooler_output)

# https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip2/modeling_siglip2.py#L952
    def forward(self, a, b):
        a = self.encoder_a(a).pooler_output
        b = self.encoder_b(b).pooler_output
        a = a / a.norm(p=2, dim=-1, keepdim=True)
        b = b / b.norm(p=2, dim=-1, keepdim=True)
        loss = self.loss(a, b)
        logits = a @ b.t() + self.bias
        return {"loss": loss, "logits": logits}

    def loss(self, emb_a, emb_b):
        sim = emb_a @ emb_b.t() + self.bias
        eye = torch.eye(sim.size(0), device=sim.device)
        y = -torch.ones_like(sim)
        y = y + 2 * eye
        loglik = F.logsigmoid(y * sim)
        nll = -torch.sum(loglik, dim=-1)
        loss = nll.mean()
        return loss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset, DatasetDict

def create_splits(dataset):
    ds_train_devtest = dataset.train_test_split(test_size=0.2, seed=42)
    ds_devtest = ds_train_devtest['test'].train_test_split(test_size=0.5, seed=42)

    ds_splits = DatasetDict({
        'train': ds_train_devtest['train'],
        'valid': ds_devtest['train'],
        'test': ds_devtest['test']
    })
    return ds_splits

def shuffle_pairs(split, key_a='q', key_b='d'):
    shuffled = split.shuffle(seed=42)
    return split.map(
        lambda example, idx: {
            key_a: example[key_a],
            key_b: shuffled[idx][key_b],
        },
        with_indices=True,
    )

dataset = load_dataset("data", split="train")

data = dataset.train_test_split(test_size=0.2, seed=42)

model = DualSiglip2Model()

def preprocess(example):
    t1 = model.tokenize(example["q"])
    t2 = model.tokenize(example["d"])
    return {"q": t1, "d": t2}

data = data.map(preprocess)

You are using a model of type siglip_text_model to instantiate a model of type siglip2_text_model. This is not supported for all configurations of models and can yield errors.
Map: 100%|██████████| 6100/6100 [00:06<00:00, 899.63 examples/s] 
Map: 100%|██████████| 1525/1525 [00:01<00:00, 867.51 examples/s]


In [3]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./siglip2",
    per_device_train_batch_size=1,
    num_train_epochs=3,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_steps=100,
    save_steps=100,
    save_total_limit=2,
    remove_unused_columns=False,
    max_steps=1000,
    # bf16=True,
    # optim="adamw_bnb_8bit",
    # torch_compile=True,
    # torch_compile_backend="inductor"
)

def collate_fn(batch):
    d = {
        "a": torch.tensor([ex["q"]["input_ids"] for ex in batch]),
        "b": torch.tensor([ex["d"]["input_ids"] for ex in batch]),
    }
    return d

trainer = Trainer(
    model,
    training_args,
    train_dataset=data["train"],
    eval_dataset=data["test"],
    data_collator=collate_fn,
)

print(trainer.evaluate())

trainer.train()

print(trainer.evaluate())

{'eval_model_preparation_time': 0.0025, 'eval_runtime': 3.2022, 'eval_samples_per_second': 476.239, 'eval_steps_per_second': 59.647}


Step,Training Loss
100,0.314
200,0.3115
300,0.3104
400,0.3094
500,0.3086
600,0.308
700,0.3074
800,0.307
900,0.3068
1000,0.3066


{'eval_model_preparation_time': 0.0025, 'eval_runtime': 3.2995, 'eval_samples_per_second': 462.191, 'eval_steps_per_second': 57.888, 'epoch': 0.16393442622950818}
