In [None]:
import warnings
warnings.filterwarnings('ignore')

from dualneuron.twins.nets import load_model, model_summary
from dualneuron.synthesis.ascend import fourier_ascending, pixel_ascending
from dualneuron.synthesis.visualize import (
    blend, plot_group, plot_poles, sequence_animation
)

from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model, hooks = model_summary(
    model='v4', 
    input_size=(1, 3, 100, 100), 
    device='cuda'
)

In [None]:
function = load_model(
    architecture='v4', 
    layer=None, 
    ensemble=False, 
    centered=True, 
    untrained=False,
    device='cuda'
)

In [None]:
results = []
lr = 0.5
nb_crops = 16
norm = None

for neuron in tqdm([4, 6]):
    for weight in [-1, 1]:
        result = fourier_ascending(
            lambda x: weight * torch.mean(function(x)[:, neuron]),
            magnitude_path='natural_rgb.npy',
            image_size=None,
            # init_image=results[-1]['image'][-1] if weight == 1 else None,
            total_steps=128, 
            learning_rate=lr,
            lr_schedule=True,
            eta_min=0.0,
            noise=0.1,
            values_range=(-2.0, 2.0),
            range_fn='tanh',
            nb_crops=nb_crops,
            box_size=(0.1, 1.0),
            target_norm=norm,
            tv_weight=0.0,
            jitter_std=0.15,
            oversample=1, 
            reflect_pad_frac=0.15,
            device='cuda',
            verbose=False,
            save_all_steps=True,
        )
        results.append(result)

In [None]:
poles = []

for result in results:
    image = result['image'][-1]
    alpha = result['alpha'][-1]
    pole = blend(image, alpha, imagecut=0.0, alphacut=90.0, boost=1.0)
    poles.append(pole)

plot_group(poles, cols=2)

In [None]:
sequence = []

for img, alp in zip(results[1]['image'], results[1]['alpha']):
    pole = blend(img, alp, imagecut=0.0, alphacut=90.0, boost=1.)
    sequence.append(pole)

sequence_animation(
    sequence, 
    np.array(results[-1]['activation'])*10,
    savename=None, 
    dpi=50,
    interval=50,
    title="Optimization Progress"
)