# Direction-Following Dataset Generator Tutorial

This tutorial demonstrates how to use the two/four-directional following dataset generator to create custom datasets for testing directional reasoning capabilities. The generator creates stories about actors moving in different directions, and tasks to determine whether two actors end up facing the same direction.


In [None]:
import os
import pickle
import json
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from collections import defaultdict
import glob
from data.dataset_generator import *

## 1. Generating Datasets

The direction-following dataset consists of stories where actors move in various directions and interact with each other. Each story ends with a query asking if two selected actors are facing the same direction.

You can generate datasets in three ways:
1. **Single dataset**: Generate a dataset with specific parameters (density, actor counts, etc.)
2. **Experiment datasets**: Generate all datasets defined in the experiment
3. **Merge datasets**: Combine multiple datasets with the same direction type

Below, we'll demonstrate how to generate all experiment datasets and then how to merge them by direction type.

In [None]:

random.seed(42)

# define desired output directory
output_dir = "./datasets_new"
os.makedirs(output_dir, exist_ok=True)

"""
# Example: Generate a single dataset for a specific density 
print("Generating simple_4dir dataset...")
dataset_path = generate_complete_dataset(
    train_size=200,
    valid_size=50,
    validB_size=100,
    test_size=100,
    n_directions=4,
    target_densities=[0.26],  # simple density
    balance_mode="actor_only",
    output_dir=output_dir,
    output_name="simple_4dir"
)
"""

# Example : Generate all datasets with stats matching the experiment
print("\nGenerating all datasets...")
generate_experiment_datasets(output_dir=output_dir)

print("\nAll datasets generated successfully!")

## 2. Inspecting Generated Datasets

After generating the datasets, we should verify they have the expected properties:
- Correct number of stories in each split
- Proper distribution of actor counts
- Balanced density categories
- Appropriate number of positive/negative examples

The function below checks basic statistics for each dataset file.

In [None]:

def check_dataset_file(file_path):
    """Print summary statistics for all dataset files"""
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    
    # Track stories by split
    train_count = 0
    valid_count = 0
    validb_count = 0
    test_count = 0
    
    for key, group in data.items():
        if isinstance(key, tuple) and len(key) > 0:
            split = key[0].lower() if isinstance(key[0], str) else None
            
            # Count stories in this group
            total = len(group["pos"]) + len(group["neg"])
            
            if split == 'train':
                train_count += total
            elif split == 'valid':
                valid_count += total
            elif split == 'validb':
                validb_count += total
            elif split == 'test':
                test_count += total
    
    print(f"Dataset file: {os.path.basename(file_path)}")
    print(f"  Total stories: {train_count + valid_count + validb_count + test_count}")
    print(f"  Train: {train_count} stories")
    print(f"  Valid: {valid_count} stories")
    print(f"  ValidB: {validb_count} stories")
    print(f"  Test: {test_count} stories")
    print()

dataset_dir = "./datasets_new"
for dataset_file in os.listdir(dataset_dir):
    if dataset_file.endswith('.pkl') and not dataset_file.startswith('all_'):
        check_dataset_file(os.path.join(dataset_dir, dataset_file))

## 3. Merging Datasets by Direction Type

To train the models on all densities combined, using the predetermined train/valid indices, we want to combine all datasets of the same direction type (2dir or 4dir).
The `merge_datasets_by_direction` function combines multiple datasets while maintaining split information.

This creates combined files like:
- `all_2dir.pkl` - All 2-direction datasets combined
- `all_4dir.pkl` - All 4-direction datasets combined

It also generates appropriate index files for each split.

In [None]:
# merging separate density datasets into complete all_2dir and all_4dir datasets

output_dir = "./datasets_new"
os.makedirs(output_dir, exist_ok=True)

print("Available datasets:")
all_dataset_files = glob.glob(os.path.join(output_dir, "*.pkl"))
all_dataset_files = [f for f in all_dataset_files if not os.path.basename(f).startswith("all_") and not os.path.basename(f).endswith("metadata.pkl")]
for dataset_file in all_dataset_files:
    print(f" - {os.path.basename(dataset_file)}")

