In [1]:
from PIL import Image

In [2]:
import torch

In [3]:
import time

In [4]:
import psutil

In [6]:
import os

In [7]:
from text_image_token_processor_1 import PaliGemmaProcessor

In [8]:
from decoder_1 import KVCache,PaliGemmaForConditionalGeneration

In [9]:
from utils import load_hf_model

In [10]:
def move_inputs_to_device(model_inputs: dict,device: str):
    model_inputs = {k:v.to(device) for k,v in model_inputs.items()}
    return model_inputs

In [18]:
def get_model_inputs(
   processor: PaliGemmaProcessor,
    prompt: str,
    image_file_path: str,
    device: str
):
    image = Image.open(image_file_path)
    images = [image]
    prompts = [prompt]
    model_inputs = processor(text=prompts,images=images)
    model_inputs = move_inputs_to_device(model_inputs,device)
    return model_inputs

In [19]:
def test_inference(
   model: PaliGemmaForConditionalGeneration,
    processor: PaliGemmaProcessor,
    device: str,
    prompt: str,
    image_file_path: str,
    max_tokens_to_generate: int,
    temperature: float,
    top_p: float,
    do_sample:bool
):
    model_inputs = get_model_inputs(processor,prompt,image_file_path,device)
    input_ids = model_inputs['input_ids']
    attention_mask = model_inputs['attention_mask']
    pixel_values = model_inputs['pixel_values']
    kv_cache = KVCache()
    stop_token = processor.tokenizer.eos_token_id
    generated_tokens = []
    start_time = time.time()
    
    for _ in range(max_tokens_to_generate):
        
        outputs = model(
           input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            kv_cache=kv_cache
        )
        
        kv_cache = outputs['kv_cache']
        next_token_logits = outputs['logits'][:,-1,:]
        if do_sample:
            next_token_logits = torch.softmax(next_token_logits / temperature,dim=-1)
            next_token = _sample_top_p(next_token_logits,top_p)
        else:
            next_token = torch.argmax(next_token_logits,dim=-1,keepdim=True)
        assert next_token.size() == (1,1)
        
        next_token = next_token.squeeze(0)
        generated_tokens.append(next_token)
        
        if next_token.item() == stop_token:
            break
        
        input_ids = next_token.unsqueeze(-1)
        attention_mask = torch.cat(
          [attention_mask,torch.ones((1,1),device=input_ids.device)],dim=-1
        )
        
    end_time = time.time()
    latency = end_time - start_time
    
    generated_tokens = torch.cat(generated_tokens,dim=-1)
    
    decoded = processor.tokenizer.decode(generated_tokens,skip_special_tokens=True)

    print('Result!!!')
    print(prompt + decoded)
    print(f'infernence latency: {latency:.2f} seconds')
    
    process = psutil.Process(os.getpid())
    memory = process.memory_info().rss / 1024 / 1024 / 1024
    print(f'memory usage: {memory:.2f} GB')
    
    if device == 'cuda':
        gpu_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
        print(f'gpu memory usage: {gpu_memory:.2f} GB')
        torch.cuda.reset_peak_memory_stats()

        

In [20]:
def _sample_top_p(probs: torch.Tensor,p: float):
    probs_sort,prob_idx = torch.sort(probs,dim=-1,descending=True)
    probs_sum = torch.cumsum(probs_sort,dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort.div_(probs_sort.sum(dim=-1,keepdim=True))
    next_token = torch.multinomial(probs_sort,num_samples=1)
    next_token = torch.gather(prob_idx,-1,next_token)
    return next_token

In [21]:
def main(
   model_path: str = None,
    prompt: str = None,
    image_file_path: str = None,
    max_tokens_to_generate: int = 100,
    temperature: float = 0.8,
    top_p: float = 0.9,
    do_sample: bool = False,
    only_cpu: bool = False
):
    device = 'cpu'
    if not only_cpu:
        if torch.cuda.is_available():
            device = 'cuda'
        elif torch.backends.mps.is_available():
            device = 'mps'
    
    print('device in use:',device)
    print('loading model')
    start_time = time.time()
    model, tokenizer = load_hf_model(model_path,device)
    model = model.to(device).eval()
    
    
    num_image_tokens = model.config.vision_config.num_image_tokens
    image_size = model.config.vision_config.image_size
    
    processor = PaliGemmaProcessor(tokenizer,num_image_tokens,image_size)

    print(f'model loaded in {time.time() - start_time:.2f} seconds')
    print('running inference')
    with torch.no_grad():
        test_inference(
            model,
            processor,
            device,
            prompt,
            image_file_path,
            max_tokens_to_generate,
            temperature,
            top_p,
            do_sample
        )
    
    

In [None]:
main(
    model_path="/Users/liuchu/vision-launguage-model-from-scratch/paligemma-3b-pt-224/",
    prompt="describe the building:",
    image_file_path="/Users/liuchu/vision-launguage-model-from-scratch/test_images/image.png",
    max_tokens_to_generate=100,
    temperature=0.8,
    top_p=0.9,
    do_sample=False,
    only_cpu=False,
)

device in use: mps
loading model
