In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from transformers import AutoTokenizer, Siglip2TextModel
from safetensors.torch import load_file

device = "cuda" if torch.cuda.is_available() else "cpu"

class DualSiglip2Model(nn.Module):
    def __init__(self, model_name="google/siglip2-base-patch16-224"):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoder_bool = Siglip2TextModel.from_pretrained(model_name)
        self.encoder_text = deepcopy(self.encoder_bool)
        self.bias = nn.Parameter(torch.zeros(1))
        self.to(device)

    def tokenize(self, texts):
        return self.tokenizer(texts, padding="max_length", truncation=True, return_tensors="pt", max_length=64).to(device)

# https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip2/modeling_siglip2.py#L952
    def forward(self, in_bool, in_text):
        tok_bool = self.tokenize(in_bool)
        tok_text = self.tokenize(in_text)
        out_bool = self.encoder_bool(**tok_bool).pooler_output
        out_text = self.encoder_text(**tok_text).pooler_output
        out_bool = out_bool / out_bool.norm(p=2, dim=-1, keepdim=True)
        out_text = out_text / out_text.norm(p=2, dim=-1, keepdim=True)
        loss = self.loss(out_bool, out_text)
        logits = out_bool @ out_text.t() + self.bias
        return {"loss": loss, "logits": logits}

    def loss(self, emb_a, emb_b):
        sim = emb_a @ emb_b.t() + self.bias
        eye = torch.eye(sim.size(0), device=sim.device)
        y = -torch.ones_like(sim)
        y = y + 2 * eye
        loglik = F.logsigmoid(y * sim)
        nll = -torch.sum(loglik, dim=-1)
        loss = nll.mean()
        return loss

    def load(self, path):
        state_dict = load_file(path, device)
        self.load_state_dict(state_dict, strict=False)
        return self

    def evaluate(self, in_bool, in_text):
        self.eval()
        with torch.no_grad():
            outputs = self(in_bool, in_text)
            logits = outputs["logits"]
            probs = torch.sigmoid(logits)
        return probs.cpu().numpy()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import random
from datasets import Dataset, IterableDataset
import numpy as np

class RandomAccessMismatchedPairs:
    def __init__(self, dataset, key_a='q', key_b='d'):
        self.dataset = dataset
        self.key_a = key_a
        self.key_b = key_b
        self.size = len(dataset)
        self.total_pairs = self.size * (self.size - 1)

    def __len__(self):
        return self.total_pairs

    def _index_to_pair(self, index):
        i = index // (self.size - 1)
        j = index % (self.size - 1)
        if j >= i: j += 1
        return i, j

    def __getitem__(self, index):
        if index < 0 or index >= self.total_pairs:
            raise IndexError("Index out of bounds")
        i, j = self._index_to_pair(index)
        return {
            self.key_a: self.dataset[i][self.key_a],
            self.key_b: self.dataset[j][self.key_b],
        }

    def random_sample(self, k=1, seed=42):
        rng = random.Random(seed)
        indices = rng.sample(range(self.total_pairs), k)
        return [self[i] for i in indices]

    def get_n_pairs(self, n, random_order=False, seed=42):
        if n > self.total_pairs:
            raise ValueError("Requested more pairs than available.")
        if random_order:
            return self.random_sample(n, seed)
        else:
            return [self[i] for i in range(n)]

    def to_dataset(self, n=None, random_order=False, seed=42):
        if n is None:
            n = self.total_pairs
        data = self.get_n_pairs(n, random_order, seed)
        return Dataset.from_list(data)

def create_mismatched_dataset(dataset: Dataset) -> Dataset:
    qs = np.array(dataset["q"])
    ds = np.array(dataset["d"])

    n = len(qs)
    idx_q = np.repeat(np.arange(n), n - 1)
    idx_d = np.array([np.delete(np.arange(n), i) for i in range(n)]).flatten()

    mismatched_qs = qs[idx_q]
    mismatched_ds = ds[idx_d]

    return Dataset.from_dict({"q": mismatched_qs.tolist(), "d": mismatched_ds.tolist()})

