In [24]:
import json
from enum import Enum
from typing import List

from minichain.agent import Agent
from minichain.functions import tool
from pydantic import BaseModel, Field

from dumbdb import store
from permuted_layers import complete, get_cross_entropy_loss, permuted
from settings import model, model_name, tasks


class PermutedLayerExperimentConfig(BaseModel):
    """Configures the order in which to apply the models' layers"""
    layers: List[int] = Field(
        ...,
        description="The layers to use in the corrupted model. [0, 1, ..., 12] leaves the model unchanged. [0, 1, ..., i] keeps only layers 0-1. [0, 4, 2] applies layers 0, then 4, then 2, etc.",
    )


# Dynamically create TaskNameEnum with task names from the 'tasks' list
TaskNameEnum = Enum("TaskNameEnum", {task["name"]: task["name"] for task in tasks})


def permuted_layers_experiment(
    configurations: List[PermutedLayerExperimentConfig] = Field(
        ..., description="Defines the configurations for which to run the experiment"
    ),
    experiment_name: str = Field(..., description="Name for this experiment"),
    task_name: TaskNameEnum = Field(..., description="Evaluation task"),
):
    """Corrupt the model by permuting / skipping / repeating its transformer blocks.

    Returns a completion from the corrupted model aling with the 'average_answer_token_loss' (avg loss of the tokens of the ideal completion) for each (prompt, answer) pair in the task.
    """
    outputs = []
    for config in configurations:
        corrupted = permuted(model, config['layers'])
        task = [i for i in tasks if i["name"] == task_name][0]
        for example in task["examples"]:
            completion = complete(corrupted, example["prompt"])
            loss = get_cross_entropy_loss(corrupted, example)
            outputs += [store(
                {
                    "model": model_name,
                    "task_name": task["name"],
                    "experiment_name": experiment_name,
                    "ablation": {'layer-permutation': config['layers']},
                    "prompt": example["prompt"],
                    "answer": example["answer"],
                    "completion": completion,
                    "average_answer_token_loss": loss,
                }
            )]
    store.to_disk()
    return format_outputs(outputs)


def format_outputs(outputs):
    outputs_str = ""
    for i in outputs:
        outputs_str += f"## Ablation\n{i['ablation']}\n## Prompt\n{i['prompt']}\n##Completion\n{i['completion']}\n## Avg loss of: '{i['answer']}'\n{i['average_answer_token_loss']}\n---\n"
    return outputs_str


def test_permuted_layers_experiment():
    output = permuted_layers_experiment(
        configurations=[
            {"layers": list(range(i))}
            for i in [1, 5]
            # for i in range(1, len(model.blocks))
        ],
        experiment_name="Test experiment 1",
        task_name='Facts about the world'
    )
    print(output)

# print(test_permuted_layers_experiment())
outputs = store.items
print(format_outputs(outputs))

## Ablation
{'layer-permutation': [0]}
## Prompt
Madrid is the captial of
##Completion
<|endoftext|>Madrid is the captial of the same exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact
## Avg loss of: 'Spain'
16.864439010620117
---
## Ablation
{'layer-permutation': [0]}
## Prompt
The capital of Spain is
##Completion
<|endoftext|>The capital of Spain is not yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet yet 

In [3]:
from uuid import uuid4
def test_get_cross_entropy_loss(model, example):
    print(example)
    loss = get_cross_entropy_loss(model, example)
    corrupted_example = dict(**example)
    corrupted_example['answer'] = str(uuid4().hex)
    ref_loss = get_cross_entropy_loss(model, corrupted_example)
    print(f"loss: {loss:.4f} ref: {ref_loss:.4f} diff: {(loss - ref_loss):.4f}")


for task in tasks[-2:-1]:
    print(task)
    for example in task['examples']:
        test_get_cross_entropy_loss(model, example)

