# Inference notenook for [CLIP Instagram Captioning](https://github.com/olafbrah/cs182proj)

Make sure to download weights and add them to your drive [here](https://drive.google.com/drive/folders/1z68jSlSbBZ6mHuqmcpcO3-aEKLUKb72X?usp=sharing)

In [None]:
#@title Install
!pip install transformers
! pip install git+https://github.com/openai/CLIP.git
!pip install datasets

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-s1_ngcry
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-s1_ngcry
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.1.3-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369497 sha256=a5a68136d3b764010995c93b64d0222d639469c002cfd844bc5a3671ed35b51e
  Stored in directory: /tmp/pip-ephem-wheel-cache-qb901oo3/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac13b112deb897d5b50f5ad9a37e4
Successfully built clip
Inst

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
#@title Imports

import clip
import os
from torch import nn
import numpy as np
import torch
import torch.nn.functional as nnf
import sys
from typing import Tuple, List, Union, Optional
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
from google.colab import files
import skimage.io as io
import PIL.Image
from IPython.display import Image
from enum import Enum

class MappingType(Enum):
    MLP = 'mlp'
    Transformer = 'transformer'
class PromptType(Enum):
    Empty = "empty"
    Orginal = 'original'
    OriginalWithWords = 'originalplus'
N = type(None)
V = np.array
ARRAY = np.ndarray
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
VS = Union[Tuple[V, ...], List[V]]
VN = Union[V, N]
VNS = Union[VS, N]
T = torch.Tensor
TS = Union[Tuple[T, ...], List[T]]
TN = Optional[T]
TNS = Union[Tuple[TN, ...], List[TN]]
TSN = Optional[TS]
TA = Union[T, ARRAY]


D = torch.device
CPU = torch.device('cpu')


def get_device(device_id: int) -> D:
    if not torch.cuda.is_available():
        return CPU
    device_id = min(torch.cuda.device_count() - 1, device_id)
    return torch.device(f'cuda:{device_id}')


CUDA = get_device

current_directory = os.getcwd()
print(current_directory)
save_path = os.path.join("/content/drive/MyDrive", "182projweights")
os.makedirs(save_path, exist_ok=True)
# model_path = os.path.join(save_path, 'model_wieghts.pt')
data_path = os.path.join(save_path, "test_data")
print(save_path)


/content
/content/drive/MyDrive/182projweights


In [None]:
class MLP(nn.Module):
    def forward(self, x: T) -> T:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)


