In [1]:
from cub_dataset import Cub2011
from torch.utils.data import DataLoader
import torch
import numpy as np
import pandas as pd
import torchvision
import torchvision.transforms
from models.clip_pretrained import CLIP_image

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
])
dataset_args = {
                 'root': '../data',
                 'crop_size': 224,
                 'brightness': 0.4, 
                 'contrast': 0.4, 
                 'saturation': .2, 
                 'hue': .1, 
                 'color_jitter_prob': .4, 
                 'gray_scale_prob': 0.2, 
                 'horizontal_flip_prob': 0.5, 
                 'gaussian_prob': .5, 
                 'min_scale': 0.6, 
                 'max_scale': 0.95}

dataset = Cub2011(dataset_args['root'], dataset_args, download=False)

dataloader = DataLoader(dataset, batch_size=32,  shuffle=False, num_workers=0)
for i, data in enumerate(dataloader):
    break



In [2]:
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

In [3]:
from transformers import CLIPProcessor, CLIPModel
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
out = processor(text="a photo of a cat", images=image, return_tensors="pt", padding=True)
out['pixel_values'].shape

torch.Size([1, 3, 224, 224])

In [4]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=["a photo of a cat", "a photo of a dog"], images = image, return_tensors="pt", padding=True)

outputs = model(**inputs)


In [6]:
outputs.vision_model_output.keys()

odict_keys(['last_hidden_state', 'pooler_output'])

In [39]:
data[0].shape

torch.Size([32, 3, 224, 224])

In [47]:
from transformers import CLIPProcessor, CLIPModel
import torch.nn as nn

class CLIP_text_image(nn.Module):

    def __init__(self, args = None):
        super().__init__()
        self.args = args
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.linear1 = nn.Linear(312, 312)
        self.classifier1 = nn.Sequential(nn.Linear(312, 312), nn.ReLU(), nn.BatchNorm1d(312), nn.Linear(312, 312))

    def forward(self, prompts, images):
        text = prompts
        inputs = self.processor(text=text, return_tensors="pt", padding=True)
        for i in inputs:
            inputs[i] = inputs[i].cuda()
        inputs['pixel_values'] = images.cuda()
        outputs = self.clip(**inputs)

        binary_pred1 = self.linear1(outputs.logits_per_image)
        classifier_pred1 = self.classifier1(outputs.logits_per_image)

        print(outputs.logits_per_image.shape)


        return (binary_pred1, classifier_pred1), outputs.logits_per_image

with open('../data/CUB_200_2011/attributes.txt', 'r') as f:
    lines = f.readlines()
text_prompts = []
for line in lines:

    #base that will be used for every image
    start = 'The bird has a '

    #get the words before seeing the descriptor
    beginning = ''
    seen = False


    for i in line.split()[1].split('_'):

        #:: signigifies that the attribute value is on the other side
        if '::' in i:
            first_half = i.split('::')[0]
            second_half = i.split('::')[1]
            seen = True

        #if we have seen the descriptor, we are done and ( signifies that
        if '(' in i:
            break
        if i != 'has':
            if '::' in i: continue
            if seen: second_half += ' ' + i
            else: beginning += i + ' '
    start += second_half + ' ' + beginning  + first_half
    text_prompts.append(start)
model = CLIP_text_image(None)
model = model.cuda()
output = model(text_prompts, data[0][:2])


torch.Size([2, 312])


In [45]:
output[0][0].shape

torch.Size([2, 312])

In [27]:
i.replace('_', ' ')

'pattern::multi-colored'

In [17]:
lines[0]

'1 has_bill_shape::curved_(up_or_down)\n'