In [3]:
from datasets import load_dataset, load_from_disk
from datasets import Dataset, DatasetDict
import numpy as np


ds_names = ["medium", "bigcodebench", "math-hard"]
ds_name = ds_names[1] # select one of the datasets


def load_sft_dataset(ds_name: str, 
                     add_system_prompt: bool=True
                     ) -> DatasetDict:
    '''
    Loads the dataset corresponding to Tab 5 in the paper,
    and transforms it into the SFT dataset where the conversations flow 
    with selecting the chosen response.
    '''
    dataset = load_from_disk(f"../data/interactive-{ds_name}")
    
    if add_system_prompt:
        dataset = add_sys_prompt(dataset, key='prompt')

    dataset_dict = {}
    for split in dataset.keys():
        prompts = dataset[split]['prompt']
        last_responses = dataset[split]['chosen']
        indices = dataset[split]['idx']

        sft_lst = [
            prompts[max(subset, key=lambda j: len(prompts[j]))] + 
            [{'role': 'assistant', 'content': last_responses[max(subset, key=lambda j: len(prompts[j]))]}]
            for subset in (np.where(indices == i)[0] for i in np.unique(indices))
        ]
        
        idx_lst = list(np.unique(indices))
        
        dataset_dict[split] = Dataset.from_dict({'chat': sft_lst, 'idx': idx_lst})
    
    return DatasetDict(dataset_dict)

In [6]:
sft_ds = load_sft_dataset(ds_name, add_system_prompt=False)['train']
print(sft_ds)

num_total_turns = sum(len(chat['chat']) for chat in sft_ds)
print(f"Total number of (user + assistant) turns: {num_total_turns}")

Dataset({
    features: ['chat', 'idx'],
    num_rows: 500
})
Total number of (user + assistant) turns: 5254


In [55]:
import random

rand_indices = random.sample(range(len(sft_ds)), 3)
rand_indices

[105, 334, 158]

In [61]:
from rich.console import Console
from rich.panel import Panel
from rich.json import JSON
from rich import box

console = Console()

for idx in rand_indices:
    chat = sft_ds['chat'][idx]
    
    # Create title with styling
    console.print(f"[bold cyan]Index: {idx}[/bold cyan]", justify="center")
    
    # Process each message in the chat
    for i, message in enumerate(chat):
        role_color = "green" if message["role"] == "assistant" else "yellow"
        console.print(Panel(
            message["content"],
            title=f"[{role_color}]{message['role'].capitalize()}[/{role_color}]",
            border_style=role_color,
            box=box.ROUNDED,
            expand=False
        ))
    
    # Add separator between different chats
    if idx != rand_indices[-1]:
        console.print("[dim]" + "─" * 80 + "[/dim]")
        console.print()