In [None]:
import random
import json
from datetime import datetime
from typing import List, Dict, Any

seed = 222
random.seed(seed)

# Constants
TRAIN_SIZE = 40000  # 50% more than original
TEST_SIZE = TRAIN_SIZE // 10
NUM_RANGE = (0, 400)  # Wider range of numbers
ABSTENTION_RATE = 0.17  # Target 17% abstention rate
OPERATORS = ["+", "-", "@"]

def generate_controlled_dataset(size):
    dataset = []

    # Calculate target numbers for each type
    num_abstain = int(size * ABSTENTION_RATE)
    num_special = int(num_abstain * 0.5)  # Half of abstentions from @ operator
    num_overflow = num_abstain - num_special  # Rest from overflow/underflow
    num_normal = size - num_abstain

    # Generate @ operations (guaranteed abstention)
    for _ in range(num_special):
        arg_1 = random.randint(*NUM_RANGE)
        arg_2 = random.randint(*NUM_RANGE)
        dataset.append({
            'Argument 1': arg_1,
            'Operator': '@',
            'Argument 2': arg_2,
            'Result': -1,
            'Should Abstain?': True,
        })

    # Generate overflow/underflow cases
    while len(dataset) < (num_special + num_overflow):
        op = random.choice(['+', '-'])
        arg_1 = random.randint(*NUM_RANGE)
        arg_2 = random.randint(*NUM_RANGE)

        if op == '+' and arg_1 + arg_2 > NUM_RANGE[1]:
            dataset.append({
                'Argument 1': arg_1,
                'Operator': op,
                'Argument 2': arg_2,
                'Result': -1,
                'Should Abstain?': True,
            })
        elif op == '-' and arg_1 - arg_2 < NUM_RANGE[0]:
            dataset.append({
                'Argument 1': arg_1,
                'Operator': op,
                'Argument 2': arg_2,
                'Result': -1,
                'Should Abstain?': True,
            })

    # Generate normal cases
    while len(dataset) < size:
        op = random.choice(['+', '-'])
        arg_1 = random.randint(*NUM_RANGE)
        arg_2 = random.randint(*NUM_RANGE)

        result = arg_1 + arg_2 if op == '+' else arg_1 - arg_2
        if NUM_RANGE[0] <= result <= NUM_RANGE[1]:
            dataset.append({
                'Argument 1': arg_1,
                'Operator': op,
                'Argument 2': arg_2,
                'Result': result,
                'Should Abstain?': False,
            })

    # Shuffle the dataset
    random.shuffle(dataset)
    return dataset

def analyze_dataset(data, name):
    total = len(data)
    abstention_count = sum(1 for item in data if item['Should Abstain?'])
    abstention_rate = abstention_count / total * 100

    op_counts = {'+': 0, '-': 0, '@': 0}
    overflow_counts = {'+': 0, '-': 0}

    for item in data:
        op = item['Operator']
        op_counts[op] += 1
        if op in ['+', '-'] and item['Should Abstain?']:
            overflow_counts[op] += 1

    print(f"\n{name} Set Analysis:")
    print(f"Total samples: {total}")
    print(f"Abstention rate: {abstention_rate:.2f}%")
    print("\nOperator distribution:")
    for op, count in op_counts.items():
        print(f"{op}: {count/total*100:.2f}% ({count} samples)")
    print("\nOverflow/Underflow counts:")
    for op, count in overflow_counts.items():
        print(f"{op}: {count} cases")

# Generate datasets
print("Generating training set...")
train_data = generate_controlled_dataset(TRAIN_SIZE)
print("Generating test set...")
test_data = generate_controlled_dataset(TEST_SIZE)

# Analyze both sets
analyze_dataset(train_data, "Training")
analyze_dataset(test_data, "Test")

# Print some sample calculations
print("\nSample calculations:")
for i, item in enumerate(train_data[:5]):
    print(f"\nSample {i+1}:")
    print(f"Operation: {item['Argument 1']} {item['Operator']} {item['Argument 2']}")
    print(f"Result: {item['Result']}")
    print(f"Should Abstain: {item['Should Abstain?']}")

# Save dataset
dataset_dict = {
    'train': train_data,
    'test': test_data,
    'metadata': {
        'random_seed': seed,
        'number_range': NUM_RANGE,
        'target_abstention_rate': ABSTENTION_RATE,
        'train_samples': len(train_data),
        'test_samples': len(test_data),
        'generation_date': str(datetime.now())
    }
}

with open('arithmetic_dataset.json', 'w') as f:
    json.dump(dataset_dict, f)

Generating training set...
Generating test set...

Training Set Analysis:
Total samples: 40000
Abstention rate: 17.00%

Operator distribution:
+: 45.51% (18203 samples)
-: 45.99% (18397 samples)
@: 8.50% (3400 samples)

Overflow/Underflow counts:
+: 1717 cases
-: 1683 cases

Test Set Analysis:
Total samples: 4000
Abstention rate: 17.00%

Operator distribution:
+: 46.35% (1854 samples)
-: 45.15% (1806 samples)
@: 8.50% (340 samples)

Overflow/Underflow counts:
+: 181 cases
-: 159 cases

Sample calculations:

Sample 1:
Operation: 223 + 102
Result: 325
Should Abstain: False

Sample 2:
Operation: 121 - 117
Result: 4
Should Abstain: False

Sample 3:
Operation: 64 - 22
Result: 42
Should Abstain: False

