In [1]:
import datetime
import json
import logging
import random
from pathlib import Path

import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

logging.getLogger('transformers').setLevel(logging.ERROR)

In [2]:
class LLMParticipant:

    def __init__(
            self,
            config: dict
        ) -> None:

        self.config = config
        self.MODEL = config['MODEL']
        self.TEMPERATURE = config['TEMPERATURE']
        self.BATCH_SIZE = config['BATCH_SIZE']
        self.EXAMPLE = config['EXAMPLE']
        self.FORMAT_VERSION = config['FORMAT_VERSION']
        self.SYSTEM_CONTENT = config['SYSTEM_CONTENT']
        self.INPUTS = config['INPUTS']

        self.data = self.load_input_questions()

        print(f"Loading model {self.MODEL}... ", end="")
        self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL)
        self.model = AutoModelForCausalLM.from_pretrained(self.MODEL,
            # token='hf_xxx'
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            # local_files_only=True
        )
        print("done.")

        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.tokenizer.padding_side = 'left'

    def load_input_questions(
            self,
        ) -> dict:

        with Path(self.INPUTS).open(encoding='utf-8') as f:
            data = json.load(f)

        return data

    def assemble_messages(
            self,
            data: dict,
        ) -> list[list[dict]]:

        inputs = []
        for example in data.values():

            user_content = ""
            user_content += f"Question: {example['question']}\n\nProcess: "
            for step in example['explanation']:
                user_content += f"{step['step']}. {step['explanation']}. "

            model_family = self.MODEL.split('/')[0]
            if model_family in [""]:

                messages = [
                    {
                        "role": "system",
                        "content": self.SYSTEM_CONTENT
                    },
                    {
                        "role": "user",
                        "content": user_content
                    }
                ]

            elif model_family in ['mistralai']:
                messages = [
                    {
                        "role": "user",
                        "content": self.SYSTEM_CONTENT + ' ' + user_content
                    },
                ]

            inputs.append(messages)

        return inputs

    def apply_template_generate_response(
            self,
            inputs
        ) -> list[str]:

        responses = []
        batched_inputs = [inputs[i:i+self.BATCH_SIZE] for i in range(0, len(inputs), self.BATCH_SIZE)]
        for batch in batched_inputs:
            input_messages = self.tokenizer.apply_chat_template(batch, padding=True, tokenize=True, return_tensors="pt", add_generation_prompt=True).to(self.device)
            outputs = self.model.generate(input_messages,
                                          max_new_tokens=256,
                                          do_sample=True,
                                          temperature=self.TEMPERATURE
                                          )
            responses += self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

        return responses

    def save_results(
            self,
            results: list[dict]
        ):

        output_filename = datetime.datetime.isoformat(datetime.datetime.now())
        output_path = Path("outputs", f"{Path(self.INPUTS).stem}", f"{output_filename}.json")
        log = Path("..", "outputs", f"{Path(self.INPUTS).stem}", 'log.json')
        print(f"Saving results in {output_path}... ", end="")

        outputs = {
            "config": self.config,
            "results": results
        }

        with output_path.open('w', encoding='utf-8') as f:
            json.dump(outputs, f, indent=4)
        print("done.")

        with log.open('r', encoding='utf-8') as f:
            log_data = json.load(f)

        log_data[output_filename] = self.config
        # sort log
        log_data = dict(sorted(log_data.items()))

        with log.open('w', encoding='utf-8') as f:
            json.dump(log_data, f, indent=4)

    def participate(
            self
        ):

        print(f"Assembling messages for {len(self.data)} inputs... ", end="")
        inputs = self.assemble_messages(self.data)
        print("done.")
        print(f"Generating responses for {len(inputs)} inputs... ")
        responses = self.apply_template_generate_response(inputs)
        print("Done.")
        outputs = []
        for q, r in zip(self.data.keys(), responses):
            t = self.data[q]['template_id']
            outputs.append({
                "question_id": q,
                "template_id": t,
                "response": r,
            })

        self.save_results(outputs)

        return outputs


In [5]:
MODEL = 'google/gemma-7b'
TEMPERATURE = 0.3
BATCH_SIZE = 20
EXAMPLE = '[1, 2, 3, 4]'
FORMAT_VERSION = 2
SYSTEM_CONTENT = f"Your role is to select, from a list the steps, those that are most important for inclusion in a summary explanation of that process. Format your output as a list, for example {EXAMPLE}. Output only this short summary paragraph and nothing else.".format(EXAMPLE)
INPUTS = '../resources/data/select.json'
DESCRIPTION = "testing"

config = {
    "MODEL": MODEL,
    "TEMPERATURE": TEMPERATURE,
    "BATCH_SIZE": BATCH_SIZE,
    "EXAMPLE": EXAMPLE,
    "FORMAT_VERSION": FORMAT_VERSION,
    "SYSTEM_CONTENT": SYSTEM_CONTENT,
    "INPUTS": INPUTS,
    "DESCRIPTION": DESCRIPTION
}

In [6]:
LLMParticipant(config=config).participate()

Loading model google/gemma-7b... 

config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.11G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

done.
Assembling messages for 400 inputs... 

UnboundLocalError: local variable 'messages' referenced before assignment