In [None]:
"""
MM2IM Mapper Debugger
Visualisasi CMap, OMap, dan Partial Sum Overlap untuk Transposed Convolution

Author: EEG Denoising Accelerator Team
Purpose: Debug dan validasi MM2IM mapping logic
"""

import numpy as np
import pandas as pd
from typing import Tuple, List, Dict
import matplotlib.pyplot as plt
import seaborn as sns

class MM2IMMapperDebugger:
    """
    Debugger untuk MM2IM Mapper
    
    Fungsi:
    - Generate CMap dan OMap untuk setiap row dan tile
    - Visualisasi mapping
    - Deteksi overlap (partial sum accumulation)
    - Export ke format yang mudah dibaca
    """
    
    def __init__(self, Ih: int, Ic: int, Oc: int, Oh: int, Ks: int, S: int, P: int):
        """
        Initialize layer parameters
        
        Args:
            Ih: Input height (jumlah positions)
            Ic: Input channels
            Oc: Output channels
            Oh: Output height
            Ks: Kernel size
            S: Stride
            P: Padding
        """
        self.Ih = Ih
        self.Ic = Ic
        self.Oc = Oc
        self.Oh = Oh
        self.Ks = Ks
        self.S = S
        self.P = P
        
        # Derived parameters
        self.N = Ks * Oc  # Total filter columns
        self.num_tiles = self.N // 16  # Assuming 16 PEs
        
        print(f"=== Layer Configuration ===")
        print(f"Input:  [{Ih}, {Ic}]")
        print(f"Output: [{Oh}, {Oc}]")
        print(f"Kernel: {Ks}, Stride: {S}, Padding: {P}")
        print(f"Total filter columns: {self.N}")
        print(f"Number of tiles: {self.num_tiles}")
        print()
    
    def calculate_h_pad(self, row_id: int) -> int:
        """Calculate h_pad for given row"""
        return -self.P + (self.S * row_id)
    
    def generate_cmap_omap(self, row_id: int, tile_id: int) -> Tuple[np.ndarray, np.ndarray, Dict]:
        """
        Generate CMap dan OMap untuk row dan tile tertentu
        
        Returns:
            cmap: Array 16-bit (1=valid, 0=skip)
            omap: Array 16×14-bit (BRAM select + address)
            info: Dictionary dengan informasi detail
        """
        h_pad = self.calculate_h_pad(row_id)
        
        cmap = np.zeros(16, dtype=int)
        omap = np.zeros(16, dtype=int)
        
        details = []
        
        for col in range(16):
            # Calculate output channel and kernel position
            oc_index = col // self.Ks
            channel = tile_id * 4 + oc_index
            k_pos = col % self.Ks
            
            # Calculate output position
            output_pos = h_pad + k_pos
            
            # Check validity
            is_valid = (output_pos >= 0) and (output_pos < self.Oh)
            
            # BRAM mapping
            if is_valid:
                bram_select = channel % 16
                tile_offset = channel // 16
                bram_addr = tile_offset * 64 + output_pos
                
                cmap[col] = 1
                omap[col] = (bram_select << 10) | bram_addr
            else:
                cmap[col] = 0
                omap[col] = 0x3FFF  # Invalid marker
                bram_select = -1
                tile_offset = -1
                bram_addr = -1
            
            # Store details
            details.append({
                'col': col,
                'oc': channel,
                'k_pos': k_pos,
                'output_pos': output_pos,
                'valid': is_valid,
                'bram_sel': bram_select,
                'tile_offset': tile_offset,
                'bram_addr': bram_addr,
                'omap_hex': f"0x{omap[col]:04X}"
            })
        
        info = {
            'row_id': row_id,
            'tile_id': tile_id,
            'h_pad': h_pad,
            'cmap_bin': ''.join(str(b) for b in cmap[::-1]),  # MSB first
            'cmap_hex': f"0x{int(''.join(str(b) for b in cmap[::-1]), 2):04X}",
            'valid_count': np.sum(cmap),
            'skip_count': 16 - np.sum(cmap),
            'efficiency': np.sum(cmap) / 16 * 100,
            'details': details
        }
        
        return cmap, omap, info
    
    def print_mapping(self, row_id: int, tile_id: int):
        """Print mapping dalam format tabel yang mudah dibaca"""
        cmap, omap, info = self.generate_cmap_omap(row_id, tile_id)
        
        print(f"\n{'='*80}")
        print(f"ROW {row_id}, TILE {tile_id}")
        print(f"{'='*80}")
        print(f"h_pad: {info['h_pad']}")
        print(f"CMap (binary): {info['cmap_bin']}")
        print(f"CMap (hex):    {info['cmap_hex']}")
        print(f"Valid: {info['valid_count']}/16 ({info['efficiency']:.1f}%)")
        print(f"Skip:  {info['skip_count']}/16")
        print()
        
        # Create DataFrame for better visualization
        df = pd.DataFrame(info['details'])
        df['status'] = df['valid'].apply(lambda x: '✓ VALID' if x else '✗ SKIP')
        
        print("Column Mapping:")
        print(df.to_string(index=False))
        print()
    
    def visualize_mapping(self, row_id: int, tile_id: int):
        """Visualisasi CMap dan OMap"""
        cmap, omap, info = self.generate_cmap_omap(row_id, tile_id)
        
        fig, axes = plt.subplots(1, 2, figsize=(14, 4))
        
        # Plot 1: CMap visualization
        colors = ['red' if c == 0 else 'green' for c in cmap]
        axes[0].bar(range(16), cmap, color=colors, alpha=0.7)
        axes[0].set_xlabel('Column')
        axes[0].set_ylabel('Valid (1) / Skip (0)')
        axes[0].set_title(f'CMap - Row {row_id}, Tile {tile_id}\n{info["cmap_hex"]}')
        axes[0].set_xticks(range(16))
        axes[0].grid(axis='y', alpha=0.3)
        
        # Plot 2: BRAM distribution
        df = pd.DataFrame(info['details'])
        valid_df = df[df['valid']]
        if len(valid_df) > 0:
            bram_counts = valid_df['bram_sel'].value_counts().sort_index()
            axes[1].bar(bram_counts.index, bram_counts.values, color='steelblue', alpha=0.7)
            axes[1].set_xlabel('BRAM ID')
            axes[1].set_ylabel('Number of Writes')
            axes[1].set_title(f'BRAM Distribution - Row {row_id}, Tile {tile_id}')
            axes[1].set_xticks(range(16))
            axes[1].grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def detect_overlaps(self, tile_id: int) -> pd.DataFrame:
        """
        Deteksi overlap (partial sum accumulation) untuk semua rows dalam tile
        
        Returns:
            DataFrame dengan informasi overlap positions
        """
        # Dictionary to track writes: (bram_id, addr) -> [list of (row, col)]
        writes = {}
        
        for row_id in range(self.Ih):
            cmap, omap, info = self.generate_cmap_omap(row_id, tile_id)
            
            for detail in info['details']:
                if detail['valid']:
                    bram_id = detail['bram_sel']
                    addr = detail['bram_addr']
                    key = (bram_id, addr)
                    
                    if key not in writes:
                        writes[key] = []
                    writes[key].append((row_id, detail['col'], detail['oc'], detail['output_pos']))
        
        # Find overlaps (positions written by multiple rows)
        overlaps = []
        for (bram_id, addr), write_list in writes.items():
            if len(write_list) > 1:
                overlaps.append({
                    'bram_id': bram_id,
                    'address': addr,
                    'num_writes': len(write_list),
                    'rows': [w[0] for w in write_list],
                    'columns': [w[1] for w in write_list],
                    'channels': [w[2] for w in write_list],
                    'output_pos': write_list[0][3]
                })
        
        if overlaps:
            df = pd.DataFrame(overlaps)
            df = df.sort_values(['bram_id', 'address'])
            return df
        else:
            return pd.DataFrame()
    
    def print_overlap_summary(self, tile_id: int):
        """Print summary overlap untuk tile"""
        print(f"\n{'='*80}")
        print(f"OVERLAP DETECTION - TILE {tile_id}")
        print(f"{'='*80}")
        
        overlaps = self.detect_overlaps(tile_id)
        
        if len(overlaps) == 0:
            print("No overlaps detected (unexpected for transposed convolution!)")
        else:
            print(f"Found {len(overlaps)} overlap positions")
            print(f"\nOverlap Details:")
            print(overlaps.to_string(index=False))
            
            print(f"\n{'='*80}")
            print("Overlap Statistics:")
            print(f"  Total overlap positions: {len(overlaps)}")
            print(f"  Average writes per overlap: {overlaps['num_writes'].mean():.2f}")
            print(f"  Max writes to single position: {overlaps['num_writes'].max()}")
        print()
    
    def export_all_mappings(self, output_file: str = "mm2im_mappings.csv"):
        """Export semua mappings ke CSV untuk debugging"""
        all_data = []
        
        for tile_id in range(min(4, self.num_tiles)):  # First 4 tiles
            for row_id in range(self.Ih):
                cmap, omap, info = self.generate_cmap_omap(row_id, tile_id)
                
                for detail in info['details']:
                    all_data.append({
                        'row_id': row_id,
                        'tile_id': tile_id,
                        'h_pad': info['h_pad'],
                        **detail
                    })
        
        df = pd.DataFrame(all_data)
        df.to_csv(output_file, index=False)
        print(f"Exported {len(df)} mappings to {output_file}")
        return df
    
    def visualize_full_layer(self, tile_id: int):
        """Visualisasi lengkap untuk seluruh layer (semua rows, satu tile)"""
        fig, ax = plt.subplots(figsize=(12, 8))
        
        # Create matrix: rows × columns
        validity_matrix = np.zeros((self.Ih, 16))
        
        for row_id in range(self.Ih):
            cmap, _, _ = self.generate_cmap_omap(row_id, tile_id)
            validity_matrix[row_id, :] = cmap
        
        # Plot heatmap
        sns.heatmap(validity_matrix, cmap='RdYlGn', cbar_kws={'label': 'Valid (1) / Skip (0)'},
                    xticklabels=[f'C{i}' for i in range(16)],
                    yticklabels=[f'R{i}' if i % 4 == 0 else '' for i in range(self.Ih)],
                    ax=ax, vmin=0, vmax=1, linewidths=0.5, linecolor='gray')
        
        ax.set_xlabel('Column')
        ax.set_ylabel('Row')
        ax.set_title(f'CMap Validity Matrix - Tile {tile_id}\n(Green=Valid, Red=Skip)')
        plt.tight_layout()
        plt.show()
        
        # Statistics
        total_computations = self.Ih * 16
        valid_computations = np.sum(validity_matrix)
        skip_computations = total_computations - valid_computations
        
        print(f"\nTile {tile_id} Statistics:")
        print(f"  Total possible computations: {total_computations}")
        print(f"  Valid computations: {int(valid_computations)} ({valid_computations/total_computations*100:.1f}%)")
        print(f"  Skipped computations: {int(skip_computations)} ({skip_computations/total_computations*100:.1f}%)")


