# CLIP + REINFORCE based Image Captioning with Semantic Visual Tokens (SVT)

In [None]:
%cd ..

In [1]:
!pip install open_clip_torch peft accelerate open-clip-torch fire ipywidgets==8.1.2 jupyter_bbox_widget

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import textwrap

import open_clip
import peft

from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer

import torch

from torch import optim
from torch.nn import functional as F
from tqdm.auto import tqdm

from SemCLIP.semclip import get_segments_embeddings

In [3]:
print = tqdm.external_write_mode()(print)

def endless_range(start=0, step=1):
    """An endless range generator."""
    i = start
    while True:
        yield i
        i += step

def logp_completion(logits, tokens, mask):
    """Compute the log probabilities of completions given their prompts.

    Args:
        tokens: The tokens input to the model. Shape: (..., T).
        logits: The logits output from the model. Shape: (..., T, V).
        mask: A mask indicating which tokens should be included in the log probabilities. It should
            exclude prompt tokens and padding tokens. Shape: (..., T).

    Returns:
        The log probabilities of the completions given their prompts. Shape: (...).
    """
    logits = F.log_softmax(logits, dim=-1)
    logp_tokens = logits[..., :-1, :].gather(-1, tokens[..., 1:, None])[..., 0]
    return torch.sum(logp_tokens * mask[..., 1:], dim=-1)

In [4]:
prompt = "The theme of this image is" 
prompt = "What do you see in this picture?"
length = 80
batch_size = 64 
kl_weight = 4e-3
temperature = 0.8

In [5]:
clip_name = "ViT-L-14-336"
clip_pretrained = "openai"
# model_name = "EleutherAI/pythia-160m-deduped"
# model_name = "mistralai/Mistral-7B-v0.1"
# model_name = "meta-llama/Llama-2-7b-hf"
model_name = "microsoft/phi-2"
device = torch.device("cuda:0")

In [6]:
# Load CLIP
# # clip_tokenizer = open_clip.get_tokenizer(clip_name)
# clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
#     clip_name, pretrained=clip_pretrained, device=device
# )
# clip_model.eval().requires_grad_(False)

# Load language model
tokenizer = AutoTokenizer.from_pretrained(model_name, token='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')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    # device_map={"": device}
    token='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',
    trust_remote_code=True # for phi-2
)

# Move the model to the GPU
model = model.to(device)

# Prepare LoRA
peft_config = peft.LoraConfig(
    peft.TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32,
    lora_alpha=8,
    lora_dropout=0.0,
    target_modules=[
        # For NeoX and Pythia
        # "attention.query_key_value",
        # "attention.dense",
        # "mlp.dense_h_to_4h",
        # "mlp.dense_4h_to_h",
        # For Llama and Mistral 7B
        # "self_attn.q_proj",
        # "self_attn.k_proj",
        # "self_attn.v_proj",
        # "self_attn.o_proj",
        # "mlp.gate_proj",
        # "mlp.up_proj",
        # "mlp.down_proj",
        # For Phi:
        'Wqkv',
        'fc1',
        'fc2',
    ],
)
model = peft.get_peft_model(model, peft_config)

# Finish preparing model
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.train()
model.print_trainable_parameters()


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 26,214,400 || all params: 2,805,898,240 || trainable%: 0.9342605382581515


### Optimize the LLM with REINFORCE

In [7]:
import open_clip

import torch
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image


# Assuming the CLIP model is loaded as shown in the reference code provided
clip_model, _, preprocess = open_clip.create_model_and_transforms(
    clip_name, pretrained=clip_pretrained, device=device
)

clip_tokenizer = open_clip.get_tokenizer(clip_name)

In [8]:
added_embeddings_tensor = get_segments_embeddings(image_name='dog.jpeg', data_name='test_images', projection_dim=768)

In [9]:
added_embeddings_tensor = added_embeddings_tensor.to(device)

print(added_embeddings_tensor.shape)
print(added_embeddings_tensor)

