In [None]:
from copy import deepcopy
from pathlib import Path

In [None]:
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch

model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

texts = ["a photo of 2 dsadsadcats",'dsadasd']
# important: we pass `padding=max_length` since the model was trained with this
inputs = processor(text=texts, images=[image,image], padding="max_length", return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image) # these are the probabilities
print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")

In [None]:
from transformers import SiglipModel, SiglipConfig
from typing import Optional, Tuple, Union
import torch
class Siglip(SiglipModel):
    config_class = SiglipConfig
    def __init__(self, config: SiglipConfig):
        super().__init__(config)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        return_loss: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        image_embeds = vision_outputs[1]
        text_embeds = text_outputs[1]
        # normalized features
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
        return image_embeds, text_embeds, self.logit_scale.exp(), self.logit_bias
wrapper_model = Siglip.from_pretrained("/home/mila/l/le.zhang/scratch/github_clone/Enhance-FineGrained/src/Outputs/test_07-Apr-2024-22-52-56/checkpoints/epoch_1.pt")


In [4]:
config = SiglipConfig.from_pretrained("google/siglip-base-patch16-224")
model = Siglip(config)

In [9]:
def load_state_dict(checkpoint_path: str, map_location='cpu'):
    checkpoint = torch.load(checkpoint_path, map_location=map_location)
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint
    if next(iter(state_dict.items()))[0].startswith('module'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}
    return state_dict

state_dict = load_state_dict("/home/mila/l/le.zhang/scratch/github_clone/Enhance-FineGrained/src/Outputs/test_07-Apr-2024-22-52-56/checkpoints/epoch_1.pt")

In [15]:
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

In [None]:
model.load_state_dict(checkpoint)

In [None]:
import torch
device =torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n =128
logits = torch.randn(n, n)
labels = 2 * torch.eye(n, device=device) - torch.ones(n, device = device) 
labels

In [None]:
import torch.nn.functional as F
-torch.mean(F.logsigmoid(labels * logits))

In [None]:
torchvision.transforms(