In [1]:
from cub_dataset import Cub2011
from torch.utils.data import DataLoader
import torch
import numpy as np
import pandas as pd
import torchvision
import torchvision.transforms

import os
import pandas as pd
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torch.utils.data import Dataset
from collections import defaultdict
import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms


class Cub2011(Dataset):
    base_folder = 'CUB_200_2011/images'
    url = 'https://data.deepai.org/CUB200\(2011\).zip'
    filename = 'CUB200\(2011\).zip'
    tarfile = 'CUB_200_2011.tgz'
    directory = 'CUB_200_2011'
    tgz_md5 = '97eceeb196236b17998738112f37df78'

    def __init__(self, root, args, train=True, loader=default_loader, download=False,):
        self.root = os.path.expanduser(root)
        #self.transform = transform
        self.loader = default_loader
        self.train = train
        self.num_attributes = 312

        if download:
            self._download()


        #initialize the attributes
        with open(os.path.join(root, 'CUB_200_2011', 'attributes', 'image_attribute_labels.txt'), 'r') as f:
            lines = f.readlines()

        attributes = defaultdict(list)
        certainty = defaultdict(list)
        for line in lines:
            line = line.strip().split()
            attributes[line[0]].append(line[2])
            certainty[line[0]].append(line[3])

        self.attributes = pd.DataFrame.from_dict(attributes, orient='index').reset_index()
        self.certainty = pd.DataFrame.from_dict(certainty, orient='index').reset_index()
        #certaintys 1 = not visible, 2 = guessing, 3 = probably, 4 = definite
        
        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        self.transform = torchvision.transforms.Compose([
                                transforms.RandomApply(
                                    [transforms.ColorJitter(args['brightness'], args['contrast'], args['saturation'], args['hue'])],
                                    p=args['color_jitter_prob'],
                                ),
                            #torchvision.transforms.Resize((224, 224)),
                            transforms.RandomResizedCrop((args['crop_size'], args['crop_size']), scale=(args['min_scale'], args['max_scale']),
                                    interpolation=transforms.InterpolationMode.BICUBIC),
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.RandomHorizontalFlip(p=args['horizontal_flip_prob']),
                            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                std=[0.229, 0.224, 0.225])
                            ])  


    def _load_metadata(self):
        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        data = images.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')

        if self.train:
            truth = self.data.is_training_img == 1
            self.data = self.data[truth]
            self.attributes = self.attributes[truth]
            self.certainty = self.certainty[truth]
        else:
            truth = self.data.is_training_img == 0
            self.data = self.data[truth]
            self.attributes = self.attributes[truth]
            self.certainty = self.certainty[truth]

    def _check_integrity(self):
        try:
            self._load_metadata()
        except Exception:
            return False
        for index, row in self.data.iterrows():
            filepath = os.path.join(self.root, self.base_folder, row.filepath)
            if not os.path.isfile(filepath):
                print(filepath)
                return False
        return True

    def _download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        if not os.path.exists(self.root): os.mkdir(self.root)
        os.system('wget ' + self.url)
        os.system('unzip ' + self.filename)
        os.system('rm ' + self.filename)
        os.system('tar -zxf ' + self.tarfile)
        os.system('rm ' + self.tarfile)
        os.system('mv ' + self.directory + ' ' + self.root + '/CUB_200_2011')
        os.system('mv attributes.txt ' + self.root + '/CUB_200_2011')
        os.system('rm segmentations.tgz')

    def get_attribute(self, idx):
        return self.attributes[idx]


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data.iloc[idx]
        attributes = torch.Tensor(self.attributes.iloc[idx, 1:].values.astype(np.int8))
        certainty = torch.Tensor(self.certainty.iloc[idx, 1:].values.astype(np.int8))
        path = os.path.join(self.root, self.base_folder, sample.filepath)
        target = sample.target - 1  # Targets start at 1 by default, so shift to 0
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        batch = {}
        batch['image'] = img
        batch['class'] = target
        batch['attributes'] = attributes
        batch['certainty'] = certainty

        return batch


transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
])
dataset_args = {
                 'root': '../data',
                 'crop_size': 224,
                 'brightness': 0.4, 
                 'contrast': 0.4, 
                 'saturation': .2, 
                 'hue': .1, 
                 'color_jitter_prob': .4, 
                 'gray_scale_prob': 0.2, 
                 'horizontal_flip_prob': 0.5, 
                 'gaussian_prob': .5, 
                 'min_scale': 0.6, 
                 'max_scale': 0.95}

dataset = Cub2011(dataset_args['root'], dataset_args, download=False)

dataloader = DataLoader(dataset, batch_size=32,  shuffle=False, num_workers=0)
for i, data in enumerate(dataloader):
    break