{'name': 'Common Phrase Completion', 'description': 'Can the model complete a common phrase or saying?', 'examples': [{'prompt': 'A penny for your', 'answer': 'thoughts.'}, {'prompt': 'An apple a day keeps the', 'answer': 'doctor away.'}, {'prompt': 'Better late than', 'answer': 'never.'}, {'prompt': 'Birds of a feather flock', 'answer': 'together.'}, {'prompt': 'Actions speak louder than', 'answer': 'words.'}, {'prompt': 'Every cloud has a silver', 'answer': 'lining.'}, {'prompt': 'Great minds think', 'answer': 'alike.'}, {'prompt': "It's raining cats and", 'answer': 'dogs.'}, {'prompt': 'Laughter is the best', 'answer': 'medicine.'}, {'prompt': 'Once bitten, twice', 'answer': 'shy.'}, {'prompt': 'Out of sight, out of', 'answer': 'mind.'}, {'prompt': 'The early bird catches the', 'answer': 'worm.'}, {'prompt': 'Time heals all', 'answer': 'wounds.'}, {'prompt': 'When in Rome, do as the', 'answer': 'Romans do.'}, {'prompt': "You can't judge a book by its", 'answer': 'cover.'}, {'prompt'

In [27]:
fields = set([])
for item in store.items:
    fields = fields.union(set(item.keys()))
fields = list(fields)
print("Fields: ", fields)
FieldEnum = Enum("FieldEnum", {field: field for field in fields})


class FilterCondition(BaseModel):
    key: str = Field(..., description="Key of the filter")
    value: str = Field(..., description="Value of the filter")

import pandas as pd

from typing import List
import pandas as pd

# Assuming FilterCondition and FieldEnum are defined elsewhere in your code
# If not, you'll need to define these or replace them with appropriate types

# Define the FilterCondition class
class FilterCondition(BaseModel):
    key: str = Field(..., description="Key of the filter")
    value: str = Field(..., description="Value of the filter")

# Define a function to flatten nested dictionaries and convert lists to strings
def flatten_dict(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = f'{parent_key}{sep}{k}' if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        elif isinstance(v, list):
            items.append((new_key, str(v)))  # Convert list to string
        else:
            items.append((new_key, v))
    return dict(items)

# Define the aggregate_results function with the correct groupby field
def aggregate_results(filters, groupby, select):
    # Convert filters to a dictionary
    filters_dict = {filter_condition.key: filter_condition.value for filter_condition in filters}
    
    # Filter the items using the store's filter method
    items = store.filter(**filters_dict)
    
    # Flatten the nested dictionaries in the items list
    flattened_items = [flatten_dict(item) for item in items]
    
    # Create a DataFrame from the flattened items
    df = pd.DataFrame(flattened_items)
    print(df)
    breakpoint()
    
    # Group the DataFrame by the specified 'groupby' fields
    grouped = df.groupby(groupby)
    
    # Initialize an empty list to store markdown strings for each group
    markdown_strings = []
    
    prev_group_keys = []
    # Iterate over each group
    for group_keys, group_df in grouped:
        # Ensure group_keys is a tuple for consistent formatting
        group_keys = (group_keys,) if isinstance(group_keys, str) else group_keys
        changed_keys = [f"{key}={i}" for key, i in zip(groupby, group_keys) if not i in prev_group_keys]
        print("changed_keys", changed_keys)
        prev_group_keys = group_keys
        # Create the headline with concatenated keys
        headline = '### '
        for k in changed_keys:
            headline += ' | '.join(map(str, k))
        
        # Create the headline with concatenated keys
        headline = '### ' + ' | '.join(map(str, group_keys))
        
        # Initialize a list to store selected field values for each row in the group
        selected_fields_text = []
        
        # Iterate over each row in the group to get the selected fields
        for _, row in group_df.iterrows():
            # Format the selected fields and their values
            fields = '\n'.join(f"{field}: {row[field]}" for field in select)
            selected_fields_text.append(fields)
        
        # Concatenate all selected fields for the group
        group_text = '\n\n'.join(selected_fields_text)
        
        # Combine the headline and the group text
        markdown_group = f"{headline}\n{group_text}"
        
        # Add the markdown for the current group to the list
        markdown_strings.append(markdown_group)
    
    # Concatenate all markdown strings for each group into a single markdown string
    markdown_output = '\n\n---\n\n'.join(markdown_strings)
    
    return markdown_output

# Test the aggregate_results function with the correct groupby field
print(aggregate_results(
    filters=[FilterCondition(key="experiment_name", value="Test experiment 1")],
    groupby=['prompt', 'ablation_layer-permutation'],
    select=['completion']
))

Fields:  ['ablation', 'experiment_name', 'average_answer_token_loss', 'task_name', 'answer', 'model', 'completion', 'prompt']
         model              task_name    experiment_name  \
0   gpt2-small  Facts about the world  Test experiment 1   
1   gpt2-small  Facts about the world  Test experiment 1   
2   gpt2-small  Facts about the world  Test experiment 1   
3   gpt2-small  Facts about the world  Test experiment 1   
4   gpt2-small  Facts about the world  Test experiment 1   
5   gpt2-small  Facts about the world  Test experiment 1   
6   gpt2-small  Facts about the world  Test experiment 1   
7   gpt2-small  Facts about the world  Test experiment 1   
8   gpt2-small  Facts about the world  Test experiment 1   
9   gpt2-small  Facts about the world  Test experiment 1   
10  gpt2-small  Facts about the world  Test experiment 1   
11  gpt2-small  Facts about the world  Test experiment 1   
12  gpt2-small  Facts about the world  Test experiment 1   
13  gpt2-small  Facts about the wo

In [9]:
store.items

[{'ablation': {'layer-permutation': [0]},
  'prompt': 'Madrid is the captial of',
  'answer': 'Spain',
  'completion': '<|endoftext|>Madrid is the captial of the same exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact exact',
  'average_answer_token_loss': 16.864439010620117},
 {'ablation': {'layer-permutation': [0]},
  'prompt': 'The capital of Spain is',
  'answer': 'Madrid',
  'completion': '<|endoftext|>The capital of Spain is not yet yet yet yet yet yet yet ye