In [1]:
from transformers import BeitFeatureExtractor, BeitForImageClassification
from PIL import Image
import requests
import torchvision
import torchvision.datasets as datasets
import torch
from torchvision.transforms import Compose
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
from torch.utils.data.distributed import DistributedSampler
import sys
import logging
sys.path.append('/scratch/bf996/vlhub/src')
from training.imagenet_zeroshot_data import *
from tqdm import tqdm

In [12]:
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
# inputs = feature_extractor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# logits = outputs.logits
# # model predicts one of the 1000 ImageNet classes
# predicted_class_idx = logits.argmax(-1).item()
# print("Predicted class:", model.config.id2label[predicted_class_idx])

In [3]:
class SharedEpoch:
    def __init__(self, epoch: int = 0):
        self.shared_epoch = Value('i', epoch)

    def set_value(self, epoch):
        self.shared_epoch.value = epoch

    def get_value(self):
        return self.shared_epoch.value

@dataclass
class DataInfo:
    dataloader: DataLoader
    sampler: DistributedSampler = None
    shared_epoch: SharedEpoch = None

    def set_epoch(self, epoch):
        if self.shared_epoch is not None:
            self.shared_epoch.set_value(epoch)
        if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
            self.sampler.set_epoch(epoch)

In [4]:
def get_imagenet(args, preprocess_fns, split):
    assert split in ["train", "val", "v2", "r", "a", "s"], "Not a recognized ImageNet split, {}".format(split)
    is_train = (split == "train")
    preprocess_train = preprocess_val = preprocess_fns

    if split == "v2":
        from imagenetv2_pytorch import ImageNetV2Dataset
        dataset = ImageNetV2Dataset(location=args["imagenet_v2"], transform=preprocess_val)
    elif is_train:
        data_path = args.imagenet_train
        preprocess_fn = preprocess_train
        dataset = datasets.ImageFolder(data_path, transform=preprocess_train)
    else:
        if split == "val":
            data_path = args['imagenet_val']
        if split == "r":
            data_path = args['imagenet_r']
        if split == "a":
            data_path = args['imagenet_a']
        if split == "s":
            data_path = args['imagenet_s']
        preprocess_fn = preprocess_val
        assert data_path, "No data path found"

        dataset = datasets.ImageFolder(data_path, transform=preprocess_val)
    if is_train:
        idxs = np.zeros(len(dataset.targets))
        target_array = np.array(dataset.targets)
        k = 50
        for c in range(1000):
            m = target_array == c
            n = len(idxs[m])
            arr = np.zeros(n)
            arr[:k] = 1
            np.random.shuffle(arr)
            idxs[m] = arr

        idxs = idxs.astype('int')
        sampler = SubsetRandomSampler(np.where(idxs)[0])
    else:
        sampler = None

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args['batch_size'],
        num_workers=args['workers'],
        sampler=sampler
    )

    return DataInfo(dataloader=dataloader, sampler=sampler)

In [5]:
#--imagenet-a "/imagenet-a" --imagenet-r "/imagenet-r" --imagenet-val "/imagenet/val/" --imagenet-v2 "/scratch/bf996/datasets" --imagenet-s "/imagenet-sketch"

args = {"imagenet_v2" : "/scratch/bf996/datasets", "imagenet_r" : "/imagenet-r", "imagenet_val" : "/imagenet/val/", "imagenet_a" : "/imagenet-a", "imagenet_s" : "/imagenet-sketch", "batch_size" : 32, "workers" : 8, 'device' : 'cuda:0'}



In [14]:
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
    CenterCrop

def _convert_to_rgb(image):
    return image.convert('RGB')

image_size = 224

transform_l = [Resize(image_size, interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), _convert_to_rgb, ToTensor(), Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))]

transform_c = Compose(transform_l)

In [15]:
data = {}
data["imagenet-val"] = get_imagenet(args, transform_c, "val")
data["imagenet-v2"] = get_imagenet(args, transform_c, "v2")
data["imagenet-s"] = get_imagenet(args, transform_c, "s")
data["imagenet-r"] = get_imagenet(args, transform_c, "r")
data["imagenet-a"] = get_imagenet(args, transform_c, "a")