torch.Size([1, 768])
tensor([[ 5.0610e-01,  3.0874e-01,  2.2853e-01, -2.6581e-01, -3.1092e-01,
          1.6562e-01, -7.2087e-02, -6.4041e-02,  3.6988e-01, -2.8280e-01,
         -1.7887e-01, -9.9637e-02,  5.3308e-02, -5.8022e-01,  6.6427e-02,
          8.0651e-02,  4.9946e-02,  1.8755e-01,  2.6872e-01, -1.2413e-01,
          1.0862e-02,  2.1680e-01, -4.6288e-01, -8.6663e-02,  4.9156e-01,
          7.9659e-01, -3.6737e-01,  4.6936e-01,  3.6757e-01,  1.0165e+00,
          3.1140e-01,  6.7719e-01, -9.3775e-02,  1.0950e+00, -2.5884e-02,
         -4.2328e-01, -8.3998e-01, -2.8377e-01,  1.1994e-02, -8.0993e-01,
         -3.2794e-01,  6.1093e-01,  2.8080e-01,  1.1519e-01,  2.5315e-01,
         -8.9761e-03,  1.2594e-01, -1.8304e-01,  1.9305e-01,  3.7651e-01,
         -3.1843e-01, -2.0094e-01,  6.0299e-01, -1.5161e-01,  7.1777e-01,
         -2.1114e-01, -6.3804e-02,  6.7453e-01,  2.6493e-01,  2.4697e-01,
         -4.3676e-01, -5.1152e-01,  8.4315e-01,  2.9186e-01, -2.7764e-01,
          2.3009e

In [10]:
# torch.save(added_embeddings_tensor, 'added_embeddings_tensor.pt')

# torch.load('added_embeddings_tensor.pt')

In [None]:
# Settings
torch.set_float32_matmul_precision("high")

added_embeddings_tensor = added_embeddings_tensor.to(device)

# Prepare the prompt
inputs = tokenizer([prompt] * batch_size, return_tensors="pt").to(device)
input_len = inputs.input_ids.shape[1]
logp_mask = torch.tensor(
    [[False] * input_len + [True] * length] * batch_size, device=device
)

# Optimize the LLM
opt = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95))

try:
    for i in tqdm(endless_range()):
        # Generate a batch of samples from the model
        model.eval()
        tokens = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            do_sample=True,
            min_new_tokens=length,
            max_new_tokens=length,
            pad_token_id=tokenizer.eos_token_id,
            temperature=temperature,
            top_k=0,
        )

        # Get the logits of the samples from the model and the reference model
        attention_mask = torch.cat(
            (inputs.attention_mask, torch.ones_like(tokens[:, input_len:])), dim=1
        )
        with torch.no_grad(), model.disable_adapter():
            outputs_ref = model(tokens, attention_mask=attention_mask, use_cache=False)
        model.train()
        outputs = model(tokens, attention_mask=attention_mask, use_cache=False)

        # Compute the log probability of the samples under the model and the reference model
        logp = logp_completion(outputs.logits / temperature, tokens, logp_mask)
        logp_ref = logp_completion(outputs_ref.logits / temperature, tokens, logp_mask)

        # Compute the CLIP loss
        texts = [tokenizer.decode(t, skip_special_tokens=True) for t in tokens]
        # print(tokens)
        clip_tokens = clip_tokenizer(texts).to(device)
        with torch.cuda.amp.autocast():
            text_embeds = clip_model.encode_text(clip_tokens).float()
        cost_clip = torch.cosine_similarity(text_embeds, added_embeddings_tensor, dim=-1).arccos()

        # Compute the KL penalty
        cost_kl = logp.detach() - logp_ref

        # REINFORCE
        cost = cost_clip + kl_weight * cost_kl
        baseline = (cost.sum() - cost) / (cost.numel() - 1)
        box = torch.exp(logp - logp.detach())
        loss = torch.mean(box * cost + (1 - box) * baseline)

        # Update the model
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Print statistics and the best sample in the batch
        grad_norm = torch.cat(
            [p.grad.flatten() for p in model.parameters() if p.grad is not None]
        ).norm()

        print(
            f"step: {i}, loss: {loss.item():g}, clip: {cost_clip.mean().item():g}, kl: {cost_kl.mean().item():g}, grad: {grad_norm.item():g}"
        )
        best_text = texts[torch.argmin(cost).item()]
        print(textwrap.fill(best_text, width=80))
        print()

except KeyboardInterrupt:
    pass

0it [00:00, ?it/s]

step: 0, loss: 1.50796, clip: 1.52056, kl: -3.14999, grad: 0.0243376
What do you see in this picture? I see a vase of flowers, a bowl of fruit, and a
plate of cookies. I decide to make a story based on this picture.  Once upon a
time, there was a girl named Lily who loved flowers. She had a beautiful vase of
flowers on her table, and she liked to arrange them every day. She also had a
bowl of fruit and a plate of cookies

step: 1, loss: 1.50926, clip: 1.51892, kl: -2.41535, grad: 0.0230939
What do you see in this picture? A) A family of deer B) A piccolo concerto C) A
beautiful mountain D) A group of friends playing soccer  Answer: B) A piccolo
concerto  Great job! Now, let's see if you can apply this knowledge to real life
situations.  3) Your friend tells you they want to learn how to play the piccolo

step: 2, loss: 1.50703, clip: 1.51938, kl: -3.08827, grad: 0.0235732
What do you see in this picture? Animal, fruit, vegetable, star, or flower?" The
mother and child looked at the pic