class CaptionModel(nn.Module):
    def get_dummy_token(self, batch_size: int, device: D) -> T:
        return torch.zeros(
            batch_size, self.prefix_length, dtype=torch.int64, device=device
        )

    def forward(
        self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None
    ):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(
            -1, self.prefix_length, self.gpt_embedding_size
        )
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

    def __init__(self, prefix_length: int, prefix_size: int = 512):
        super(CaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained("gpt2")
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        if prefix_length > 10:  # not enough memory
            self.clip_project = nn.Linear(
                prefix_size, self.gpt_embedding_size * prefix_length
            )
        else:
            self.clip_project = MLP(
                (
                    prefix_size,
                    (self.gpt_embedding_size * prefix_length) // 2,
                    self.gpt_embedding_size * prefix_length,
                )
            )

def generate(
    model,
    tokenizer,
    tokens=None,
    prompt=None,
    embed=None,
    entry_count=1,
    entry_length=67,  # maximum number of words
    top_p=0.8,
    temperature=1.0,
    stop_token: str = ".",
):
    model.eval()
    generated_num = 0
    generated_list = []
    stop_token_index = tokenizer.encode(stop_token)[0]
    filter_value = -float("Inf")
    device = next(model.parameters()).device

    with torch.no_grad():

        for entry_idx in range(entry_count):
            if embed is not None:
                generated = embed
            else:
                if tokens is None:
                    tokens = torch.tensor(tokenizer.encode(prompt))
                    tokens = tokens.unsqueeze(0).to(device)

                generated = model.gpt.transformer.wte(tokens)

            for i in range(entry_length):

                outputs = model.gpt(inputs_embeds=generated)
                logits = outputs.logits
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(
                    nn.functional.softmax(sorted_logits, dim=-1), dim=-1
                )
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value
                next_token = torch.argmax(logits, -1).unsqueeze(0)
                next_token_embed = model.gpt.transformer.wte(next_token)
                if tokens is None:
                    tokens = next_token
                else:
                    tokens = torch.cat((tokens, next_token), dim=1)
                generated = torch.cat((generated, next_token_embed), dim=1)
                if stop_token_index == next_token.item():
                    break

            output_list = list(tokens.squeeze().cpu().numpy())
            output_text = tokenizer.decode(output_list)
            generated_list.append(output_text)

    return generated_list[0]

In [None]:
class PromptedCaptionModel(nn.Module):

    def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    def pad_tokens(self, tokens):
        padding = self.max_seq_len - tokens.shape[0]
        if padding > 0:
            tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64).to(self.device) - 1))
        elif padding < 0:
            tokens = tokens[:self.max_seq_len]
        mask = tokens.ge(0)  # mask is zero where we out of sequence
        tokens[~mask] = 0
        mask = mask.float()
        mask = torch.cat((torch.ones(self.prefix_length).to(self.device), mask), dim=0)  # adding prefix mask
        return tokens, mask


    def forward(self, caption, prefix: torch.Tensor,
                labels: Optional[torch.Tensor] = None):
        # embedding_text = torch.cat((self.prepend_embedding.unsqueeze(0).repeat(40, 1, 1),self.gpt.transformer.wte(tokens)), dim=1 )
        # ones_tensor = torch.ones(40, 9).to(device)
        # mask = torch.cat((ones_tensor, mask), dim=1)

        # prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)

        curr_text = generate(self.original_model, self.tokenizer, embed=prefix_projections)
        curr_text = f" is a picture of {curr_text} and a social media post would caption it {caption}"
        tokens = torch.tensor(self.tokenizer.encode(curr_text)).to(self.device)
        tokens, mask = self.pad_tokens(tokens)
        embedding_text = self.gpt.transformer.wte(tokens).unsqueeze(0)
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)

        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

    def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
                 num_layers: int = 8, mapping_type: MappingType = MappingType.MLP,
                 prompt_mode: PromptType = PromptType.OriginalWithWords, weights_path: str = os.path.join(save_path, 'coco_weights.pt')
,
                device = "cpu" ):
        super(PromptedCaptionModel, self).__init__()
        self.device = device
        self.max_seq_len = 77
        self.prompt_mode = prompt_mode
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        if prefix_length > 10:  # not enough memory
            self.clip_project = nn.Linear(
                prefix_size, self.gpt_embedding_size * prefix_length
            )
        else:
            self.clip_project = MLP(
                (
                    prefix_size,
                    (self.gpt_embedding_size * prefix_length) // 2,
                    self.gpt_embedding_size * prefix_length,
                )
            )

        if self.prompt_mode == PromptType.OriginalWithWords:
            self.original_model = CaptionModel(prefix_length)
            state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
            self.original_model.load_state_dict(state_dict, strict=False)
            # Freeze the model parameters
            for param in self.original_model.parameters():
                param.requires_grad = False

            self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

            # prepend_phrase = "is a image of "
            # prepend_tokens = torch.tensor(self.tokenizer.encode(prepend_phrase)).to(device)
            # self.prepend_embedding = self.gpt.transformer.wte(prepend_tokens).detach()
            # self.prepend_embedding = self.prepend_embedding.to(device)


In [None]:
#@title CLIP model + GPT2 tokenizer

device = CUDA(0)
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

100%|███████████████████████████████████████| 338M/338M [00:06<00:00, 53.8MiB/s]


vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [None]:

from datasets import load_from_disk
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms



