In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from transformers import AutoTokenizer, Siglip2TextModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DualSiglip2Model(nn.Module):
    def __init__(self, model_name="google/siglip2-base-patch16-224"):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoder_a = Siglip2TextModel.from_pretrained(model_name)
        self.encoder_b = deepcopy(self.encoder_a)
        self.bias = nn.Parameter(torch.zeros(1))
        self.to(device)
        self._signature_columns = ["a","b"]

    def embed(self, texts, encoder):
        inputs = self.tokenizer(texts, padding="max_length", truncation=True, return_tensors="pt", max_length=64).to(device)
        outputs = encoder(**inputs)
        return F.normalize(outputs.pooler_output)

    def forward(self, a, b):
        emb_a = self.embed(a, self.encoder_a)
        emb_b = self.embed(b, self.encoder_b)
        return emb_a, emb_b

    def loss(self, emb_a, emb_b):
        sim = emb_a @ emb_b.t() + self.bias
        y = -torch.ones_like(sim)
        y.fill_diagonal_(1)
        return torch.log1p(torch.exp(-y * sim)).mean()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset, DatasetDict

def create_splits(dataset):
    ds_train_devtest = dataset.train_test_split(test_size=0.2, seed=42)
    ds_devtest = ds_train_devtest['test'].train_test_split(test_size=0.5, seed=42)

    ds_splits = DatasetDict({
        'train': ds_train_devtest['train'],
        'valid': ds_devtest['train'],
        'test': ds_devtest['test']
    })
    return ds_splits

def shuffle_pairs(split, key_a='q', key_b='d'):
    shuffled = split.shuffle(seed=42)
    return split.map(
        lambda example, idx: {
            key_a: example[key_a],
            key_b: shuffled[idx][key_b],
        },
        with_indices=True,
    )

dataset = load_dataset("data", split="train")

data = create_splits(dataset)

In [3]:
from transformers import Trainer, TrainingArguments

model = DualSiglip2Model()

training_args = TrainingArguments(
    output_dir="./siglip2-dual",
    per_device_train_batch_size=4,
    num_train_epochs=3,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_steps=10,
    save_steps=50,
    save_total_limit=2,
    report_to="none",
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data["train"],
    eval_dataset=data["valid"],
)

trainer.evaluate()

trainer.train()

trainer.evaluate()

You are using a model of type siglip_text_model to instantiate a model of type siglip2_text_model. This is not supported for all configurations of models and can yield errors.


TypeError: can only join an iterable