In [None]:
import os
import pandas as pd
import json
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter, defaultdict
import cv2

In [None]:
def load_json_data(base_dir):
    """Load JSON data from directory structure and return DataFrame"""
    data = []

    for root, dirs, files in os.walk(base_dir):
        for file in files:
            if file.endswith('.json'):
                file_path = os.path.join(root, file)
                try:
                    with open(file_path, 'r') as f:
                        json_data = json.load(f)
                        json_data['file_path'] = file_path
                        data.append(json_data)
                except Exception as e:
                    print(f"Error loading {file_path}: {e}")

    return pd.DataFrame(data)

In [None]:
def get_channel_attribution(df, channel_name, frequency_threshold=0.7):
    """Analyze region+change combinations for a channel and determine attribution"""
    channel_data = df[df['channel'] == channel_name]

    if len(channel_data) == 0:
        return None

    # Count region+change combinations
    region_change_counts = defaultdict(int)
    total_samples = len(channel_data)

    for _, row in channel_data.iterrows():
        region = row.get('region', 'unknown')
        change = row.get('change_type', 'unknown')
        combination = f"{region}+{change}"
        region_change_counts[combination] += 1

    # Find most frequent combination
    most_frequent = max(region_change_counts.items(), key=lambda x: x[1])
    frequency = most_frequent[1] / total_samples

    attribution = {
        'channel': channel_name,
        'most_frequent_combination': most_frequent[0],
        'frequency': frequency,
        'total_samples': total_samples,
        'attributed': frequency >= frequency_threshold,
        'all_combinations': dict(region_change_counts)

In [None]:
def analyze_all_channels_attribution(df, frequency_threshold=0.7):
    """Process all channels and return attribution results"""
    channels = df['channel'].unique()
    results = []

    for channel in channels:
        attribution = get_channel_attribution(df, channel, frequency_threshold)
        if attribution:
            results.append(attribution)

    return results

In [None]:
def print_attribution_summary(attribution_results):
    """Print summary statistics of attribution analysis"""
    total_channels = len(attribution_results)
    attributed_channels = sum(1 for r in attribution_results if r['attributed'])

    print(f"Channel Attribution Summary")
    print(f"=" * 50)
    print(f"Total channels analyzed: {total_channels}")
    print(f"Channels with clear attribution: {attributed_channels}")
    print(f"Attribution rate: {attributed_channels/total_channels:.2%}")
    print()

    print("Channel Details:")
    print("-" * 50)
    for result in attribution_results:
        status = "✓" if result['attributed'] else "✗"
        print(f"{status} {result['channel']}: {result['most_frequent_combination']} "
              f"({result['frequency']:.2%},

In [None]:
def plot_channel_samples(df, channels=None, samples_per_channel=3):
    """Display sample images for each channel for annotation purposes"""
    if channels is None:
        channels = df['channel'].unique()[:6]  # Limit to first 6 channels

    fig, axes = plt.subplots(len(channels), samples_per_channel,
                            figsize=(12, 3*len(channels)))

    if len(channels) == 1:
        axes = axes.reshape(1, -1)

    for i, channel in enumerate(channels):
        channel_data = df[df['channel'] == channel]
        sample_data = channel_data.sample(min(samples_per_channel, len(channel_data)))

        for j, (_, row) in enumerate(sample_data.iterrows()):
            ax = axes[i, j] if len(channels) > 1 else axes[j]

            # Try to load and display image if path exists
            if 'image_path' in row and os.path.exists(row['image_path']):
                img = cv2.imread(row['image_path'])
                if img is not None:
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    ax.imshow(img)
                    ax.set_title(f"{channel}\n{row.get('region', 'N/A')}+{row.get('change_type', 'N/A')}")
                else:
                    ax.text(0.5, 0.5, 'Image not found', ha='center', va='center')
                    ax.set_title(f"{channel}")
            else:
                ax.text(0.5, 0.5, f"Sample {j+1}\n{row.get('region', 'N/A')}+{row.get('change_type', 'N/A')}",
                       ha='center', va='center')
                ax.set_title(f"{channel}")

            ax.axis('off')

    plt.tight_

In [None]:
def save_attribution_results(attribution_results, output_file='channel_attribution_results.csv'):
    """Save attribution results to CSV file"""
    results_df = pd.DataFrame([
        {
            'channel': r['channel'],
            'most_frequent_combination': r['most_frequent_combination'],
            'frequency': r['frequency'],
            'total_samples': r['total_samples'],
            'attributed': r['attributed']
        }
        for r in attribution_results
    ])

    results_df.to_csv(output_file, index=False)
    print(f"Results saved to {output_file}")

    return results_df

In [None]:
def main(data_directory, frequency_threshold=0.7):
    """Main function to orchestrate the