In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"

from dualneuron.screening.sets import ImagenetImages
from dualneuron.screening.run import screen_activations

from dualneuron.screening.visualize import (
    plot_population_statistics,
    plot_neuron_activation,
    plot_neuron_poles,
)

from dualneuron.screening.utils import compute_population_statistics
from dualneuron.twins.nets import model_summary

from pathlib import Path
import dualneuron
import numpy as np

import matplotlib.pyplot as plt

token = os.getenv("HF_TOKEN")
data_dir = os.getenv("DATA_DIR")

In [None]:
architecture = 'resnet50'
model, hooks = model_summary(architecture, input_size=(1, 3, 224, 224))

In [None]:
architecture = 'resnet50'
layer = 'layer3.1.relu'
dataset = "imagenet"

if False:
    screen_activations(
        data_dir + "datasets",
        output_dir=data_dir + "dryad",
        token=token,
        split='test',
        dataset=dataset,
        model=architecture,
        layer=layer,
        location='center',
        ensemble=True,
        batch_size=32,
        num_workers=0,
        device="cuda"
    )

In [None]:
model_name = "V4ColorTaskDriven"
package_dir = Path(dualneuron.__file__).parent
mask_path = package_dir / "twins" / model_name / "mask.npy"
mask = np.load(mask_path)

In [None]:
dset = ImagenetImages(
    data_dir=data_dir + "datasets",
    token=token,
    split='test',
    use_center_crop=True,
    use_resize_output=True,
    use_grayscale=False,
    use_normalize=False,
    use_mask=True,
    use_norm=False,
    use_clip=False,
    mask=mask,
    num_channels=3,
    output_size=(224, 224),
    crop_size=236,
    bg_value=0.0,
    norm=None,
)

In [None]:
idx_dir = data_dir + f"dryad/{architecture}_{layer}_{dataset}_ordered_indices"
resp_dir = data_dir + f"dryad/{architecture}_{layer}_{dataset}_ordered_responses"
num_neurons = len([f for f in os.listdir(resp_dir) if f.endswith('.npy')])
print(f"Number of .npy files in resp_dir: {num_neurons}")

In [None]:
response_stats, active_neurons = compute_population_statistics(resp_dir, sort_by='gini')
plot_population_statistics(response_stats)

In [None]:
unit = active_neurons.index(23)
plot_neuron_activation(active_neurons[unit], resp_dir, response_stats)
plot_neuron_poles(active_neurons[unit], dset, resp_dir, idx_dir)

