# Generate stories using LLMs through local inference

In [None]:
import os
import json
from transformers import pipeline

In [None]:
# Param
results_base_path = "../../results"
dataset_name = "stories"
experiment_name = "exp1"

# Generator LLM
#gen_model = "cjvt/GaMS-1B-Chat"
gen_model = "utter-project/EuroLLM-1.7B-Instruct"

system_prompt = None

dataset_base_path = "data"

#lang = "cs"
lang = "sl"

In [None]:
pipeline = pipeline(
        "text-generation",
        model=gen_model,
        device_map="auto"
    )

def inference(pipeline, messages):
    response = pipeline(messages, max_length=500)
    return response[0]["generated_text"][-1]["content"]

In [None]:
def add_conversation_round(conversation, new_prompt):
    conversation.append({"role": "user", "content": new_prompt})
    response = inference(pipeline, conversation)
    conversation.append({"role": "assistant", "content": response})
    return conversation

In [None]:
# Load the dataset from local JSON
with open(os.path.join(dataset_base_path, dataset_name, f"{lang}.json"), "r", encoding="utf-8") as f:
    dataset = json.load(f)

In [None]:
# Create the directory structure
os.makedirs(os.path.join(results_base_path, dataset_name, experiment_name, gen_model, lang), exist_ok=True)

# Iterate through prompts
for id, sample in enumerate(dataset):
    prompt = sample["prompt"]

    # Initialize conversation history
    if system_prompt:
        conversation = [{"role": "system", "content": system_prompt}]
    else:
        conversation = []

    # Add the prompt to the conversation and get the response
    conversation = add_conversation_round(conversation, prompt)

    # Prepare result
    result = {
        "id": id,
        "conversation": conversation
    }

    # Save to JSON
    output_file = os.path.join(results_base_path, dataset_name, experiment_name, gen_model, lang, f"{id:06d}.json")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(result, f, indent=4, ensure_ascii=False)

    # Print progress
    print("id:", id)
    print("Prompt:", prompt)
    print("Response:", conversation[-1]["content"])
    print("-" * 50)