# Conversation Templates

> Defines structures and templates for handling conversations in LLaVA-style models.

In [1]:
#| default_exp conversation

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
import sys
from pathlib import Path
import os

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    print(f"Project root already in sys.path: {project_root_str}")

Adding project root to sys.path: /workspace/llava


In [4]:
#| export
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Dict, Any

Project root already in sys.path: /workspace/llava


This module defines the conversation structures and templates used for formatting model inputs, similar to the reference LLaVA implementation. It includes the `SeparatorStyle` enum and the `Conversation` dataclass.

In [5]:
#| export
class SeparatorStyle(Enum):
    """Different separator styles for conversations."""
    # Generic styles
    SINGLE = auto()
    TWO = auto()
    MPT = auto()
    PLAIN = auto() # Special case for simple image-caption pairs
    LLAMA_2 = auto()
    # Specific model styles (add as needed)
    VICUNA = auto() # Equivalent to TWO for older versions
    CHATML = auto()
    CHATGLM = auto()
    DOLLY = auto()
    RWKV = auto()
    PHOENIX = auto()
    ROBIN = auto()
    FALCON_CHAT = auto()

In [6]:
SeparatorStyle

An enumeration.

In [7]:
#| export
@dataclasses.dataclass
class Conversation:
    """A class that manages prompt generation and conversation history for different models."""
    # The system prompt message
    system: str
    # Roles for user and assistant
    roles: List[str]
    # The conversation history messages. List of lists, where each inner list contains [role, message]. Role is string, message is string or None.
    messages: List[List[str]]
    # Message offset
    offset: int
    # Separator style
    sep_style: SeparatorStyle
    # Separator token(s)
    sep: str
    # Optional second separator token
    sep2: str = None
    # Stop criteria (list of stop strings or token IDs)
    stop_str: Union[str, List[str]] = None
    # Stop token IDs (list of token IDs)
    stop_token_ids: List[int] = None

    def get_prompt(self) -> str:
        """Generates the prompt string based on the conversation history and style."""
        ret = ""
        # Handle system prompt based on style
        if self.sep_style == SeparatorStyle.PLAIN:
            # Plain style only concatenates messages with sep
            # Often used for image-caption tasks where system prompt is empty
            pass # System prompt usually ignored or empty
        elif self.system:
            ret += self.system
            # Add separator after system prompt if needed by the style
            if self.sep_style not in [SeparatorStyle.CHATML]: # CHATML includes system in messages
                 ret += self.sep

        # Format messages
        for i, (role, message) in enumerate(self.messages):
            if message:
                if self.sep_style == SeparatorStyle.PLAIN:
                     # Append role (often empty) and message, followed by separator
                     ret += role + message + self.sep
                elif self.sep_style == SeparatorStyle.CHATML:
                     # CHATML format: <|im_start|>role\nmessage<|im_end|>sep
                     ret += role + "\n" + message + self.sep + ("\n" if i < len(self.messages)-1 else "")
                elif self.sep_style == SeparatorStyle.TWO or self.sep_style == SeparatorStyle.VICUNA:
                     # Vicuna V1 format: ROLE: message<sep>
                     # System prompt is handled above. 'sep' is space, 'sep2' is EOS.
                     ret += role + ": " + message + self.sep # Add space after message
                else:
                     # Default/Fallback: ROLE message SEP
                     ret += role + message + self.sep 
            else:
                # Handle cases where message is None (e.g., assistant prompt marker)
                if self.sep_style == SeparatorStyle.PLAIN:
                     ret += role + self.sep # Append role and separator (role might be empty)
                elif self.sep_style == SeparatorStyle.CHATML:
                     ret += role + "\n" # Append role marker only, like <|im_start|>assistant\n
                elif self.sep_style == SeparatorStyle.TWO or self.sep_style == SeparatorStyle.VICUNA:
                     ret += role + ":" # Append role marker like ASSISTANT:
                else:
                     ret += role # Append role only
                     
        # Add final separator if required (e.g., Vicuna v1 EOS)
        if self.sep_style == SeparatorStyle.TWO or self.sep_style == SeparatorStyle.VICUNA:
            if ret.endswith(self.sep):
                 # Remove the trailing space added by the loop if sep2 exists
                 ret = ret[:-len(self.sep)] 
            if self.sep2:
                ret += self.sep2
        
        # Remove trailing separators if style demands
        # if self.sep_style != SeparatorStyle.PLAIN and ret.endswith(self.sep):
        #      ret = ret[:-len(self.sep)]
              
        # Handle plain specifically: it should only have one final separator if needed
        if self.sep_style == SeparatorStyle.PLAIN:
             # Often expects a single newline at the end if sep is newline
             # Let's ensure it ends with exactly one sep if sep is defined
             if self.sep:
                  if ret.endswith(self.sep):
                       pass # Already ends with sep
                  else:
                       ret += self.sep
             pass # Keep it simple, usually just image\ncaption\n
             # Strip extra whitespace at the end for plain
             ret = ret.rstrip()
             # Re-add the single newline if it was the separator
             if self.sep == "\n": ret += "\n"

        return ret

    def append_message(self, role: str, message: str | None):
        """Appends a new message to the conversation history."""
        self.messages.append([role, message])

    def copy(self):
        """Creates a deep copy of the conversation object."""
        return Conversation(
            system=self.system,
            roles=list(self.roles), # Shallow copy roles list
            messages=[list(msg) for msg in self.messages], # Deep copy messages
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            stop_str=self.stop_str,
            stop_token_ids=list(self.stop_token_ids) if self.stop_token_ids else None,
        )

    def dict(self):
        """Converts the conversation object to a dictionary."""
        # Use dataclasses.asdict if needed, or manually construct
        return {
            "system": self.system,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
            "sep_style": self.sep_style.name, # Store enum name
            "sep": self.sep,
            "sep2": self.sep2,
            "stop_str": self.stop_str,
            "stop_token_ids": self.stop_token_ids,
        }

