In [None]:
import atreides

model = atreides.train_model(
    base_model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    train_dataset="capitals/train",
    val_dataset="capitals/val",
)

In [None]:
from openai import AsyncOpenAI
from prisma import Prisma

client = AsyncOpenAI()
prisma = Prisma()

await prisma.connect()

datasets = await prisma.dataset.find_many()

print(datasets)

[]


In [4]:
import json

dataset = await prisma.dataset.create(
    {
        "name": "capitals/train",
        "prompts": {
            "create": [
                {
                    "messages": json.dumps(
                        [{"role": "user", "content": "What is the capital of France?"}]
                    ),
                    "metadata": json.dumps({})
                }
            ]
        },
    }
)
dataset

Dataset(id=1, name='capitals/train', prompts=None)

In [8]:
completion = await prisma.completion.create(
    {
        "prompt_id": 1,
        "model": "gpt-4o",
        "results": json.dumps([{"role": "assistant", "content": "Paris"}]),
    }
)
completion

Completion(id=1, prompt_id=1, prompt=None, model='gpt-4o', results='[{"role": "assistant", "content": "Paris"}]', rewards=[Reward(id=1, completion_id=1, completion=None, title='Accuracy', description='1.0 if the capital is correct, 0.0 otherwise', value=1.0, weight=1.0)])

In [7]:
dataset = await prisma.dataset.find_first(include={"prompts": True})
dataset.prompts

[Prompt(id=1, messages='[{"role": "user", "content": "What is the capital of France?"}]', metadata='{}', datasets=None, completions=None)]

In [3]:
import atreides

try:
    train_dataset = await atreides.get_dataset("capitals/train")
except atreides.NotFoundError:
    train_dataset = await atreides.create_dataset(
        name="capitals/train",
        tasks=[
            {
                "messages": [
                    {"role": "user", "content": "What is the capital of France?"}
                ],
                "pattern_rewards": [
                    {
                        "regex_pattern": r"The capital of France is (\w+)",
                        "expected_capture_groups": ["Paris"],
                    },
                ],
            },
        ],
    )

try:
    val_dataset = await atreides.get_dataset("capitals/val")
except atreides.NotFoundError:
    val_dataset = await atreides.create_dataset(
        name="capitals/val",
        tasks=[
            {
                "messages": [
                    {"role": "user", "content": "What is the capital of England?"}
                ],
                "pattern_rewards": [
                    {
                        "regex_pattern": r"The capital of England is (\w+)",
                        "expected_capture_groups": ["London"],
                    },
                ],
            },
        ],
    )