In [2]:
import json
import uuid

import torch
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from transformers import TextIteratorStreamer
from threading import Thread

from loguru import logger 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
GB = 1 << 30

worker_id = str(uuid.uuid4())[:6]
global_counter = 0
model_path = "SkunkworksAI/BakLLaVA-1"
model_name = 'BakLLaVA-1'
model_base = None

model_semaphore = None

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path, model_base, model_name, True, False, device)
    



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at SkunkworksAI/BakLLaVA-1 were not used when initializing LlavaMistralForCausalLM: ['model.vision_tower.vision_tower.vision_model.encoder.layers.5.self_attn.v_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.18.self_attn.q_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.19.mlp.fc1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.12.self_attn.q_proj.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.1.layer_norm2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.22.self_attn.out_proj.weight', 'model.vision_tower.vision_tower.vision_model.embeddings.class_embedding', 'model.vision_tower.vision_tower.vision_model.encoder.layers.10.layer_norm1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.10.mlp.fc1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.6.layer_norm1.weight', 'model.vision_tower.vision_tower.vision

In [6]:

# path/filename: params_dataclass.py
from dataclasses import dataclass, fields
from typing import Optional, List
from queue import Queue 
from threading import Thread

@dataclass
class StreamParams:
    """
    Dataclass for parameters used in generate_stream method.
    """
    prompt: str
    images: Optional[List[str]] = None
    temperature: float = 1.0
    top_p: float = 1.0
    max_new_tokens: int = 256
    stop: Optional[str] = None

    @staticmethod
    def from_dict(params_dict: dict) -> 'StreamParams':
        """
        Initialize the dataclass instance from a dictionary.
        """
        return StreamParams(**{f.name: params_dict.get(f.name, getattr(StreamParams, f.name)) for f in fields(StreamParams)})

    def to_dict(self) -> dict:
        """
        Convert the dataclass instance to a dictionary.
        """
        return {f.name: getattr(self, f.name) for f in fields(StreamParams)}
    
# path/filename: generate_stream_functions.py
from dataclasses import dataclass
from typing import Generator, Any
import torch
import json
from threading import Thread

@torch.inference_mode()
def generate_stream(
    params: StreamParams,
    tokenizer: Any,
    model: Any,
    image_processor: Any,
    output_queue: Queue
) -> Generator[bytes, None, None]:
    """
    Generates a stream of text using the provided model and tokenizer, based on the given parameters.

    :param params: StreamParams object containing parameters for generation.
    :param tokenizer: Tokenizer to use for text processing.
    :param model: The model used for generating text.
    :param image_processor: Processor for handling image inputs.
    :return: A generator yielding the generated text stream.
    """
    try:
        prompt = params.prompt
        ori_prompt = prompt
        images = params.images or []
        num_image_tokens = 0

        if len(images) > 0:
            if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
                raise ValueError("Number of images does not match number of <image> tokens in prompt")

            images = [load_image_from_base64(image) for image in images]
            images = process_images(images, image_processor, model.config)

            if type(images) is list:
                images = [image.to(model.device, dtype=torch.float16) for image in images]
            else:
                images = images.to(model.device, dtype=torch.float16)

            replace_token = DEFAULT_IMAGE_TOKEN
            if getattr(model.config, 'mm_use_im_start_end', False):
                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
            prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)

            num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
        else:
            images = None
        image_args = {"images": images}

        temperature = params.temperature
        top_p = params.top_p
        max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
        max_new_tokens = min(params.max_new_tokens, 1024)
        stop_str = params.stop
        do_sample = temperature > 0.001

        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
        keywords = [params.stop] if params.stop is not None else []
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)

        max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)

        if max_new_tokens < 1:
            output_queue.put(json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0")
            return

        # Start generation in a separate thread
        thread = Thread(target=model.generate, kwargs=dict(
            inputs=input_ids,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            streamer=streamer,
            stopping_criteria=[stopping_criteria],
            use_cache=True,
            **image_args
        ))
        thread.start()

        # Stream generated text
        generated_text = ori_prompt
        for new_text in streamer:
            generated_text += new_text
            if generated_text.endswith(stop_str):
                generated_text = generated_text[:-len(stop_str)]
            output_queue.put(json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0")

    except Exception as e:
        # Error handling
        logger.error(f"Error in generate_stream: {e}")
        output_queue.put(json.dumps({"error": str(e)}).encode() + b"\0")
        output_queue.put(None)

def generate_stream_gate(params: StreamParams, tokenizer, model, image_processor) -> Generator[bytes, None, None]:
    output_queue = Queue()
    thread = Thread(target=generate_stream, args=(params, tokenizer, model, image_processor, output_queue))
    thread.start()

    while True:
        result = output_queue.get()
        if result is None:  # Use a sentinel value to indicate completion
            break
        yield result

    thread.join() 

from llava.conversation import default_conversation
conv = default_conversation.copy()
logger.info(conv)
stop = conv.sep

# Example Usage
params_dict = {
    "prompt": "Example prompt",
    "temperature": 0.7,
    "top_p": 0.9,
    "images": ["image1_base64", "image2_base64"],
    "max_new_tokens": 300
}
prompt = "Caption this image"
params = StreamParams(prompt=prompt, stop=stop)



for gen in generate_stream_gate(params, tokenizer,model, image_processor):
    print("Generating")
    logger.info(gen)

[32m2023-12-01 11:34:00.737[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m145[0m - [1mConversation(system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=('Human', 'Assistant'), messages=[['Human', 'What are the key differences between renewable and non-renewable energy sources?'], ['Assistant', 'Renewable energy sources are those that can be replenished naturally in a relatively short amount of time, such as solar, wind, hydro, geothermal, and biomass. Non-renewable energy sources, on the other hand, are finite and will eventually be depleted, such as coal, oil, and natural gas. Here are some key differences between renewable and non-renewable energy sources:\n1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable energy sources are finite and will eventually run out.\n2. Environmental impact: Renewable energy 

Generating


In [None]:
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
import torch
import os
import json
from tqdm import tqdm
import shortuuid


from llava.utils import disable_torch_init


# new stopping implementation
class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords, tokenizer, input_ids):
        self.keywords = keywords
        self.tokenizer = tokenizer
        self.start_len = None
        self.input_ids = input_ids

    def __call__(
        self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
    ) -> bool:
        if self.start_len is None:
            self.start_len = self.input_ids.shape[1]
        else:
            outputs = self.tokenizer.batch_decode(
                output_ids[:, self.start_len :], skip_special_tokens=True
            )[0]
            for keyword in self.keywords:
                if keyword in outputs:
                    return True
        return False


conv.append_message(conv.roles[0], qs)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids)