In [16]:
def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [31]:
def run(model, classifier, dataloader, args, idx=None, split=None, caption_subset=""):
    with torch.no_grad():
        top1, top5, n = 0., 0., 0.
        for images, target in tqdm(dataloader, unit_scale=args['batch_size']):
            if caption_subset != "":
                if split == "r":
                    ir_idx = get_ir_idx().tolist()
                    match_idx = sum(target==ir_idx.index(i) for i in idx).bool().nonzero(as_tuple=True)[0]
                elif split == "a":
                    ia_idx = get_ia_idx().tolist()
                    #keep only the samples which are in passed-in class subset, using correct imagenet-a indices 
                    match_idx = sum(target==ia_idx.index(i) for i in idx).bool().nonzero(as_tuple=True)[0]
                else:
                    match_idx = sum(target==i for i in idx).bool().nonzero(as_tuple=True)[0]
                #shave down target and images size so we skip irrelevant samples
                target = target[match_idx]
                images = images[match_idx]
            if images.size(0) == 0:
                continue
            model = model.to("cuda:0")
            images = images.to("cuda:0")
            target = target.to("cuda:0")
            logits = model(images).logits
            #zero out logits which are not being evaluated (in VL this is handled by changing the size of the classification problem)
            if caption_subset != "":
                icap_idx = get_icap_idx(caption_subset)
                not_icap_idx = [i for i in range(1000) if i not in icap_idx]
                logits[:, not_icap_idx] = float("-inf")
            if split == 'r':
                ir_idx = get_ir_idx()
                not_ir_idx = [i for i in range(1000) if i not in ir_idx]
                logits[:, not_ir_idx] = float("-inf")
            if split == 'a':
                ia_idx = get_ia_idx()
                not_ia_idx = [i for i in range(1000) if i not in ia_idx]
                logits[:, not_ia_idx] = float("-inf")

            acc1, acc5 = accuracy(logits, target, topk=(1, min(5, len(icap_idx))))
            n += images.size(0)
            top1 += acc1
            top5 += acc5
            #print("top1", top1, "n", n)

    top1 = (top1 / n)
    top5 = (top5 / n)
    #TODO: debug integer labels for extended metrics
    return top1, top5

In [32]:
logging.info('Starting zero-shot imagenet.')
caption_subset = "in100"
if caption_subset != "":
    logging.info("Using caption subset {}".format(caption_subset))
    get_icap_idx(caption_subset)
    get_common_ir_idx()
    get_common_ir_idx_zeroindexed()
    get_common_ia_idx()
    get_common_ia_idx_zeroindexed()
    get_common_obj_idx()
    get_common_obj_idx_zeroindexed()

In [33]:
classifier = None
imagenets = []
results = {}
if 'imagenet-val' in data:            
    top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args, get_icap_idx(caption_subset) if caption_subset != "" else None, caption_subset=caption_subset)
    results['imagenet-zeroshot-val-top1'] = top1
    imagenets.append(top1)
    results['imagenet-zeroshot-val-top5'] = top5
    print('Finished zero-shot val. Top1 was {}, top5 was {}'.format(top1, top5))
if 'imagenet-v2' in data:
    top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args, get_icap_idx(caption_subset) if caption_subset != "" else None, caption_subset=caption_subset)
    results['imagenetv2-zeroshot-val-top1'] = top1
    imagenets.append(top1)
    results['imagenetv2-zeroshot-val-top5'] = top5
    print('Finished zero-shot v2. Top1 was {}, top5 was {}'.format(top1, top5))
if 'imagenet-s' in data:
    top1, top5 = run(model, classifier, data['imagenet-s'].dataloader, args, get_icap_idx(caption_subset) if caption_subset != "" else None, caption_subset=caption_subset)
    results['imagenets-zeroshot-val-top1'] = top1
    imagenets.append(top1)
    results['imagenets-zeroshot-val-top5'] = top5
    print('Finished zero-shot sketch. Top1 was {}, top5 was {}'.format(top1, top5))
if 'imagenet-r' in data:
    top1, top5 = run(model, classifier, data['imagenet-r'].dataloader, args, get_common_ir_idx() if caption_subset != "" else get_ir_idx(), "r", caption_subset=caption_subset)
    results['imagenetr-zeroshot-val-top1'] = top1
    imagenets.append(top1)
    results['imagenetr-zeroshot-val-top5'] = top5
    print('Finished zero-shot imagenet-r. Top1 was {}, top5 was {}'.format(top1, top5))
if 'imagenet-a' in data:
    top1, top5 = run(model, classifier, data['imagenet-a'].dataloader, args, get_common_ia_idx() if caption_subset != "" else get_ia_idx(), "a", caption_subset=caption_subset)
    results['imageneta-zeroshot-val-top1'] = top1
    imagenets.append(top1)
    results['imageneta-zeroshot-val-top5'] = top5
    print('Finished zero-shot imagenet-a. Top1 was {}, top5 was {}'.format(top1, top5))  