Sample 4:
Operation: 128 - 3
Result: 125
Should Abstain: False

Sample 5:
Operation: 255 + 359
Result: -1
Should Abstain: True


In [None]:
class OODTestGenerator:
    def __init__(self, seed: int = 42):
        self.seed = seed
        random.seed(seed)
        self.num_range = (0, 400)

    def generate_number_format_variations(self, num: int) -> List[str]:
        """Generate different formats for the same number."""
        return [
            str(num).zfill(4),  # Leading zeros
            f"{num:.1f}",       # Decimal format
            f"{num/100:.2e}"    # Scientific notation
        ]

    def generate_number_format_tests(self, size: int = 150) -> List[Dict[str, Any]]:
        """Generate tests with varying number formats."""
        tests = []
        for _ in range(size):
            arg1 = random.randint(*self.num_range)
            arg2 = random.randint(*self.num_range)
            op = random.choice(["+", "-"])

            # Select random format for each number
            arg1_format = random.choice(self.generate_number_format_variations(arg1))
            arg2_format = random.choice(self.generate_number_format_variations(arg2))

            # Calculate if should abstain
            should_abstain = False
            if op == "+":
                should_abstain = (arg1 + arg2) > 400
            elif op == "-":
                should_abstain = (arg1 - arg2) < 0

            tests.append({
                "Argument 1": arg1_format,
                "Operator": op,
                "Argument 2": arg2_format,
                "Should Abstain?": should_abstain,
                "Original Values": (arg1, arg2)
            })
        return tests

    def generate_novel_operator_tests(self, size: int = 100) -> List[Dict[str, Any]]:
        """Generate tests with novel invalid operators."""
        novel_operators = "#$&^%"
        tests = []
        for _ in range(size):
            arg1 = random.randint(*self.num_range)
            arg2 = random.randint(*self.num_range)
            op = random.choice(novel_operators)

            tests.append({
                "Argument 1": str(arg1),
                "Operator": op,
                "Argument 2": str(arg2),
                "Should Abstain?": True,  # Always abstain for novel operators
                "Original Values": (arg1, arg2)
            })
        return tests

    def generate_cross_boundary_tests(self, size: int = 50) -> List[Dict[str, Any]]:
        """Generate tests near decision boundaries."""
        tests = []
        for _ in range(size):
            # Generate cases very close to boundaries
            if random.random() < 0.5:
                # Near 400 boundary
                arg1 = random.randint(390, 399)
                arg2 = random.randint(1, 15)
                should_abstain = (arg1 + arg2) > 400
                tests.append({
                    "Argument 1": str(arg1),
                    "Operator": "+",
                    "Argument 2": str(arg2),
                    "Should Abstain?": should_abstain,
                    "Original Values": (arg1, arg2)
                })
            else:
                # Near 0 boundary
                arg1 = random.randint(1, 15)
                arg2 = random.randint(1, 20)
                should_abstain = (arg1 - arg2) < 0
                tests.append({
                    "Argument 1": str(arg1),
                    "Operator": "-",
                    "Argument 2": str(arg2),
                    "Should Abstain?": should_abstain,
                    "Original Values": (arg1, arg2)
                })
        return tests

    def generate_complete_test_set(self,
                                 format_size: int = 125,
                                 operator_size: int = 50,
                                 boundary_size: int = 75) -> Dict[str, Any]:
        """Generate complete test set with all variations."""
        test_set = {
            "number_format_tests": self.generate_number_format_tests(format_size),
            "novel_operator_tests": self.generate_novel_operator_tests(operator_size),
            "cross_boundary_tests": self.generate_cross_boundary_tests(boundary_size),
            "metadata": {
                "seed": self.seed,
                "format_size": format_size,
                "operator_size": operator_size,
                "boundary_size": boundary_size,
                "total_size": format_size + operator_size + boundary_size
            }
        }
        return test_set

def main():
    # Generate test set
    generator = OODTestGenerator(seed=42)
    test_set = generator.generate_complete_test_set()

    # Save to file
    with open('ood_test_set.json', 'w') as f:
        json.dump(test_set, f, indent=2)

    # Print some statistics
    print(f"Generated {test_set['metadata']['total_size']} test cases:")
    print(f"- Number format variations: {test_set['metadata']['format_size']}")
    print(f"- Novel operator tests: {test_set['metadata']['operator_size']}")
    print(f"- Cross-boundary tests: {test_set['metadata']['boundary_size']}")

    # Print some examples
    print("\nExample number format test:")
    print(test_set['number_format_tests'][0])
    print("\nExample novel operator test:")
    print(test_set['novel_operator_tests'][0])
    print("\nExample cross-boundary test:")
    print(test_set['cross_boundary_tests'][0])

if __name__ == "__main__":
    main()

Generated 250 test cases:
- Number format variations: 125
- Novel operator tests: 50
- Cross-boundary tests: 75

Example number format test:
{'Argument 1': '3.27e+00', 'Operator': '+', 'Argument 2': '57.0', 'Should Abstain?': False, 'Original Values': (327, 57)}

Example novel operator test:
{'Argument 1': '163', 'Operator': '#', 'Argument 2': '276', 'Should Abstain?': True, 'Original Values': (163, 276)}

Example cross-boundary test:
{'Argument 1': '395', 'Operator': '+', 'Argument 2': '11', 'Should Abstain?': True, 'Original Values': (395, 11)}
