In [None]:
import json
import logging

import rich
import dotenv

import networkx

import random

from pathlib import Path

In [None]:
from twon_lss.simulations.twon_base import (
    Simulation,
    SimulationArgs,
    RankerArgs,
    SemanticSimilarityRanker,
    Agent,
    AgentInstructions,
)

from twon_lss.schemas import Post, User, Feed, Network
from twon_lss.utility import LLM, Message

In [None]:
logging.getLogger().setLevel(logging.DEBUG)

In [None]:
# Fixed across experiments
LENGTH_AGENT_MEMORY: int = 21 # actions (42 prompt-completion pairs) => Past 7 Rounds (2 Reads + 1 Write)
PERSISTENCE: int = 4
NUM_POSTS_TO_INTERACT_WITH: int = 2
STEPS: int = 5

# Varies across experiments
NUM_AGENTS: int = 8

# Fix seed
random.seed(42)

In [None]:
ENV = dotenv.dotenv_values("../" * 2 + ".env")
AGENTS_INSTRUCTIONS_CFG = json.load(open("./data/agents.instructions.json"))
AGENTS_PERSONAS_CFG = json.load(open("./data/agents.personas.json"))
AGENTS_PERSONAS_CFG = random.sample(AGENTS_PERSONAS_CFG, k=NUM_AGENTS)

rich.print(AGENTS_INSTRUCTIONS_CFG)
rich.print(len(AGENTS_PERSONAS_CFG))

In [None]:
RANKER = SemanticSimilarityRanker(
    llm=LLM(api_key=ENV["HF_TOKEN"], model="mxbai-embed-large-v1", url="https://router.huggingface.co/hf-inference/models/mixedbread-ai/mxbai-embed-large-v1/pipeline/feature-extraction"),
    args=RankerArgs(persistence=PERSISTENCE)
)

In [None]:
AGENT_LLM = LLM(api_key=ENV["HF_TOKEN"], model="meta-llama/Llama-3.1-8B-Instruct:cerebras")

In [None]:
usernames = [LLM.generate_username(AGENT_LLM, history.get("messages", [])) for history in AGENTS_PERSONAS_CFG]
USERS = [User(id=username) for username in usernames]
usernames[:5]

In [None]:
NETWORK = Network.from_graph(networkx.complete_graph(n = len(USERS)), USERS)
networkx.draw(NETWORK.root)

In [None]:
histories = [
    [message.get("content") for message in personas["messages"] if message.get("role") == "assistant"]
    for personas in AGENTS_PERSONAS_CFG
]
histories[0][:2]

In [None]:
FEED = Feed(
    [
        Post(user=user, content=post)
        for user, history in zip(USERS, histories, strict=False)
        for post in history[:2]
    ]
)
len(FEED)

In [None]:
INDIVIDUALS = {
    user: Agent(
        llm=AGENT_LLM,
        instructions=AgentInstructions(
            persona=AGENTS_INSTRUCTIONS_CFG["persona"], **AGENTS_INSTRUCTIONS_CFG["actions"]
        ),
        memory=history["messages"][1:LENGTH_AGENT_MEMORY*2 + 1],
        memory_length=LENGTH_AGENT_MEMORY
    )
    for user, history in zip(
        USERS, AGENTS_PERSONAS_CFG, strict=False
    )
}
rich.print(INDIVIDUALS.get(USERS[0]))

In [None]:
simulation = Simulation(
    args=SimulationArgs(num_steps=STEPS, num_posts_to_interact_with=NUM_POSTS_TO_INTERACT_WITH),
    ranker=RANKER,
    individuals=INDIVIDUALS,
    network=NETWORK,
    feed=FEED,
    output_path=Path("Output/").mkdir(exist_ok=True) or "Output/"
)

In [None]:
simulation()