class InstagramDataset(Dataset):
    """
    Dataset that returns the following
    -tokens : tokenized caption via gpt tokenizer
    -prefix : clip image prefix
    -mask : token attention mask for gpt

    --> Possible will error with device stuff, might have to pass in device
    """
    def __init__(self, clip, preprocessor, tokenizer, path=data_path, split="train", device="cuda"):
        self.clip_model = clip
        self.preprocess = preprocessor
        self.tokenizer = tokenizer
        # assert split in ["train", "test"], "Invalid Split Name! Expected one of 'train' or 'test'"
        self.data_dict = load_from_disk(path)
        self.max_seq_len = 77 # clip max sequence length
        self.prefix_len = 10
        self.device = device

    def __len__(self):
        return self.data_dict.num_rows

    def __getitem__(self, idx):
        entry = self.data_dict[idx]
        pil_image = entry["image"]
        img = transforms.ToTensor()(pil_image)
        image = self.preprocess(pil_image).unsqueeze(0)
        image = image.to(self.device)
        prefix = self.clip_model.encode_image(image)

        caption = entry["caption"]
        caption = caption
        tokens = torch.tensor(self.tokenizer.encode(caption))
        tokens, mask = self.pad_tokens(tokens)

        return tokens, prefix, mask, img
    def pad_tokens(self, tokens):
        padding = self.max_seq_len - tokens.shape[0]
        if padding > 0:
            tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
        elif padding < 0:
            tokens = tokens[:self.max_seq_len]
        mask = tokens.ge(0)  # mask is zero where we out of sequence
        tokens[~mask] = 0
        mask = mask.float()
        mask = torch.cat((torch.ones(self.prefix_len), mask), dim=0)  # adding prefix mask
        return tokens, mask



class PromptedInstagramDataset(Dataset):
    """
    Dataset that returns the following
    -tokens : tokenized caption via gpt tokenizer
    -prefix : clip image prefix
    -mask : token attention mask for gpt

    --> Possible will error with device stuff, might have to pass in device
    """
    def __init__(self, clip, preprocessor, tokenizer, path=data_path, split="train", device="cuda"):
        self.clip_model = clip
        self.preprocess = preprocessor
        self.tokenizer = tokenizer
        # assert split in ["train", "test"], "Invalid Split Name! Expected one of 'train' or 'test'"
        self.data_dict = load_from_disk(path)
        self.max_seq_len = 77 # clip max sequence length
        self.prefix_len = 10
        self.device = device

    def __len__(self):
        return self.data_dict.num_rows

    def __getitem__(self, idx):
        entry = self.data_dict[idx]
        pil_image = entry["image"]
        image = self.preprocess(pil_image).unsqueeze(0)
        image = image.to(self.device)
        prefix = self.clip_model.encode_image(image)

        caption = entry["caption"]
        caption = caption
        tokens = torch.tensor(self.tokenizer.encode(caption))
        tokens, mask = self.pad_tokens(tokens)

        return tokens, prefix, mask, caption

    def pad_tokens(self, tokens):
        padding = self.max_seq_len - tokens.shape[0]
        if padding > 0:
            tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
        elif padding < 0:
            tokens = tokens[:self.max_seq_len]
        mask = tokens.ge(0)  # mask is zero where we out of sequence
        tokens[~mask] = 0
        mask = mask.float()
        mask = torch.cat((torch.ones(self.prefix_len), mask), dim=0)  # adding prefix mask
        return tokens, mask

In [None]:
models = {}

In [None]:
model_path = os.path.join(save_path, 'base_weights_final.pt')

prefix_length = 10

base_model = CaptionModel(prefix_length)

base_model.load_state_dict(torch.load(model_path, map_location=CPU))

base_model = base_model.eval()
device = "cuda"
base_model = base_model.to(device)
models["base"] = base_model


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [None]:
model_path = os.path.join(save_path, 'base_5epochs_weights.pt')

prefix_length = 10

base_5eps_model = CaptionModel(prefix_length)

base_5eps_model.load_state_dict(torch.load(model_path, map_location=CPU))

base_5eps_model = base_5eps_model.eval()
device = "cuda"
base_model = base_5eps_model.to(device)
models["base 5 epochs"] = base_5eps_model

In [None]:
model_path = os.path.join(save_path, 'frozen_weights.pt')

prefix_length = 10

frozen_model = CaptionModel(prefix_length)

frozen_model.load_state_dict(torch.load(model_path, map_location=CPU))