In [2]:
total = 0
certain = 0
for i, data in enumerate(dataloader):
    total += data['image'].shape[0]
    certain += (torch.sum(data['certainty'] >= 2))/312

In [3]:
certain / total

tensor(0.8910)

In [8]:
certain / total

tensor(0.5205)

In [8]:
from PIL import Image
import requests
import torch

from transformers import CLIPProcessor, CLIPModel
import torch.nn as nn

def create_text_prompts(root):
    with open(root + '/CUB_200_2011/attributes.txt', 'r') as f:
        lines = f.readlines()
    text_prompts = []
    for line in lines:

        #base that will be used for every image
        start = 'The bird has a '

        #get the words before seeing the descriptor
        beginning = ''
        seen = False


        for i in line.split()[1].split('_'):

            #:: signigifies that the attribute value is on the other side
            if '::' in i:
                first_half = i.split('::')[0]
                second_half = i.split('::')[1]
                seen = True

            #if we have seen the descriptor, we are done and ( signifies that
            if '(' in i:
                break
            if i != 'has':
                if '::' in i: continue
                if seen: second_half += ' ' + i
                else: beginning += i + ' '
        start += second_half + ' ' + beginning  + first_half
        text_prompts.append(start)
    return text_prompts

text_prompts = create_text_prompts('../data/')

class CLIP_text_image_concat(nn.Module):

    def __init__(self, args = None):
        super().__init__()
        self.args = args
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.linear1 = nn.ModuleList([nn.Linear(1024, 1) for i in range(312)])
        self.linear2 = nn.ModuleList([nn.Linear(768 + 512, 1) for i in range(312)])
        self.classifier1 = nn.ModuleList([nn.Sequential(nn.Linear(1024, 312), nn.ReLU(), nn.BatchNorm1d(312), nn.Linear(312, 1)) for i in range(312)])
        self.classifier2 = nn.ModuleList([nn.Sequential(nn.Linear(768 + 512, 312), nn.ReLU(), nn.BatchNorm1d(312), nn.Linear(312, 1)) for i in range(312)])

    def forward(self, prompts, images):
        text = prompts
        inputs = self.processor(text=text, return_tensors="pt", padding=True)
        for i in inputs:
            inputs[i] = inputs[i].cuda()
        inputs['pixel_values'] = images.cuda()
        outputs = self.clip(**inputs)

        image_embed = outputs.image_embeds
        image_out = outputs.vision_model_output['pooler_output']
        text_embed = outputs.text_embeds
        text_out = outputs.text_model_output['pooler_output']

        classifications = []

        for i in range(text_embed.size(0)):
            new_text_embed = text_embed[i].unsqueeze(0).repeat(image_embed.size(0), 1)
            new_text_out = text_out[i].unsqueeze(0).repeat(image_out.size(0), 1)
            final_embed = torch.cat((image_embed, new_text_embed), dim=1)
            final_out = torch.cat((image_out, new_text_out), dim=1)

            lin1 = self.linear1[i](final_embed)
            class1 = self.classifier1[i](final_embed)

            lin2 = self.linear2[i](final_out)
            class2 = self.classifier2[i](final_out)
            print(class2.size())

            inter_class = [lin1, lin2, class1, class2]
            if i == 0:
                classifications = inter_class

            else:
                classifications[0] = torch.cat((classifications[0], inter_class[0]), dim=1)
                classifications[1] = torch.cat((classifications[1], inter_class[1]), dim=1)
                classifications[2] = torch.cat((classifications[2], inter_class[2]), dim=1)
                classifications[3] = torch.cat((classifications[3], inter_class[3]), dim=1)

        return classifications, outputs.logits_per_image

model = CLIP_text_image_concat()
model.cuda()
images, attributes, certainty = data['image'].cuda(), data['attributes'].cuda(), data['certainty'].cuda()

truth = certainty >= 3
classification_out, clip_image_logits = model(text_prompts, images)

torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])
torch.Size([32, 1])


In [9]:
classification_out[0].shape

torch.Size([32, 312])

In [None]:
#indicates where a multivalued attribute starts and ends
attributes = data['attributes'][0]
certainty = data['certainty'][0]
indices = [9, 24, 39, 54, 58, 73, 79, 94, 105, 120, 135, 149, 152, 167, 182, 197, 212, 217, 222, 236, 240, 244, 248, 263, 278, 293, 308, 312]
multi = [0 for i in range(len(indices))]
start = 0
for f in range(len(data['attributes'])):
    attributes = data['attributes'][f]
    certainty = data['certainty'][f]
    start = 0
    for val, i in enumerate(indices):
        if certainty[start] >= 3:
            if attributes[start:i - 1].sum() > 1:
                #print('multivalued attribute', start, i)
                multi[val] += 1
        start = i
print(multi)