<a href="https://colab.research.google.com/github/karen-pal/notebook/blob/main/Image_To_Text_(LLM_%2B_CLIP).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image To Text (LLM + CLIP)

Notebook by Katherine Crowson (https://twitter.com/RiversHaveWings)

This notebook uses reinforcement learning to fine-tune a large language model ([Pythia 160M](https://github.com/EleutherAI/pythia) by default) to interpret a single image according to a [CLIP](https://arxiv.org/abs/2103.00020) based image/text matching loss.

In [None]:
#@title Licensed under the Apache License, Version 2.0 { display-mode: "form" }

# Copyright 2024 Katherine Crowson

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
#@title Check GPU

!nvidia-smi

In [None]:
#@title Install dependencies

!pip install open_clip_torch peft

In [None]:
#@title Import libraries

import textwrap

from google.colab import files
import open_clip
import peft
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch import optim
from torch.nn import functional as F
from tqdm.auto import tqdm

In [None]:
#@title Define necessary functions

print = tqdm.external_write_mode()(print)


def endless_range(start=0, step=1):
    """An endless range generator."""
    i = start
    while True:
        yield i
        i += step


def logp_completion(logits, tokens, mask):
    """Compute the log probabilities of completions given their prompts.

    Args:
        tokens: The tokens input to the model. Shape: (..., T).
        logits: The logits output from the model. Shape: (..., T, V).
        mask: A mask indicating which tokens should be included in the log probabilities. It should
            exclude prompt tokens and padding tokens. Shape: (..., T).

    Returns:
        The log probabilities of the completions given their prompts. Shape: (...).
    """
    logits = F.log_softmax(logits, dim=-1)
    logp_tokens = logits[..., :-1, :].gather(-1, tokens[..., 1:, None])[..., 0]
    return torch.sum(logp_tokens * mask[..., 1:], dim=-1)


In [None]:
#@title Upload image

uploaded = files.upload()
assert len(uploaded) == 1, "Please upload exactly one image."
image = Image.open(list(uploaded.keys())[0])

In [None]:
#@title Set parameters { display-mode: "form" }

#@markdown Generations from the LLM will be prefixed by the prompt:
prompt = "The theme of this image is"  #@param {type: 'string'}

#@markdown The number of tokens to sample from the LLM:
length = 50  #@param {type: 'integer'}

#@markdown The batch size:
batch_size = 64  #@param {type: 'integer'}

#@markdown The strength of the KL divergence penalty vs the original LLM:
#@markdown <br><small>The KL divergence penalty specifies the rate at which the optimizer will trade off a decrease in the angle (in radians) between the CLIP text and image embeddings and a decrease in the KL divergence between the model and the reference model.</small>
kl_weight = 4e-3  #@param {type: 'number'}

#@markdown The temperature at which to sample from the LLM:
temperature = 0.9  #@param {type: 'number'}


In [None]:
#@title Load models

clip_name = "ViT-L-14-336"
clip_pretrained = "openai"
model_name = "EleutherAI/pythia-160m-deduped"
device = torch.device("cuda:0")

# Load CLIP
clip_tokenizer = open_clip.get_tokenizer(clip_name)
clip_model, _, preprocess = open_clip.create_model_and_transforms(
    clip_name, pretrained=clip_pretrained, device=device
)
clip_model.eval().requires_grad_(False)

# Load language model
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map={"": device}
)

# Prepare LoRA
peft_config = peft.LoraConfig(
    peft.TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32,
    lora_alpha=8,
    lora_dropout=0.0,
    target_modules=[
        # For NeoX and Pythia
        "attention.query_key_value",
        "attention.dense",
        "mlp.dense_h_to_4h",
        "mlp.dense_4h_to_h",
        # For Llama and Mistral 7B
        # "self_attn.q_proj",
        # "self_attn.k_proj",
        # "self_attn.v_proj",
        # "self_attn.o_proj",
        # "mlp.gate_proj",
        # "mlp.up_proj",
        # "mlp.down_proj",
    ],
)
model = peft.get_peft_model(model, peft_config)

# Finish preparing model
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.train()
model.print_trainable_parameters()


In [None]:
#@title Optimize the LLM

# Settings
torch.set_float32_matmul_precision("high")

# Prepare the input image
image_for_clip = preprocess(image).unsqueeze(0).to(device)
with torch.cuda.amp.autocast():
    image_embed = clip_model.encode_image(image_for_clip).float()

# Prepare the prompt
inputs = tokenizer([prompt] * batch_size, return_tensors="pt").to(device)
input_len = inputs.input_ids.shape[1]
logp_mask = torch.tensor(
    [[False] * input_len + [True] * length] * batch_size, device=device
)

# Optimize the LLM
opt = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.95))

try:
    for i in tqdm(endless_range()):
        # Generate a batch of samples from the model
        model.eval()
        tokens = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            do_sample=True,
            min_new_tokens=length,
            max_new_tokens=length,
            pad_token_id=tokenizer.eos_token_id,
            temperature=temperature,
            top_k=0,
        )

        # Get the logits of the samples from the model and the reference model
        attention_mask = torch.cat(
            (inputs.attention_mask, torch.ones_like(tokens[:, input_len:])), dim=1
        )
        with torch.no_grad(), model.disable_adapter():
            outputs_ref = model(tokens, attention_mask=attention_mask, use_cache=False)
        model.train()
        outputs = model(tokens, attention_mask=attention_mask, use_cache=False)

        # Compute the log probability of the samples under the model and the reference model
        logp = logp_completion(outputs.logits / temperature, tokens, logp_mask)
        logp_ref = logp_completion(outputs_ref.logits / temperature, tokens, logp_mask)

        # Compute the CLIP loss
        texts = [tokenizer.decode(t, skip_special_tokens=True) for t in tokens]
        clip_tokens = clip_tokenizer(texts).to(device)
        with torch.cuda.amp.autocast():
            text_embeds = clip_model.encode_text(clip_tokens).float()
        cost_clip = torch.cosine_similarity(text_embeds, image_embed, dim=-1).arccos()

        # Compute the KL penalty
        cost_kl = logp.detach() - logp_ref

        # REINFORCE
        cost = cost_clip + kl_weight * cost_kl
        baseline = (cost.sum() - cost) / (cost.numel() - 1)
        box = torch.exp(logp - logp.detach())
        loss = torch.mean(box * cost + (1 - box) * baseline)

        # Update the model
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Print statistics and the best sample in the batch
        grad_norm = torch.cat(
            [p.grad.flatten() for p in model.parameters() if p.grad is not None]
        ).norm()
        print(
            f"step: {i}, loss: {loss.item():g}, clip: {cost_clip.mean().item():g}, kl: {cost_kl.mean().item():g}, grad: {grad_norm.item():g}"
        )
        best_text = texts[torch.argmin(cost).item()]
        print(textwrap.fill(best_text, width=80))
        print()

except KeyboardInterrupt:
    pass