In [None]:
import warnings
warnings.filterwarnings('ignore')

from dualneuron.screening.sets import ImagenetImages
from dualneuron.screening.utils import load_poles
from dualneuron.dream.axis import semantic_axis
import dualneuron

import numpy as np
import os
from pathlib import Path
import matplotlib.pyplot as plt

from dotenv import load_dotenv
load_dotenv()

token = os.getenv("HF_TOKEN")
data_dir = os.getenv("DATA_DIR")

In [None]:
package_dir = Path(dualneuron.__file__).parent
mask_path = package_dir / "twins" / "V4ColorTaskDriven" / "mask.npy"
mask = np.load(mask_path)

In [None]:
dset = ImagenetImages(
    data_dir=data_dir + "datasets",
    token=token,
    split="train",
    use_center_crop=True,
    use_resize_output=True,
    use_grayscale=False,
    use_normalize=False,
    use_mask=True,
    use_norm=True,
    use_clip=True,
    mask=mask,
    num_channels=3,
    output_size=(224, 224),
    crop_size=236,
    bg_value=0.0,
    clip_min=0.0,
    clip_max=1.0,
    norm=80.0,
)

In [None]:
lais, mais = load_poles(
    neuron_id=6,
    dset=dset,
    idx_dir=data_dir + 'v4_imagenet_ordered_indices',
    k=10,
    pole='both'
)

In [None]:
from dreamsim import dreamsim

dreamsim_model, _ = dreamsim(
    pretrained=True, 
    device='cuda',
    dreamsim_type='dinov2_vitb14',
    cache_dir=data_dir + 'dreamsim_models'
)

axis = semantic_axis(mais, lais, dreamsim_model)

In [None]:
fig, axs = plt.subplots(2, 10, figsize=(20, 4))

for i in range(10):
    axs[0, i].imshow(lais[i].permute(1, 2, 0))
    axs[0, i].axis('off')
    axs[0, i].set_title(f'LAI {i+1}')

    axs[1, i].imshow(mais[i].permute(1, 2, 0))
    axs[1, i].axis('off')
    axs[1, i].set_title(f'MAI {i+1}')
    
plt.tight_layout()
plt.show()

In [None]:
from dualneuron.twins.nets import V4ColorTaskDriven
from dualneuron.twins.activations import WrapLayer, model_summary

from dualneuron.synthesis.ascend import fourier_ascending
from dualneuron.synthesis.visualize import (
    blend, plot_group, plot_poles, sequence_animation
)

import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
function = V4ColorTaskDriven(centered=True, ensemble=False).eval().to(device)

In [None]:
layer = WrapLayer(model=function, layer=function)
shape = layer(torch.randn(1, 3, 100, 100).to(device)).shape
units = np.random.choice(shape[1], size=30, replace=False)

lr = 1.0
nb_crops = 16
norm = 40.0

act_weight = -1.0
sim_weights = [-2.0, -1.0, 1.0, 2.0]
results = []

for sim_weight in sim_weights:
    result = fourier_ascending(
        lambda x: act_weight * torch.mean(layer(x)[:, 35]),
        magnitude_path='natural_rgb.npy',
        image_size=None,
        init_image=None,
        total_steps=128,
        learning_rate=lr,
        lr_schedule=True,
        noise=0.0,
        values_range=(-2.0, 2.0),
        range_fn='sigmoid',
        nb_crops=nb_crops,
        box_size=(1.0, 1.0),
        target_norm=norm,
        tv_weight=0.0,
        jitter_std=0.2,
        oversample=1, 
        reflect_pad_frac=0.2,
        simulation_function=lambda x: dreamsim_model.embed(x),
        simulation_axis=None, 
        simulation_weight=sim_weight,
        device='cuda',
        verbose=True,
        save_all_steps=True,
    )
    results.append(result)
    
layer.remove()

In [None]:
poles = []

for result in results:
    image = result['image'][-1]
    alpha = result['alpha'][-1]
    pole = blend(image, alpha, imagecut=0.0, alphacut=90.0, boost=1.2)
    poles.append(pole)

plot_group(poles, cols=4)

In [None]:
sequence = []

for img, alp in zip(results[0]['image'], results[0]['alpha']):
    pole = blend(img, alp, imagecut=0.0, alphacut=90.0, boost=1.2)
    sequence.append(pole)

sequence_animation(
    sequence, 
    np.array(result['activation'])*10,
    savename=None, 
    dpi=50,
    interval=50,
    title="Optimization Progress"
)