In [None]:
import torch
import random
from PIL import Image
import os
from torchvision.datasets import CIFAR10, CIFAR100, CocoCaptions, ImageNet
import slip_models
from tokenizer import SimpleTokenizer

import torchvision.transforms as T

from tqdm import tqdm
%pip install ipywidgets
%pip install update tqdm


In [None]:
preprocess = T.Compose(
                    [
                        T.Resize(224),
                        T.CenterCrop(224),
                        T.ToTensor(),
                        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                    ]
                        )

In [None]:
# model = slip_models.CLIP_VITB16(embed_dim=8)
# ckpt = torch.load('ckpts/clip_embed_dim_8_epoch_35.ckpt', map_location='cpu')

model = slip_models.VisionParallelTextStandard()
ckpt = torch.load('ckpts/test_experimental_epoch_2.ckpt', map_location='cpu')

# model = slip_models.CLIP_VITB16(num_prompt_tokens=64, num_text_outputs=1000)
# ckpt = torch.load('ckpts/epoch_30_prompted_clip_may_27.ckpt', map_location='cpu')


model.load_state_dict({k.replace('module.',''):v for k,v in ckpt["model"].items()})
model = model.cuda()

In [None]:
tokenizer = SimpleTokenizer()

In [None]:
# image = preprocess(Image.open("pics/golden-retriever.png")).unsqueeze(0).to(device)
# image = preprocess(Image.open("pics/CLIP.png")).unsqueeze(0).to(device)
base_text = ["a diagram", "a dog", "a cat"]
model.eval()
dog_image = preprocess(Image.open("pics/golden-retriever.png"))
diagram_image = preprocess(Image.open("pics/CLIP.png").convert("RGB"))
cat_image = preprocess(Image.open("pics/cat.jpg"))
images = torch.stack([diagram_image, dog_image, cat_image]).cuda()

text = tokenizer([f"a picture of {s}" for s in base_text]).cuda()

with torch.no_grad():
    image_features = model.encode_image(images)
    text_features = model.encode_text(text)
    
    logits_per_image = model.logit_scale.exp() * image_features @ text_features.t()
    # logits_per_image, logits_per_text = model(images, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

# Basic ImageNet and CIFAR checks

In [None]:
# Download the dataset
cifar10 = CIFAR10(root="/tmp/", transform=preprocess, download=True, train=False)
loader = torch.utils.data.DataLoader(cifar10, batch_size=64, shuffle=True)
# text_inputs = torch.stack([tokenizer(f"a photo of a {c}") for c in cifar10.classes]).cuda()


text_inputs = torch.stack([tokenizer(f"a photo of a {c}") for c in cifar10.classes]).cuda()


num_correct = 0
num_seen = 0
with torch.no_grad():
    #  text_features = model.encode_text(text_inputs)
    for imgs, targets in loader:
        imgs = imgs.cuda()
        targets = targets.cuda()
        
        # image_features = model.encode_image(imgs)
        # results = model(imgs, text_inputs, lang_prompt_viz=True, sharded_computation=False)
        results = model(imgs, text_inputs)
        image_features = results['image_embed']
        image_features = image_features / image_features.norm(dim=1, keepdim=True)

        text_features = results['text_embed']
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        
        similarity = (image_features @ text_features.T).softmax(dim=-1)
        max_sims = similarity.max(dim=-1)[0]
        print(max_sims.min(), max_sims.max(), max_sims.mean())
        num_correct += (similarity.argmax(dim=-1)==targets).sum().item()  
        num_seen += imgs.shape[0]
        curr_acc = num_correct / num_seen
        print(f"Current Acc: {curr_acc}")
acc = num_correct / len(cifar10)
print(f"Final Acc {acc}")

In [None]:
similarity.max(dim=-1)

In [None]:
imagenet = ImageNet(root="/export/share/datasets/vision/imagenet", transform=preprocess, split='val')
loader = torch.utils.data.DataLoader(imagenet, batch_size=32, num_workers=4)
text_inputs = torch.stack([tokenizer(f"a photo of a {c}") for c in imagenet.classes]).cuda()

