In [None]:
import datetime
import json

import rich
import dotenv

import networkx
import huggingface_hub

In [None]:
from twon_lss.simulations.twon_base import (
    Simulation,
    SimulationArgs,
    Ranker,
    Agent,
    AgentInstructions,
)

from twon_lss.schemas import Post, User, Feed, Network
from twon_lss.utility import LLM, Message, Decay

In [None]:
ENV = dotenv.dotenv_values("../" * 3 + ".env")
AGENTS_CFG = json.load(open("./data/agents.json"))
rich.print(AGENTS_CFG)

In [None]:
RANKER = Ranker(
    decay=Decay(low=0.2, timedelta=datetime.timedelta(days=3)),
    llm=LLM(
        client=huggingface_hub.InferenceClient(api_key=ENV["HF_TOKEN"]),
        model="BAAI/bge-m3",
    ),
)
rich.print(RANKER)

In [None]:
USERS = [User() for _ in range(len(AGENTS_CFG["personas"]))]
rich.print(USERS)

In [None]:
NETWORK = Network.from_graph(networkx.random_regular_graph(3, len(USERS)), USERS)
networkx.draw(NETWORK.root)

In [None]:
FEED = Feed(
    [
        Post(user=user, content=post)
        for user, agents in zip(USERS, AGENTS_CFG["personas"], strict=False)
        for post in agents["history"]
    ]
)
rich.print(FEED)

In [None]:
AGENT_LLM = LLM(
    client=huggingface_hub.InferenceClient(api_key=ENV["HF_TOKEN"]),
    model="meta-llama/Meta-Llama-3-8B-Instruct",
)
rich.print(AGENT_LLM)

In [None]:
INDIVIDUALS = {
    user: Agent(
        llm=AGENT_LLM,
        instructions=AgentInstructions(
            persona=persona["description"], **AGENTS_CFG["instructions"]
        ),
        memory=[Message(role="assistant", content=post.content)],
    )
    for user, persona, post in zip(
        USERS, AGENTS_CFG["personas"], FEED.root, strict=False
    )
}
rich.print(INDIVIDUALS.get(USERS[0]))

In [None]:
simulation = Simulation(
    args=SimulationArgs(num_steps=2, num_posts_to_interact_with=2),
    ranker=RANKER,
    individuals=INDIVIDUALS,
    network=NETWORK,
    feed=FEED,
)

In [None]:
simulation()