# merge 2dir datasets
print("\nMerging 2-directional datasets:")
config_2dir = merge_datasets_by_direction(output_dir, "2dir")

# merge 4dir datasets
print("\nMerging 4-directional datasets:")
config_4dir = merge_datasets_by_direction(output_dir, "4dir")

print("\nSummary of merged datasets:")
if config_2dir:
    print(f"Created all_2dir dataset:")
    print(f" - Stories file: {os.path.basename(config_2dir['stories_file'])}")
    print(f" - Train indices: {os.path.basename(config_2dir['train_indices_file'])}")
    print(f" - ValidA indices: {os.path.basename(config_2dir['valid_indices_file'])}")
    print(f" - ValidB indices: {os.path.basename(config_2dir['validb_indices_file'])}")
    print(f" - Test indices: {os.path.basename(config_2dir['test_indices_file'])}")

if config_4dir:
    print(f"\nCreated all_4dir dataset:")
    print(f" - Stories file: {os.path.basename(config_4dir['stories_file'])}")
    print(f" - Train indices: {os.path.basename(config_4dir['train_indices_file'])}")
    print(f" - ValidA indices: {os.path.basename(config_4dir['valid_indices_file'])}")
    print(f" - ValidB indices: {os.path.basename(config_4dir['validb_indices_file'])}")
    print(f" - Test indices: {os.path.basename(config_4dir['test_indices_file'])}")

## 4. Detailed Dataset Analysis

Now let's examine the actor count distribution in more detail.
This is important to verify that our dataset generation maintained the expected actor ranges:
- Train/Valid: 2-8 actors
- ValidB: 9-20 actors  
- Test: 21-30 actors

We'll use the `inspect_dataset_actor_counts` function to check each dataset, then 
look at a specific dataset (simple_4dir) in more detail.

In [None]:

from debugging_functions import inspect_dataset_actor_counts

output_dir = "./datasets_new"
os.makedirs(output_dir, exist_ok=True)

print("Available datasets:")
all_dataset_files = glob.glob(os.path.join(output_dir, "*.pkl"))
# filter out merged datasets and metadata files
dataset_files = [f for f in all_dataset_files if not os.path.basename(f).startswith("all_") and os.path.basename(f).endswith(".pkl")]
for dataset_file in sorted(dataset_files):
    print(f" - {os.path.basename(dataset_file)}")
    inspect_dataset_actor_counts(dataset_file)

simple_4dir_path = os.path.join(output_dir, "simple_4dir.pkl")
if os.path.exists(simple_4dir_path):
    print("\nInspecting simple_4dir dataset in detail:")
    
    with open(simple_4dir_path, 'rb') as f:
        simple_data = pickle.load(f)
    
    valid_actors = []
    for key, group in simple_data.items():
        if isinstance(key, tuple) and len(key) > 0 and key[0] == 'valid':
            for story_type in ['pos', 'neg']:
                for story in group[story_type]:
                    if len(story) > 2 and isinstance(story[2], dict):
                        if 'num_actors' in story[2]:
                            valid_actors.append(story[2]['num_actors'])
    
    print("Actor count distribution in simple_4dir validation set:")
    for actor_count in range(2, 9):
        count = valid_actors.count(actor_count)
        print(f"  - {actor_count} actors: {count} stories")

## 5. Verifying Merged Datasets

Finally, let's verify that our merged datasets maintain the correct split information and actor ranges.
The `check_merged_dataset_stats` function validates:
- Index file integrity
- Actor counts for each split
- Category distribution 
- Whether stories are within their expected actor ranges

This helps ensure the merged datasets are properly structured for experiments.

