In [None]:
from threading import Thread
from typing import Literal, Optional, TypedDict

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer,
)

assistant_model_id = "julep-ai/samantha-7b-ds-03"
model_id = "julep-ai/samantha-33b-ds-03"

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
# assistant_model = AutoModelForCausalLM.from_pretrained(assistant_model_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)

class StopSequenceCriteria(StoppingCriteria):
    def __init__(
        self,
        tokenizer,
        stop: list[str],
        input_length,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        
        self.stop = stop
        self.tokenizer = tokenizer
        self.input_length = input_length

    def __call__(
        self,
        input_ids: torch.LongTensor,
        scores: torch.FloatTensor,
        **kwargs,
    ) -> bool:

        input_ids = input_ids.long().tolist()
        new_input_ids = [i[self.input_length:] for i in input_ids]
        
        stops = []
        for text in self.stop:
            stop = []
            for input_id in new_input_ids:
                decoded = self.tokenizer.decode(input_id)
                stop.append(text in decoded)
            stops.append(all(stop))

        return any(stops)

class ChatMLMessage(TypedDict):
    name: Optional[str] = None
    role: Literal["assistant", "system", "user"]
    content: str

ChatML = list[ChatMLMessage]

def message_role_to_prefix(message: ChatMLMessage) -> str:
    match message:
        case {"role": "system", "name": name, **rest}:
            return name
        case {"role": "user", "name": name, **rest}:
            return f"person ({name})" if name else "person"
        case {"role": "assistant", "name": name, **rest}:
            return f"me ({name})" if name else "me"

def to_prompt(
    messages: ChatML,
    bos: str = "<|section|>",
    eos: str = "<|endsection|>",
    suffix: str = "\n<|section|>me (Samantha)\n",
) -> str:
    prompt = "\n".join([
        f"{bos}{message_role_to_prefix(message)}\n{message['content']}{eos}"
        for message in messages
    ])

    return prompt + suffix

def remove_stops(generator, stop: list[str] = []):
    
    for item in generator:
        for s in stop:
            item = item.split(s)[0]

        if item:
            yield item
    
def generate(
    messages: ChatML,
    stop: list[str] = [],
    timeout: int = 10,
    stream: bool = False,
    **kwargs
) -> TextIteratorStreamer | str:
    
    # Prepare input
    prompt = to_prompt(messages)
    inputs = tokenizer(prompt, return_tensors="pt").to(0)
    input_length = len(inputs["input_ids"].squeeze().tolist())

    # Stopping criteria
    stopping_criteria = (
        StoppingCriteriaList([StopSequenceCriteria(
            tokenizer=tokenizer,
            stop=stop,
            input_length=input_length,
        )])
        if stop else None
    )

    # Generation parameters
    generation_kwargs = {
        # defaults
        "max_new_tokens": 40, 
        "repetition_penalty": 1.1,
        "no_repeat_ngram_size": 4,
        "renormalize_logits": True,
        "temperature": 1.1,
        #
        # overrides
        **kwargs,
        #
        # required params
        "stopping_criteria": stopping_criteria,
        # "assistant_model": assistant_model,
        #
        # add inputs
        **inputs,
    }

    # If not streaming, run directly and return result
    if not stream:
        [output] = model.generate(**generation_kwargs)
        result = tokenizer.decode(output[input_length:])

        # Remove the stop sequence at the end (needed)
        for s in stop:
            result = result.split(s)[0].strip()
        
        return result
    
    # If streaming, prepare streamer
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=timeout)
    generation_kwargs["streamer"] = streamer

    # and start generating in new thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # stop sequence filter
    return remove_stops(streamer, stop)

In [19]:
chatml = [
    ChatMLMessage(role="system", name="situation", content="I am talking to Diwank"),
    ChatMLMessage(role="user", name="Diwank", content="Hey Samantha!"),
]

In [None]:
r = generate(
    chatml, 
    max_new_tokens=40, 
    stop=["<|endsection|>", "\n"],
    temperature=1.2,
    stream=True,
)

list(r)