def mismatched_generator(dataset):
    data = list(dataset)
    for i, item_q in enumerate(data):
        for j, item_d in enumerate(data):
            if i != j:
                yield {"q": item_q["q"], "d": item_d["d"]}

def create_lazy_mismatched_dataset(dataset: Dataset) -> IterableDataset:
    return IterableDataset.from_generator(lambda: mismatched_generator(dataset))


In [3]:
from datasets import load_dataset

dataset = load_dataset("data", split="train")

neg_data = RandomAccessMismatchedPairs(dataset, key_a="q", key_b="d").to_dataset(1000, random_order=True)
print(neg_data)

data = dataset.train_test_split(test_size=0.2, seed=42)
print(data)

model = DualSiglip2Model()

# def preprocess(example):
#     t1 = model.tokenize(example["q"])
#     t2 = model.tokenize(example["d"])
#     return {"q": t1, "d": t2}

# data = data.map(preprocess)

Dataset({
    features: ['q', 'd'],
    num_rows: 1000
})
DatasetDict({
    train: Dataset({
        features: ['q', 'd'],
        num_rows: 6100
    })
    test: Dataset({
        features: ['q', 'd'],
        num_rows: 1525
    })
})


You are using a model of type siglip_text_model to instantiate a model of type siglip2_text_model. This is not supported for all configurations of models and can yield errors.


In [4]:
print("Positive Untrained", model.evaluate(data["test"]["q"], data["test"]["d"]))
print("Negative Untrained", model.evaluate(neg_data["q"], neg_data["d"]))
model = model.load(r"./siglip2/checkpoint-18300/model.safetensors")
print("Positive Trained", model.evaluate(data["test"]["q"], data["test"]["d"]))
print("Negative Trained", model.evaluate(neg_data["q"], neg_data["d"]))

Positive Untrained [[0.6791712  0.66508985 0.6837426  ... 0.67505753 0.68886775 0.6758862 ]
 [0.6769943  0.6663372  0.6767399  ... 0.6728213  0.6815574  0.6702494 ]
 [0.6878812  0.67477584 0.6922052  ... 0.6883122  0.6967495  0.6881189 ]
 ...
 [0.6876252  0.67707604 0.69353455 ... 0.68361294 0.69489396 0.68127614]
 [0.68463874 0.6735706  0.68839514 ... 0.6820059  0.6871852  0.67682475]
 [0.6744925  0.6638663  0.6794575  ... 0.6737547  0.68450654 0.6801185 ]]
Negative Untrained [[0.68646175 0.6822819  0.6894692  ... 0.6853466  0.68375593 0.7064092 ]
 [0.68411094 0.68023825 0.68619466 ... 0.6847603  0.6834683  0.695462  ]
 [0.68516934 0.6828417  0.68859494 ... 0.68543124 0.6832921  0.7063381 ]
 ...
 [0.7043689  0.705027   0.70461035 ... 0.70472145 0.7098322  0.7035266 ]
 [0.69072527 0.686063   0.6929386  ... 0.68947005 0.6872002  0.70067495]
 [0.6893828  0.6842517  0.6929224  ... 0.6896274  0.68589926 0.7108423 ]]
Positive Trained [[0.7640605  0.7523464  0.76783586 ... 0.7606517  0.77205

In [5]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./siglip2",
    per_device_train_batch_size=1,
    num_train_epochs=3,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_steps=100,
    save_steps=100,
    save_total_limit=2,
    remove_unused_columns=False,
    # max_steps=1000,
    # bf16=True,
    # optim="adamw_bnb_8bit",
    # torch_compile=True,
    # torch_compile_backend="inductor"
)

def collate_fn(batch):
    d = {
        "a": torch.tensor([ex["q"]["input_ids"] for ex in batch]),
        "b": torch.tensor([ex["d"]["input_ids"] for ex in batch]),
    }
    return d

trainer = Trainer(
    model,
    training_args,
    train_dataset=data["train"],
    eval_dataset=data["test"],
    data_collator=collate_fn,
)

# print(trainer.evaluate())

trainer.train(resume_from_checkpoint=True)

# print(trainer.evaluate())

AttributeError: 'DualSiglip2Model' object has no attribute '_keys_to_ignore_on_save'