In [8]:
show_doc(Conversation)

---

### Conversation

>      Conversation (system:str, roles:List[str], messages:List[List[str]],
>                    offset:int, sep_style:__main__.SeparatorStyle, sep:str,
>                    sep2:str=None, stop_str:Union[str,List[str]]=None,
>                    stop_token_ids:List[int]=None)

*A class that manages prompt generation and conversation history for different models.*

## Conversation Templates

Define specific conversation templates for different models/stages.

In [9]:
#| export
# Template for Stage 1 Pretraining ('plain' style)
conv_llava_plain = Conversation(
    system="", # No system prompt
    roles=("", ""), # Roles are often ignored, but placeholders needed
    messages=[], # History is built dynamically
    offset=0,
    sep_style=SeparatorStyle.PLAIN,
    sep="\n", # Separator is a newline
    stop_str=None,
    stop_token_ids=None
)

# Template for Vicuna v1 (Instruction Tuning - Stage 2)
conv_vicuna_v1 = Conversation(
    system="A chat between a curious user and an artificial intelligence assistant. "
           "The assistant gives helpful, detailed, and polite answers to the user's questions.",
    roles=("USER", "ASSISTANT"),
    messages=[],
    offset=0,
    sep_style=SeparatorStyle.TWO, # Uses two separators
    sep=" ", # Space separator between turns
    sep2="</s>", # EOS token as the second separator (end of conversation)
    stop_str="</s>",
    # stop_token_ids=[2], # Example, assuming tokenizer.eos_token_id is 2
    # Note: stop_token_ids should be set based on the actual tokenizer used later
    stop_token_ids=None
)

# Add other templates as needed (e.g., 'v0', 'llama_2', 'chatml') based on LLaVA reference
# ...

# --- Template Dictionary --- 
conv_templates = {
    "plain": conv_llava_plain,
    "v1": conv_vicuna_v1,
    # Add other templates here
    # "v0": conv_vicuna_v0, 
    # "vicuna_v1": conv_vicuna_v1, # Alias
}

# --- Default Conversation --- 
# Set a default conversation template (can be overridden)
default_conversation = conv_vicuna_v1 

# --- Functions to get conversation templates ---
def get_conv_template(name: str) -> Conversation:
    """Gets a conversation template by name.
    
    Args:
        name: The name of the conversation template.
        
    Returns:
        A deep copy of the requested conversation template.
        
    Raises:
        ValueError: If the template name is not found.
    """
    if name not in conv_templates:
        raise ValueError(f"Unknown conversation template: {name}. Available templates: {list(conv_templates.keys())}")
    return conv_templates[name].copy()

## Tests

In [10]:
#| test
# --- Test Plain Template --- 
plain_conv = get_conv_template("plain")
plain_conv.append_message("", "<image>") # Role is empty for plain image part
plain_conv.append_message("", "A red block.") # Role is empty for plain caption part
plain_prompt = plain_conv.get_prompt()

print("--- Plain Template Test ---")
print("Formatted Prompt:")
print(plain_prompt)
assert plain_prompt == "<image>\nA red block.\n", f"Plain template mismatch: '{plain_prompt}'"

# --- Test Vicuna V1 Template --- 
v1_conv = get_conv_template("v1")
v1_conv.append_message(v1_conv.roles[0], "<image>\nDescribe it.") # User turn
v1_conv.append_message(v1_conv.roles[1], "It is red.") # Assistant turn
v1_conv.append_message(v1_conv.roles[0], "What shape?") # User turn
v1_conv.append_message(v1_conv.roles[1], "It is square.") # Assistant turn
v1_conv.append_message(v1_conv.roles[1], None) # Add prompt for next assistant turn
v1_prompt = v1_conv.get_prompt()

print("\n--- Vicuna V1 Template Test ---")
print("Formatted Prompt:")
print(v1_prompt)

# Expected output construction
expected_v1 = (
    "A chat between a curious user and an artificial intelligence assistant. "
    "The assistant gives helpful, detailed, and polite answers to the user's questions." + v1_conv.sep +
    "USER: <image>\nDescribe it." + v1_conv.sep + 
    "ASSISTANT: It is red." + v1_conv.sep2 + # EOS after assistant turn
    "USER: What shape?" + v1_conv.sep + 
    "ASSISTANT: It is square." + v1_conv.sep2 + # EOS after assistant turn
    "ASSISTANT:" # Prompt for next assistant turn
)

assert v1_prompt.strip() == expected_v1.strip(), f"Vicuna v1 template mismatch. Expected:\n'{expected_v1}'\nGot:\n'{v1_prompt}'"

--- Plain Template Test ---
Formatted Prompt:
<image>
A red block.

--- Vicuna V1 Template Test ---
Formatted Prompt:
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>
Describe it. ASSISTANT: It is red.</s>USER: What shape? ASSISTANT: It is square.</s>ASSISTANT:


In [11]:
#| hide
import nbdev; nbdev.nbdev_export()