In [None]:
from llama_cpp import (
    Llama,
    LlamaGrammar
)  # same speed as llama_cpp_cuda, but with tensor cores
import json
import yaml
import pandas as pd
from pathlib import Path
import logging


In [None]:
# Load few-shot messages
with open(snakemake.input.few_shot_messages) as f:
    few_shot_messages = json.load(f)

In [None]:
# Load system message (instruction)

system_prompt = Path(snakemake.input.instruction).read_text()

In [None]:
# Load this split
with open(snakemake.input.json_split) as f:
    samples_request_messages = json.load(f)

In [None]:
# load the model
llm = Llama(
    model_path=snakemake.input.model,
    n_ctx=32000,  # The max sequence length to use - note that longer sequence lengths require much more resources
    n_threads=25,  # The number of CPU threads to use
    n_gpu_layers=86,  # High enough number to load the full model
    verbose=True,
)

In [None]:
def messages_to_prompt(messages):
    """
     `llm.create_chat_completion` did not work unfortunately for some reason,
     so I needed to aggregate the prompt for myself
    """
    return "<s>" + "".join(
        [
            '[INST] ' + message['content'] + ' [/INST]' if message['role'] == "user" else message['content'] + "</s>"
            for message in messages
        ]
    )

def llm_run(messages, num_retries=3):
    seed = hash(messages[-1]["content"]) % 2**30
    for i in range(num_retries):
        output = llm( 
            messages_to_prompt(messages),
            max_tokens=2048,
            stop=["</s>"],  # stop token for Mixtral
            seed=seed + i,
            grammar=LlamaGrammar.from_json_schema(
                Path(snakemake.input.json_schema).read_text()
            ),
            temperature=snakemake.params.temperature + i * (1.0/num_retries),
            top_p=0.9 + i * (0.1 / 3),
            top_k=50 + i * (100//num_retries),
        )
    
        try:
            result_serial = output["choices"][0]["text"].strip()
            logging.debug(result_serial)
            result = json.loads(result_serial)
            return result
        except Exception as e:
            logging.warning(f"Failed to decode : {type(e)}:{e}, {output}. Retrying {i}")
    else:
        logging.warning(
            f"Prompt template continuously failed: \n{messages[-1]}\n{output}"
        )
        return []

In [None]:
# TODO load chunked
responses = {}
for sample_id, sample_request_message in samples_request_messages.items():
    if False:  # model.startswith("gpt-4"):
        messages = [{"role": "assistant", "content": system_prompt}]
    else:
        messages = [{"role": "user", "content": system_prompt},
                    {"role": "assistant", "content": "Understood. Please provide the sample information, so I can generate a corresponding conversation."}]

    messages.extend(few_shot_messages)  # TODO 
    try:
        sample_request_message["content"] += snakemake.params.prompt_reminder
    except AttributeError:
        pass
        
    messages.append(sample_request_message)
        
    response = llm_run(messages)

    responses[sample_id] = response

In [None]:
with open(snakemake.output.generated_conversations, "w") as f:
    json.dump(responses, f)