In [None]:
from init_notebook import *

In [None]:
from torchvision import models 

In [None]:
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1, progress=True)

VF.to_pil_image(get_model_weight_images(model, normalize="each"))

In [None]:
import ipywidgets

def optimize_image(
    image: torch.Tensor,
    model: nn.Module,
    target: torch.Tensor,
    steps: int = 1000,
    batch_size: int = 4,
    device: str = "auto",
    rand_degrees: float = 5.,
    rand_translate: float = 0.1,
):
    output_widget = ipywidgets.Output()
    display(output_widget)
    
    device = to_torch_device(device)
    image = nn.Parameter(image.clone().to(device))
    trans = VT.Compose([
        #VT.Resize((96, 96), interpolation=VT.InterpolationMode.BILINEAR),
        VT.RandomAffine(degrees=rand_degrees, translate=(rand_translate, rand_translate)),
    ])
    
    target = target.unsqueeze(0).repeat(batch_size, 1).to(device)
    model.to(device)
    optim = torch.optim.Adam([image], lr=0.001)
    loss_func = nn.L1Loss()
    
    progress = tqdm(range(steps))
    try:
        for step_idx in progress:
            batch = torch.concat([
                trans(image).unsqueeze(0)
                for _ in range(batch_size)
            ])
            output = model(batch)
            output = F.softmax(output, dim=1)
            if output.shape != target.shape:
                print(f"output={output.shape}, target={target.shape}")
                return
    
            loss = loss_func(output, target)
            #mean_loss = (.0 - batch.mean()).abs() 
            progress.set_postfix({"loss": loss.item()})#, "mean_loss": mean_loss.item(), "mean": batch.mean()})
            #loss = loss + mean_loss
            loss.backward()
            optim.step()
            model.zero_grad()
            with torch.no_grad():
                image[:] = image.clamp(0, 1)
    
            if step_idx % 50 == 0:
                with output_widget:
                    output_widget.clear_output()
                    display(VF.to_pil_image(resize(image, 2)))
    except KeyboardInterrupt:
        pass
    return image[:].cpu()
        
img = optimize_image(
    torch.rand(3, 128, 128) * .0 + .3, model, torch.Tensor([1, *((0, ) * 999)]),
    batch_size=16,
)

In [None]:
import ipywidgets

def optimize_image_2(
    image: torch.Tensor,
    model: nn.Module,
    steps: int = 1000,
    batch_size: int = 4,
    device: str = "auto",
    rand_degrees: float = 5.,
    rand_translate: float = 0.01,
):
    output_widget = ipywidgets.Output()
    output_widget.clear_output()
    display(output_widget)
    
    device = to_torch_device(device)
    image = nn.Parameter(image.clone().to(device))
    trans = VT.Compose([
        #VT.Resize((96, 96), interpolation=VT.InterpolationMode.BILINEAR),
        VT.RandomAffine(degrees=rand_degrees, translate=(rand_translate, rand_translate)),
        VT.RandomErasing(),
    ])
    
    #target = target.unsqueeze(0).repeat(batch_size, 1).to(device)
    model.to(device)
    optim = torch.optim.Adam([image], lr=0.001)
    loss_func = nn.L1Loss()

    layer = find_module_layer(model, "layer3.3.conv1")

    layer_snapshot = None
    def _hook(model, args, kwargs=None):
        nonlocal layer_snapshot
        layer_snapshot = args[0]
        #print(model, len(args), args[0].shape)
    hook = layer.register_forward_hook(_hook)
    
    progress = tqdm(range(steps))
    try:
        for step_idx in progress:
            batch = torch.concat([
                trans(image).unsqueeze(0)
                for _ in range(batch_size)
            ])
            model.zero_grad()
            output = model(batch)
            target_layer = torch.zeros_like(layer_snapshot)
            target_layer[:, 9, 0, 0] = 1
            loss = loss_func(layer_snapshot, target_layer)
            progress.set_postfix({"loss": loss.item()})#, "mean_loss": mean_loss.item(), "mean": batch.mean()})
            loss.backward()
            optim.step()
            with torch.no_grad():
                image[:] = image.clamp(0, 1)
    
            if step_idx % 50 == 0:
                with output_widget:
                    output_widget.clear_output()
                    display(VF.to_pil_image(resize(image, 2)))
    except KeyboardInterrupt:
        pass
    hook.remove()
    
    return image[:].cpu()
        
img = optimize_image_2(
    torch.rand(3, 128, 128) * .0 + .3, model, 
    batch_size=16,
)

In [None]:
[n for n, l in iter_module_layers(model)]