In [None]:
def check_merged_dataset_stats(stories_file, indices_files, all_stories_file=None):
    """
    Check statistics of a merged dataset and its indices files.
    
    Args:
        stories_file (str): Path to the merged dataset pickle file
        indices_files (dict): Dictionary of split names to index file paths
        all_stories_file (str, optional): Path to the all_stories list file for direct indexing
    """
    print(f"\nChecking statistics for {os.path.basename(stories_file)}")
    
    with open(stories_file, 'rb') as f:
        combined_stories = pickle.load(f)
    
    all_stories = []
    if all_stories_file and os.path.exists(all_stories_file):
        print(f"Using all_stories from {os.path.basename(all_stories_file)}")
        with open(all_stories_file, 'rb') as f:
            all_stories = pickle.load(f)
    else:
        # reconstruct from combined_stories
        print("Reconstructing all_stories from combined dataset")
        for key, group in combined_stories.items():
            for story_type in ["pos", "neg"]:
                for story in group[story_type]:
                    all_stories.append(story)
    
    story_metadata = {}
    for idx, story in enumerate(all_stories):
        if len(story) > 2 and isinstance(story[2], dict):
            story_metadata[idx] = story[2]
    
    print(f"Total stories in dataset: {len(all_stories)}")
    
    # check each split's indices
    for split_name, indices_file in indices_files.items():
        if not os.path.exists(indices_file):
            print(f"  Warning: Indices file {indices_file} not found")
            continue
        
        # load indices
        with open(indices_file, 'r') as f:
            indices = json.load(f)
        
        # define expected actor ranges
        actor_ranges = {
            'train': (2, 8),
            'valid': (2, 8),
            'validB': (9, 20),
            'test': (21, 30)
        }
        min_actors, max_actors = actor_ranges.get(split_name.lower(), (0, 100))
        
        # count stories by actor count and category for this split
        split_counts = {
            'actor_counts': defaultdict(int),
            'category_counts': defaultdict(int),
            'valid_indices': 0,
            'invalid_indices': 0,
            'in_range': 0,
            'out_range': 0
        }
        
        # process each index
        for idx in indices:
            if idx < 0 or idx >= len(all_stories):
                split_counts['invalid_indices'] += 1
                continue
            
            split_counts['valid_indices'] += 1
            
            story = all_stories[idx]
            if len(story) > 2 and isinstance(story[2], dict):
                metadata = story[2]
                
                if 'num_actors' in metadata:
                    actor_count = metadata['num_actors']
                    split_counts['actor_counts'][actor_count] += 1
                    
                    # check if in expected range
                    if min_actors <= actor_count <= max_actors:
                        split_counts['in_range'] += 1
                    else:
                        split_counts['out_range'] += 1
                
                if 'category' in metadata:
                    category = metadata['category']
                    split_counts['category_counts'][category] += 1
        
        print(f"\n{split_name.upper()} SPLIT:")
        print(f"  Valid indices: {split_counts['valid_indices']} ({split_counts['valid_indices']/max(1, len(indices))*100:.1f}%)")
        
        if split_counts['invalid_indices'] > 0:
            print(f"  Invalid indices: {split_counts['invalid_indices']} ({split_counts['invalid_indices']/max(1, len(indices))*100:.1f}%)")
            
        print(f"  Actor range validity:")
        print(f"    - Within expected range ({min_actors}-{max_actors}): {split_counts['in_range']} ({split_counts['in_range']/max(1, split_counts['valid_indices'])*100:.1f}%)")
        print(f"    - Outside expected range: {split_counts['out_range']} ({split_counts['out_range']/max(1, split_counts['valid_indices'])*100:.1f}%)")
        
        print("  Actor counts:")
        for actor_count in sorted(split_counts['actor_counts'].keys()):
            count = split_counts['actor_counts'][actor_count]
            in_range = min_actors <= actor_count <= max_actors
            range_status = "✓" if in_range else "✗"
            print(f"    - {actor_count} actors: {count} stories ({count/max(1, split_counts['valid_indices'])*100:.1f}%) {range_status}")
        
        print("  Category distribution:")
        for category, count in sorted(split_counts['category_counts'].items()):
            print(f"    - {category}: {count} stories ({count/max(1, split_counts['valid_indices'])*100:.1f}%)")


check_merged_dataset_stats(
    "./datasets_new/all_4dir.pkl", 
    {
        'train': "./datasets_new/train_indices_all_4dir.json",
        'valid': "./datasets_new/valid_indices_all_4dir.json",
        'validB': "./datasets_new/validb_indices_all_4dir.json",
        'test': "./datasets_new/test_indices_all_4dir.json"
    },
    all_stories_file="./datasets_new/all_stories_4dir.pkl"
)