num_correct = 0
num_seen = 0
with torch.no_grad():
    text_features = model.encode_text(text_inputs)
    text_features = text_features / text_features.norm(dim=1, keepdim=True)
    
    for imgs, targets in loader:
        imgs = imgs.cuda()
        targets = targets.cuda()
        
        image_features = model.encode_image(imgs)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        # results = model(imgs, text_inputs, lang_prompt_viz=True, sharded_computation=False)
        # image_features = results['image_embed']
        # text_features = results['text_embed']
        
        
        similarity = (image_features @ text_features.T).softmax(dim=-1)
        num_correct += (similarity.argmax(dim=-1)==targets).sum().item()  
        max_sims = similarity.max(dim=-1)[0]
        # print(max_sims.min(), max_sims.max(), max_sims.mean())
        num_seen += imgs.shape[0]
        curr_acc = num_correct / num_seen
        print(f"Current Acc: {curr_acc}")
        
print(f"Final Acc {curr_acc}")

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.hist( (text_sims.cpu()).flatten().detach().cpu().numpy(), label='text sims', alpha=0.5)
plt.hist( (image_sims.cpu()).flatten().detach().cpu().numpy(), label='image_sims', alpha=0.5)
plt.legend()

In [None]:
from torchvision import datasets, transforms

def dataset_constructor(name, split, transform):
    """
    inputs
        name: string that we'll run if/else on
        split: "train" or "val" for now
    """
    assert split in ['train', 'val']

    # Need to figure out way to override the getitem call for strong weak augmentatin
    if name == 'cifar10':
        # Has class names
        return datasets.CIFAR10(root=os.path.expanduser("~/.cache"),
                                download=True,
                                train=(split=='train'),
                               transform=transform)
    elif name == 'cifar100':
        # Has class names
        return datasets.CIFAR100(root=os.path.expanduser("~/.cache"),
                                download=True,
                                train=(split=='train'),
                                transform=transform)

    elif name == 'svhn':
        # Doesn't have class names
        return datasets.SVHN(root=os.path.expanduser("~/.cache"),
                             download=True,
                             transform=transform,
                             split=split)
    elif name == 'food101':
        # name is in .classes
        assert split=='val'
        return datasets.ImageFolder(root="/export/share/bwallace/datasets/food101/images/",
                   transform=transform)
    elif name == 'merced':
        # name is in .classes
        assert split=='val'
        return datasets.ImageFolder(root="/export/share/bwallace/datasets/UCMerced_LandUse/Images/",
                   transform=transform)
    elif name == 'dtd':
        # name is in .classes
        return datasets.ImageFolder(root=f"/export/share/bwallace/datasets/dtd/{split}",
                   transform=transform)
    elif name == 'cub':
        # name is in .classes
        return datasets.ImageFolder(root=f"/export/share/bwallace/datasets/CUB_2011_formatted/{split}",
                   transform=transform)
    elif name == 'places365':
        # class names works in 
        return datasets.Places365(
                    root="/export/share/datasets/vision/Places365/",
                    split='train-standard' if split=='train' else split,
                   transform=transform)

    elif name == 'imagenet':
        # Doesn't have class names, have separate call
        return datasets.ImageNet(
                    root="/export/share/datasets/vision/imagenet/",
                    split=split,
                   transform=transform)
    elif name == 'imagenet_val':
        assert split=='train'
        # Doesn't have class names, only use is having Imagenet val be trainset 
        return datasets.ImageNet(
                    root="/export/share/datasets/vision/imagenet/",
                    split='val',
                   transform=transform)
    else:
        raise NotImplementedError


def get_imagenet_class_dict():
        idx_to_word_id_and_name_tuple = json.load(open('imagenet_class_index.json'))
        word_id_to_name_and_idx = {v[0]:(v[1],int(k))
                           for k,v in idx_to_word_id_and_name_tuple.items()}
        # e.g. n0023923939 to ('unicorn', 123)
        return word_id_to_name_and_idx