In [None]:
def sampled_images_animation(
    neuron_id,
    dset,
    resp_dir,
    idx_dir,
    num_samples=100,
    savename='neuron_animation.mp4',
    fps=5,
    figsize=(10, 5),
    dpi=100,
    vmin=None,
    vmax=None
):
    """
    Create animation showing adaptively sampled images with their position on the activation curve.
    Uses OpenCV for fast video generation.
    """
    import cv2
    from io import BytesIO
    import torch
    import matplotlib
    matplotlib.use('Agg')  # Use non-interactive backend
    
    print("Loading data...")
    # Load ALREADY SORTED responses and indices
    sorted_responses = np.load(os.path.join(resp_dir, f"{neuron_id}.npy"))
    sorted_dataset_indices = np.load(os.path.join(idx_dir, f"{neuron_id}.npy"))
    
    # Adaptive sampling
    rng = np.random.default_rng(seed=num_samples)
    diffs = np.abs(np.diff(sorted_responses))
    probs = diffs / np.sum(diffs)
    
    sampled_transitions = rng.choice(
        len(probs),
        num_samples,
        p=probs,
        replace=False
    )
    sampled_positions = np.sort(sampled_transitions + 1)
    sampled_dataset_idx = sorted_dataset_indices[sampled_positions]
    sampled_activations = sorted_responses[sampled_positions]
    
    print(f"Loading {num_samples} images...")
    # Pre-load and process all images
    images = []
    for idx in sampled_dataset_idx:
        img, _ = dset[idx]
        
        if torch.is_tensor(img):
            img = img.cpu().numpy()
        
        if img.ndim == 3:
            if img.shape[0] in [1, 3]:
                img = np.transpose(img, (1, 2, 0))
            if img.shape[2] == 1:
                img = img.squeeze(2)
        
        # Normalize
        if vmin is not None and vmax is not None:
            img = np.clip(img, vmin, vmax)
            img = (img - vmin) / (vmax - vmin)
        else:
            img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        
        images.append(img)
    
    print("Generating frames...")
    # Generate each frame
    frames = []
    
    for frame_idx in range(num_samples):
        # Create figure for this frame
        fig = plt.figure(figsize=figsize, dpi=dpi, facecolor='black')
        
        # Image on left
        ax_img = fig.add_subplot(1, 2, 1, aspect='equal')
        ax_img.set_axis_off()
        ax_img.set_facecolor('black')
        
        ax_img.imshow(images[frame_idx], cmap='gray', vmin=0, vmax=1)
        
        current_activation = sampled_activations[frame_idx]
        current_position = sampled_positions[frame_idx]
        value_color = '#00d4ff' if current_activation >= 0 else '#ff0080'
        
        ax_img.text(
            0.5, 1.05, f'Sample {frame_idx + 1}/{num_samples}',
            transform=ax_img.transAxes,
            color='white', fontsize=10, ha='center', va='bottom',
            weight='bold'
        )
        
        ax_img.text(
            0.5, -0.05, f'{current_activation:.4f} Hz',
            transform=ax_img.transAxes,
            color=value_color, fontsize=10, ha='center', va='top',
            weight='bold',
            bbox=dict(
                boxstyle='round,pad=0.4',
                facecolor='#0a0a0a',
                edgecolor=value_color,
                linewidth=2
            )
        )
        
        # Curve on right
        ax_plot = fig.add_subplot(1, 2, 2)
        ax_plot.set_facecolor('#0a0a0a')
        
        # Plot full curve
        ax_plot.plot(sorted_responses, color='#00d4ff', linewidth=2, alpha=0.5, zorder=1)
        ax_plot.fill_between(
            range(len(sorted_responses)),
            sorted_responses,
            color='#00d4ff',
            alpha=0.2,
            zorder=1
        )
        
        # All sampled points (faded)
        ax_plot.scatter(
            sampled_positions,
            sampled_activations,
            c='#ff0080',
            s=30,
            alpha=0.3,
            zorder=2
        )
        
        # Visited points (brighter)
        if frame_idx > 0:
            ax_plot.scatter(
                sampled_positions[:frame_idx],
                sampled_activations[:frame_idx],
                c='#ff0080',
                s=50,
                alpha=0.6,
                zorder=3
            )
        
        # Current point glow
        ax_plot.scatter(
            current_position,
            current_activation,
            c=value_color,
            s=300,
            alpha=0.3,
            zorder=4
        )
        
        # Current point
        ax_plot.scatter(
            current_position,
            current_activation,
            c=value_color,
            s=150,
            marker='o',
            edgecolors='white',
            linewidths=2,
            zorder=5
        )
        
        ax_plot.set_xlabel('Sorted Image Index', color='white', fontsize=9, weight='bold')
        ax_plot.set_ylabel('Activation (Hz)', color='white', fontsize=9, weight='bold')
        ax_plot.tick_params(colors='white', labelsize=8)
        
        for spine in ax_plot.spines.values():
            spine.set_color('#00d4ff')
            spine.set_linewidth(2)
        
        ax_plot.grid(True, alpha=0.2, color='#00d4ff', linestyle='--', linewidth=0.8)
        
        fig.suptitle(
            f'Neuron {neuron_id}',
            color='white',
            fontsize=12,
            weight='bold'
        )
        
        plt.tight_layout()
        
        # Convert to array using savefig to buffer
        buf = BytesIO()
        fig.savefig(buf, format='raw', dpi=dpi, facecolor='black')
        buf.seek(0)
        
        # Read the buffer
        img_array = np.frombuffer(buf.getvalue(), dtype=np.uint8)
        w, h = fig.canvas.get_width_height()
        frame = img_array.reshape((int(h), int(w), -1))[:, :, :3]  # Take only RGB
        
        frames.append(frame)
        buf.close()
        plt.close(fig)
        
        if (frame_idx + 1) % 10 == 0:
            print(f"  Generated {frame_idx + 1}/{num_samples} frames...")
    
    print(f"Writing video to {savename}...")
    # Write video with OpenCV
    height, width = frames[0].shape[:2]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(savename, fourcc, fps, (width, height))
    
    for frame in frames:
        # Convert RGB to BGR for OpenCV
        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame_bgr)
    
    out.release()
    print(f"Done! Saved to {savename}")
    print(f"Duration: {num_samples / fps:.1f} seconds at {fps} fps")

In [None]:
sampled_images_animation(
    neuron_id=4,
    dset=dset,
    resp_dir=resp_dir,
    idx_dir=idx_dir,
    num_samples=128,
    fps=5,
    savename='neuron_4.mp4',
    dpi=100
)