if results.get('imagenet-zeroshot-val-top1'):
    logging.info("computing effective robustness on imagenet")
    logging.info("len imagenets {}".format(len(imagenets)))
    try:
        imagenet_shifts = []
        for shift in ['imagenetr-zeroshot-val-top1', 'imageneta-zeroshot-val-top1', 'imagenets-zeroshot-val-top1', 'imagenetv2-zeroshot-val-top1']:
            if results.get(shift):
                imagenet_shifts.append(results[shift])
        if len(imagenet_shifts) > 0:
            results['imagenet-average-robustness'] = np.average(imagenet_shifts)
            results['imagenet-effective-robustness'] = np.divide(np.average(imagenet_shifts), results['imagenet-zeroshot-val-top1'])
            print("Average robustness over {} ImageNet shifts: {}".format(len(imagenet_shifts), results['imagenet-average-robustness']))
    except Exception as e:
        logging.info("error calculating effective robustness: ")
        logging.info(e)

100%|██████████| 50016/50016 [00:55<00:00, 904.77it/s] 
100%|██████████| 10016/10016 [00:18<00:00, 546.92it/s]
100%|██████████| 50912/50912 [01:43<00:00, 493.77it/s]
100%|██████████| 30016/30016 [00:48<00:00, 620.71it/s] 
100%|██████████| 7520/7520 [00:10<00:00, 696.31it/s] 


In [34]:
print(results)

{'imagenet-zeroshot-val-top1': 0.4228, 'imagenet-zeroshot-val-top5': 0.6152, 'imagenetv2-zeroshot-val-top1': 0.352, 'imagenetv2-zeroshot-val-top5': 0.526, 'imagenets-zeroshot-val-top1': 0.02242770017706079, 'imagenets-zeroshot-val-top5': 0.046429274050757426, 'imagenetr-zeroshot-val-top1': 0.0001585791309863622, 'imagenetr-zeroshot-val-top5': 0.0009514747859181732, 'imageneta-zeroshot-val-top1': 0.0006246096189881324, 'imageneta-zeroshot-val-top5': 0.0037476577139287947, 'imagenet-average-robustness': 0.09380272223175881, 'imagenet-effective-robustness': 0.22186074321608043}


## BEIT Eval from timm

In [1]:
import pandas as pd

In [3]:
IN100_TRUE_IDX = [386, 928, 931, 704, 907, 291, 454, 76, 952, 788, 245, 937, 924, 8, 983, 816, 920, 379, 204, 396, 929, 619, 815, 88, 84, 217, 118, 935, 987, 642, 950, 951, 954, 557, 18, 967, 945, 6, 440, 348, 22, 571, 23, 963, 104, 958, 579, 312, 534, 620, 115, 298, 284, 552, 373, 997, 182, 422, 308, 839, 13, 489, 805, 832, 85, 695, 2, 863, 310, 565, 886, 455, 988, 347, 580, 425, 99, 424, 105, 107, 343, 658, 721, 443, 421, 679, 19, 825, 130, 309, 849, 879, 496, 971, 922, 985, 286, 625, 637, 943]
ir_idx = [1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107, 113, 122, 
125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203, 207, 208, 219, 
231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, 281, 288, 289, 291, 292, 293, 
296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347, 353, 355, 361, 
362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447, 448, 457, 462, 
463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 613, 617, 621, 629, 
637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, 866, 875, 883, 
889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 
983, 988]
IN100_DOGS_IDX = [151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250]
ia_idx = [6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107, 108, 110, 
113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, 306, 307, 308, 309, 
310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363, 372, 378, 386, 397, 400, 401, 
402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, 462, 470, 472, 483, 486, 488, 492, 496, 
514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572, 573, 575, 579, 589, 606, 607, 609, 614, 626, 627, 
640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, 719, 736, 746, 749, 752, 758, 763, 765, 768, 773, 774, 776, 
779, 780, 786, 792, 797, 802, 803, 804, 813, 815, 820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870, 879, 880, 
888, 890, 897, 900, 907, 913, 924, 932, 933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980, 981, 984, 
986, 987, 988]
common_ia_idx = [n for n in ia_idx if n in IN100_TRUE_IDX]
common_ir_idx = [n for n in ir_idx if n in IN100_TRUE_IDX]
common_dir_idx = [n for n in ir_idx if n in IN100_DOGS_IDX]
common_dia_idx = [n for n in ia_idx if n in IN100_DOGS_IDX]

In [226]:
df = pd.read_csv("/scratch/bf996/caption-paper-ICLR/timm_results_tfeffs-da.csv")
df = df[df['targets'].isin(common_dia_idx)]
round(len(df[df['index'] == df['targets']]) / len(df), 3)

0.659