# =============================================================================
# USAGE EXAMPLES
# =============================================================================

# Layer d1 parameters
print("LAYER d1: DeconvBlock1D(256, 128, k=4, s=2, p=1)")
print("="*80)

debugger = MM2IMMapperDebugger(
    Ih=32,   # Input height
    Ic=256,  # Input channels
    Oc=128,  # Output channels
    Oh=64,   # Output height
    Ks=4,    # Kernel size
    S=2,     # Stride
    P=1      # Padding
)

# Example 1: Mapping untuk Row 0, Tile 0
print("\n### EXAMPLE 1: Row 0, Tile 0 ###")
debugger.print_mapping(row_id=0, tile_id=0)

# Example 2: Visualisasi Row 0, Tile 0
print("\n### EXAMPLE 2: Visualisasi Row 0, Tile 0 ###")
debugger.visualize_mapping(row_id=0, tile_id=0)

# Example 3: Row 1, Tile 0 (untuk melihat perbedaan)
print("\n### EXAMPLE 3: Row 1, Tile 0 ###")
debugger.print_mapping(row_id=1, tile_id=0)
debugger.visualize_mapping(row_id=1, tile_id=0)

# Example 4: Overlap detection untuk Tile 0
print("\n### EXAMPLE 4: Overlap Detection (Partial Sum) ###")
debugger.print_overlap_summary(tile_id=0)

# Example 5: Visualisasi seluruh layer
print("\n### EXAMPLE 5: Full Layer Visualization ###")
debugger.visualize_full_layer(tile_id=0)

# Example 6: Export ke CSV
print("\n### EXAMPLE 6: Export Mappings ###")
df_all = debugger.export_all_mappings("layer_d1_mappings.csv")
print(df_all.head(20))