frozen_model = frozen_model.eval()
device = "cuda"
frozen_model = frozen_model.to(device)
models["frozen"] = frozen_model

In [None]:
model_path = os.path.join(save_path, 'prompted_weights.pt')

prefix_length = 10

prompted_model = PromptedCaptionModel(prefix_length)

prompted_model.load_state_dict(torch.load(model_path, map_location=CPU))

prompted_model = prompted_model.eval()
device = "cuda"
prompted_model = prompted_model.to(device)
models["prompted"] = prompted_model

In [None]:
model_path = os.path.join(save_path, 'lora_4_weights.pt')

prefix_length = 10

lora_model = CaptionModel(prefix_length)

lora_model.load_state_dict(torch.load(model_path, map_location=CPU), strict=False)

lora_model = lora_model.eval()
device = "cuda"
lora_model = lora_model.to(device)
models["lora"] = lora_model

In [None]:
from torch.utils.data import Subset

test_data = InstagramDataset(clip_model, preprocess, tokenizer, split="test", device="cuda")
# test_data = Subset(test_data, indices=range(5))

test_loader = DataLoader(dataset=test_data, batch_size=1, shuffle=False)

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output



How to rate: You will be presented an image and captions for each of the models (5 in total) Rank the models from 1-5. If any models return gibberish, score it a 0 and rank the other inputs starting from 5 and going down.

In [None]:
def image_generator():
    for tokens, prefix, mask, image in test_loader:
        tokens, prefix, mask = tokens.to(device), prefix.to(device), mask.to(device)
        yield tokens, prefix, mask, image

gen = image_generator()  # Create the generator instance

# Initialize a dictionary to keep track of ratings
all_ratings = {model_name: [] for model_name in models}

def on_button_clicked(b):
    # Extract and store ratings
    current_ratings = {model_name: inputs[model_name].value for model_name in models}
    for model_name, rating in current_ratings.items():
        all_ratings[model_name].append(rating)

    # Clear the current output
    clear_output(wait=True)

    # Display the next image or calculate averages if done
    try:
        display_next_image()
    except StopIteration:
        calculate_averages()

def calculate_averages():
    print("All images processed. Calculating averages...")
    for model_name, ratings in all_ratings.items():
        average_rating = sum(ratings) / len(ratings) if ratings else 0
        print(f"Average rating for {model_name}: {average_rating:.2f}")

def display_next_image():
    global inputs

    tokens, prefix, mask, image = next(gen)
    image = image.squeeze()
    image = transforms.ToPILImage()(image)

    plt.imshow(image)
    plt.axis('off')
    plt.show()

    outputs = {}
    with torch.no_grad():
        for key, model in models.items():
            prefix = prefix.to(torch.float32)
            prefix_embed = model.clip_project(prefix).reshape(1,10, -1)
            output = generate(model, tokenizer, embed=prefix_embed)
            outputs[key] = output

    inputs = {}
    for k, v in outputs.items():
        print(f"{k} Output: {v}")
        inputs[k] = widgets.FloatText(description=f'{k}:')

    for input_widget in inputs.values():
        display(input_widget)

    # Display submit button
    button = widgets.Button(description="Submit")
    button.on_click(on_button_clicked)
    display(button)

# Start by displaying the first image
display_next_image()

All images processed. Calculating averages...
Average rating for base: 2.59
Average rating for base 5 epochs: 2.32
Average rating for frozen: 2.87
Average rating for prompted: 0.84
Average rating for lora: 0.56


In [None]:
print(all_ratings)

{'base': [2.0, 1.0, 2.0, 1.0, 2.0], 'frozen': [1.0, 2.0, 1.0, 1.0, 1.0]}


In [None]:
#@title Upload Image


uploaded = files.upload()

if not uploaded:
  UPLOADED_FILE = ''
elif len(uploaded) == 1:
  UPLOADED_FILE = list(uploaded.keys())[0]
else:
  raise AssertionError('Please upload one image at a time')

print(UPLOADED_FILE)

Conceptual captions examples:
https://drive.google.com/file/d/1mzH3b0LQrGEWjEva4hI6HE_fIYRIgtBT/view?usp=sharing