# DSL development

## Imports

In [None]:
# Use this to reload changes in python scripts
%reload_ext autoreload
%autoreload 1
%aimport arc25.dsl, arc25.training_tasks, arc25.input_generation
%aimport -matplotlib, -matplotlib.pyplot

In [None]:
import random
import logging
import numpy as np
from IPython.display import Markdown, display
import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm.auto import tqdm

from arc25.dsl import *
from arc25.training_tasks import *
from arc25.input_generation import *
from arc25.plot import plot_task, plot_grid, plot_grids_with_shape
from arc25.logging import configure_logging

#configure_logging(level=logging.DEBUG)
configure_logging(level=logging.INFO)
logging.getLogger('matplotlib').setLevel(logging.CRITICAL)


plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 3)  
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 12

## Individual tasks

- Color should be different when touching other objects
- When monochrome there can't be contact

In [None]:
#task_generator = HighlightRectangles(property='is_square')
task_generator = NormalizeImgsWithDifferentBackgroundColor()
for _ in range(3):
    task = task_generator.sample()
    display(Markdown("```python\n" + task.code + "\n```"))
    plot_task(task); plt.show()

In [None]:
[task_generator.sample() for _ in tqdm(range(100), smoothing=0)];

## All tasks

In [None]:
generator = training_tasks_generator()
for _ in tqdm(range(3)):
    task = next(generator)
    display(Markdown("```python\n" + task.code + "\n```"))
    plot_task(task); plt.show()

In [None]:
[next(generator) for _ in tqdm(range(1000), smoothing=0)];

- Without validation I'm able to generate 2500 tasks per second. When adding code validation I generate around 1000 tasks per second.
- Downscale is slow

## All tasks analysis

### Code

In [None]:
def analyze_dsl_function_usage(tasks):
    dsl_function_names = _get_dsl_function_names()
    counts = {function_name: sum(1 for task in tasks if f'{function_name}(' in task.code)
              for function_name in dsl_function_names}
    dsl_function_names = sorted(dsl_function_names, key=lambda x: counts[x], reverse=True)
    print(f'There are {len(dsl_function_names)} DSL functions defined in arc25.dsl:')
    print(f"\tDSL functions used in {len(tasks)} tasks:")
    for function_name in dsl_function_names:
        print(f"{function_name:30} {counts[function_name]:5} times")
    

def _get_dsl_function_names(dsl_module_name='arc25.dsl'):
    dsl_function_names = [
        name for name, cls in inspect.getmembers(sys.modules[dsl_module_name], inspect.isfunction)
        if cls.__module__ == dsl_module_name
        and not name.startswith('_')
    ]
    return dsl_function_names

In [None]:
def analyze_dsl_attributes_usage(tasks):
    dsl_classes = _get_dsl_function_classes()
    attributes = {cls.__name__: [attribute for attribute in dir(cls) if not attribute.startswith('_')]
                  for cls in dsl_classes}
    unique_attributes = set(attr for attrs in attributes.values() for attr in attrs)
    counts = {attr: sum(1 for task in tasks if f'.{attr}' in task.code)
              for attr in unique_attributes}
    unique_attributes = sorted(unique_attributes, key=lambda x: counts[x], reverse=True)
    print(f'There are {len(unique_attributes)} DSL attributes defined in arc25.dsl:')
    print(f"\tDSL attributes used in {len(tasks)} tasks:")
    for attr in unique_attributes:
        possible_classes = [cls for cls, attrs in attributes.items() if attr in attrs]
        print(f"{attr:30} {counts[attr]:5} times ({', '.join(possible_classes)})")
    
def _get_dsl_function_classes(dsl_module_name='arc25.dsl'):
    dsl_function_names = [
        cls for name, cls in inspect.getmembers(sys.modules[dsl_module_name], inspect.isclass)
        if cls.__module__ == dsl_module_name
        and not name.startswith('_')
        and name != 'Img'
    ]
    return dsl_function_names


### Analysis

In [None]:
generator = training_tasks_generator()
print(f'There are {len(_get_dsl_function_names())} DSL functions defined in arc25.dsl:')
tasks = [next(generator) for _ in tqdm(range(1000), smoothing=0)];

In [None]:
analyze_dsl_function_usage(tasks)

In [None]:
analyze_dsl_attributes_usage(tasks)

## Visualize generators

In [None]:
imgs = []
for _ in range(3):
    imgs.append(generate_arc_image_with_random_objects(
        image_shape=(10, 10),
        n_objects=5,
        allowed_sizes=(5, 6),
        random_shape_probability=1.0,
    )[0])
plot_grids_with_shape(imgs)

In [None]:
imgs = []
for _ in range(3):
    imgs.append(generate_arc_image_with_random_objects(
        image_shape=(10, 10),
        n_objects=7,
        allowed_sizes=[3, 4, 6, 8, 9],
        random_shape_probability=0.25,
        line_shape_probability=0.5,
    )[0])
plot_grids_with_shape(imgs)