In [None]:
def test_on_dataset(model, dataset_name, transform,
                   prompt_template="a photo of a {}",
                   test_components_mode='standard',
                   normalization_from_all_components=False):
    """
    Options for test components mode
        standard : standard
        test_all : scale through
        random_all : As all but random sorting
        
    Normalization is whether to normalize from all components or just ones in computation
    
    """
    dataset = dataset_constructor(dataset_name, 'val', transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=4)
    class_names = [prompt_template.format(c.replace("'","")) for c in dataset.classes]
    
    # print(class_names)
    text_inputs = torch.stack([tokenizer(txt) for txt in class_names]).cuda()

    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        
        
        if test_components_mode == 'standard':
            sorted_feature_variances, sorted_variance_idx = None, None
        elif test_components_mode == 'test_all':
            sorted_feature_variances, sorted_variance_idx = sorted_feature_component_variances(text_features)
        elif test_components_mode == 'random_all':
            sorted_feature_variances, sorted_variance_idx = sorted_feature_component_variances(text_features)
            sorted_variance_idx = torch.randperm(sorted_variance_idx.shape[0])
        else:
            raise NotImplementedError
        
        num_correct = 0 if test_components_mode =='standard' else torch.zeros(text_features.shape[1])
        num_seen = 0
        for imgs, targets in tqdm(loader):
            imgs = imgs.cuda()
            targets = targets.cuda()

            image_features = model.encode_image(imgs)
            image_features = image_features / image_features.norm(dim=1, keepdim=True)
            # results = model(imgs, text_inputs, lang_prompt_viz=True, sharded_computation=False)
            # image_features = results['image_embed']
            # text_features = results['text_embed']


            if test_components_mode=="standard":
                similarity = (image_features @ text_features.T).softmax(dim=-1)
                num_correct += (similarity.argmax(dim=-1)==targets).sum().item() 
            elif test_components_mode in ["test_all", "random_all"]:
                for num_components in range(1, text_features.shape[1]+1):
                    idx_to_use = sorted_variance_idx[:num_components]
                    masked_text_features = text_features[:, idx_to_use]
                    masked_image_features = image_features[:, idx_to_use]
                    if not normalization_from_all_components:
                        masked_text_features =  masked_text_features / masked_text_features.norm(dim=1, keepdim=True)
                        masked_image_features =  masked_image_features / masked_image_features.norm(dim=1, keepdim=True)
                    masked_similarity = (masked_image_features @ masked_text_features.T).softmax(dim=-1)
                    num_correct[num_components-1] += (masked_similarity.argmax(dim=-1)==targets).sum().item()
                    
            # max_sims = similarity.max(dim=-1)[0]
            # print(max_sims.min(), max_sims.max(), max_sims.mean())
            num_seen += imgs.shape[0]
            curr_acc = num_correct / num_seen
    return curr_acc


def get_dataset_text_features(model, dataset_name):
    dataset = dataset_constructor(dataset_name, 'val', None)
    class_names = [f"a photo of a {c}" for c in dataset.classes]
    text_inputs = torch.stack([tokenizer(txt) for txt in class_names]).cuda()
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
    return text_features

def sorted_feature_component_variances(tensor):
    # NOTE: sorting max first so [0] is intuitively important
    feature_variances = tensor.var(dim=0)
    idx = feature_variances.argsort().flip(0)
    return feature_variances[idx], idx


In [None]:
test_on_dataset(model, 'dtd', preprocess)

In [None]:
test_on_dataset(model, 'cub', preprocess, test_components_mode='test_all')

In [None]:
test_on_dataset(model, 'cub', preprocess, test_components_mode='random_all')

In [None]:
test_on_dataset(model, 'cub', preprocess, test_components_mode='random_all')

In [None]:
tf = get_dataset_text_features(model, 'food101')

In [None]:
tf.shape

In [None]:
tf.std(dim=0)

In [None]:
text_features.std(dim=0).shape