diff --git a/captum/optim/README.md b/captum/optim/README.md new file mode 100644 index 0000000000..2d3be54415 --- /dev/null +++ b/captum/optim/README.md @@ -0,0 +1,13 @@ +# Captum "optim" module + +This is a WIP PR to integrate existing feature visualization code from the authors of `tensorflow/lucid` into captum. +It is also an opportunity to review which parts of such interpretability tools still feel rough to implement in a system like PyTorch, and to make suggetsions to the core PyTorch team for how to improve these aspects. + +## Roadmap + +* unify API with Captum API: a single class that's callable per "technique"(? check for details before implementing) +* Consider if we need an abstraction around "an optimization process" (in terms of stopping criteria, reporting losses, etc) or if there are sufficiently strong conventions in PyTorch land for such tasks +* integrate Eli's FFT param changes (mostly for simplification) +* make a table of PyTorch interpretability tools for readme? +* do we need image viewing helpers and io helpers or throw those out? +* can we integrate paper references closer with the code? \ No newline at end of file diff --git a/captum/optim/__init__.py b/captum/optim/__init__.py new file mode 100644 index 0000000000..ce86513f77 --- /dev/null +++ b/captum/optim/__init__.py @@ -0,0 +1,27 @@ +from typing import Dict, Optional, Union, Callable, Iterable +from typing_extensions import Protocol + +import torch +import torch.nn as nn + +ParametersForOptimizers = Iterable[Union[torch.Tensor, Dict[str, torch.tensor]]] + + +class HasLoss(Protocol): + def loss(self) -> torch.Tensor: + ... + + +class Parameterized(Protocol): + parameters: ParametersForOptimizers + + +class Objective(Parameterized, HasLoss): + def cleanup(self): + pass + + +ModuleOutputMapping = Dict[nn.Module, Optional[torch.Tensor]] + +StopCriteria = Callable[[int, Objective, torch.optim.Optimizer], bool] + diff --git a/captum/optim/_scrap_and_testing.py b/captum/optim/_scrap_and_testing.py new file mode 100644 index 0000000000..705bafe356 --- /dev/null +++ b/captum/optim/_scrap_and_testing.py @@ -0,0 +1,139 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import requests +from PIL import Image +from IPython.display import display + +from clarity.pytorch.inception_v1 import googlenet +from lucid.misc.io import show, load, save +from lucid.modelzoo.other_models import InceptionV1 + +# get a test image +img_url = ( + "https://lucid-static.storage.googleapis.com/building-blocks/examples/dog_cat.png" +) +img_tf = load(img_url) +img_pt = torch.as_tensor(img_tf.transpose(2, 0, 1))[None, ...] +img_pil = Image.open(requests.get(img_url, stream=True).raw) + +# instantiate ported model +net = googlenet(pretrained=True) + +# get predictions +out = net(img_pt) +logits = out.detach().numpy()[0] +top_k = np.argsort(-logits)[:5] + +# load labels +labels = load(InceptionV1.labels_path, split=True) + +# show predictions +for i, k in enumerate(top_k): + prediction = logits[k] + label = labels[k] + print(f"{i}: {label} ({prediction*100:.2f}%)") + +# transforms + + +# def build_grid(source_size, target_size): +# k = float(target_size) / float(source_size) +# direct = ( +# torch.linspace(0, k, target_size) +# .unsqueeze(0) +# .repeat(target_size, 1) +# .unsqueeze(-1) +# ) +# full = torch.cat([direct, direct.transpose(1, 0)], dim=2).unsqueeze(0) +# return full.cuda() + + +# def random_crop_grid(x, grid): +# d = x.size(2) - grid.size(1) +# grid = grid.repeat(x.size(0), 1, 1, 1).cuda() +# # Add random shifts by x +# grid[:, :, :, 0] += torch.FloatTensor(x.size(0)).cuda().random_(0, d).unsqueeze( +# -1 +# ).unsqueeze(-1).expand(-1, grid.size(1), grid.size(2)) / x.size(2) +# # Add random shifts by y +# grid[:, :, :, 1] += torch.FloatTensor(x.size(0)).cuda().random_(0, d).unsqueeze( +# -1 +# ).unsqueeze(-1).expand(-1, grid.size(1), grid.size(2)) / x.size(2) +# return grid + + +# # We want to crop a 80x80 image randomly for our batch +# # Building central crop of 80 pixel size +# grid_source = build_grid(224, 80) +# # Make radom shift for each batch +# grid_shifted = random_crop_grid(batch, grid_source) +# # Sample using grid sample +# sampled_batch = F.grid_sample(batch, grid_shifted) + + +from clarity.pytorch.transform import RandomSpatialJitter, RandomUpsample + +# crop = torchvision.transforms.RandomCrop( +# 224, padding=34, pad_if_needed=True, padding_mode="reflect" +# ) +jitter = RandomSpatialJitter(16) +ups = RandomUpsample() +for i in range(10): + cropped = ups(img_pt) + show(cropped.numpy()[0].transpose(1, 2, 0)) + # display(cropped) + + +# result = param().cpu().detach().numpy()[0].transpose(1, 2, 0) +# loss_curve = objective.history + +# 2019-11-21 notes from Pytorch team +# Set up model +# net = googlenet(pretrained=True) +# parameterization = Image() # TODO: make size adjustable, currently hardcoded +# input_image = parameterization() + +# writer = SummaryWriter() +# writer.add_graph(net, (input_image,)) +# writer.close() + +# Specify target module / "objective" +# target_module = net.mixed3b._pool_reduce[1] +# target_channel = 54 +# hook = OutputHook(target_module) # TODO: investigate detach on rerun +# parameterization = Image() # TODO: make size adjustable, currently hardcoded +# optimizer = optim.Adam(parameterization.parameters, lr=0.025) + +# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# net = net.to(device) +# parameterization = parameterization.to(device) +# for i in range(1000): +# optimizer.zero_grad() + +# # forward pass through entire net +# input_image = parameterization() +# with suppress(AbortForwardException): +# _ = net(input_image.to(device)) + +# # activations were stored during forward pass +# assert hook.saved_output is not None +# loss = -hook.saved_output[:, target_channel, :, :].sum() # channel 13 + +# loss.backward() +# optimizer.step() + +# if i % 100 == 0: +# print("Loss: ", -loss.cpu().detach().numpy()) +# url = show( +# parameterization.raw_image.cpu() +# .detach() +# .numpy()[0] +# .transpose(1, 2, 0) +# ) + +# traced_net = torch.jit.trace(net, example_inputs=(input_image,)) +# print(traced_net.graph) diff --git a/captum/optim/io/__init__.py b/captum/optim/io/__init__.py new file mode 100644 index 0000000000..3431237f9f --- /dev/null +++ b/captum/optim/io/__init__.py @@ -0,0 +1 @@ +from .io import show diff --git a/captum/optim/io/fixtures.py b/captum/optim/io/fixtures.py new file mode 100644 index 0000000000..6f52ea0988 --- /dev/null +++ b/captum/optim/io/fixtures.py @@ -0,0 +1,13 @@ +import torch + +# TODO: use imageio to redo load and avoid TF dependency +from lucid.misc.io import load + +DOG_CAT_URL = ( + "https://lucid-static.storage.googleapis.com/building-blocks/examples/dog_cat.png" +) + + +def image(url: str = DOG_CAT_URL): + img_np = load(url) + return torch.as_tensor(img_np.transpose(2, 0, 1)) diff --git a/captum/optim/io/formatters.py b/captum/optim/io/formatters.py new file mode 100644 index 0000000000..4743a49964 --- /dev/null +++ b/captum/optim/io/formatters.py @@ -0,0 +1,22 @@ +from io import BytesIO + +import torch +from torchvision import transforms + +from IPython import display, get_ipython + + +def tensor_jpeg(tensor: torch.Tensor): + if tensor.dim() == 3: + pil_image = transforms.ToPILImage()(tensor.cpu().detach()).convert("RGB") + buffer = BytesIO() + pil_image.save(buffer, format="jpeg") + data = buffer.getvalue() + return data + else: + return tensor + + +def register_formatters(): + jpeg_formatter = get_ipython().display_formatter.formatters["image/jpeg"] + jpeg_formatter.for_type(torch.Tensor, tensor_jpeg) diff --git a/captum/optim/io/io.py b/captum/optim/io/io.py new file mode 100644 index 0000000000..822a3a5425 --- /dev/null +++ b/captum/optim/io/io.py @@ -0,0 +1,11 @@ +# TODO: redo show using display or register handler for jupyter display directly +# maybe we could even have subtypes of tensors that are "ImageTensors" or "ActivationTensors" etc +from lucid.misc.io import show as lucid_show + + +def show(thing): + if len(thing.shape) == 3: + numpy_thing = thing.cpu().detach().numpy().transpose(1, 2, 0) + elif len(thing.shape) == 4: + numpy_thing = thing.cpu().detach().numpy()[0].transpose(1, 2, 0) + lucid_show(numpy_thing) diff --git a/captum/optim/models/__init__.py b/captum/optim/models/__init__.py new file mode 100644 index 0000000000..353da46b6f --- /dev/null +++ b/captum/optim/models/__init__.py @@ -0,0 +1 @@ +from .inception_v1 import googlenet diff --git a/captum/optim/models/conv2d.py b/captum/optim/models/conv2d.py new file mode 100644 index 0000000000..be6c2a51da --- /dev/null +++ b/captum/optim/models/conv2d.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _get_padding(kernel_size, stride=1, dilation=1, **_): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def _calc_same_pad(i, k, s, d): + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, + groups, bias) + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0]) + pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2]) + return F.conv2d(x, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): +# padding = kwargs.pop('padding', '') +# kwargs.setdefault('bias', False) +# if isinstance(padding, str): +# # for any string padding, the padding will be calculated for you, one of three ways +# padding = padding.lower() +# if padding == 'same': +# # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact +# if _is_static_pad(kernel_size, **kwargs): +# # static case, no extra overhead +# padding = _get_padding(kernel_size, **kwargs) +# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) +# else: +# # dynamic padding +# return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) +# elif padding == 'valid': +# # 'VALID' padding, same as padding=0 +# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs) +# else: +# # Default to PyTorch style 'same'-ish symmetric padding +# padding = _get_padding(kernel_size, **kwargs) +# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) +# else: +# # padding was specified as a number or pair +# return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + +# class MixedConv2d(nn.Module): +# """ Mixed Grouped Convolution +# Based on MDConv and GroupedConv in MixNet impl: +# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py +# """ + +# def __init__(self, in_channels, out_channels, kernel_size=3, +# stride=1, padding='', dilated=False, depthwise=False, **kwargs): +# super(MixedConv2d, self).__init__() + +# kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] +# num_groups = len(kernel_size) +# in_splits = _split_channels(in_channels, num_groups) +# out_splits = _split_channels(out_channels, num_groups) +# for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): +# d = 1 +# # FIXME make compat with non-square kernel/dilations/strides +# if stride == 1 and dilated: +# d, k = (k - 1) // 2, 3 +# conv_groups = out_ch if depthwise else 1 +# # use add_module to keep key space clean +# self.add_module( +# str(idx), +# conv2d_pad( +# in_ch, out_ch, k, stride=stride, +# padding=padding, dilation=d, groups=conv_groups, **kwargs) +# ) +# self.splits = in_splits + +# def forward(self, x): +# x_split = torch.split(x, self.splits, 1) +# x_out = [c(x) for x, c in zip(x_split, self._modules.values())] +# x = torch.cat(x_out, 1) +# return x + + +# # helper method +# def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): +# assert 'groups' not in kwargs # only use 'depthwise' bool arg +# if isinstance(kernel_size, list): +# # We're going to use only lists for defining the MixedConv2d kernel groups, +# # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. +# return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) +# else: +# depthwise = kwargs.pop('depthwise', False) +# groups = out_chs if depthwise else 1 +# return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) diff --git a/captum/optim/models/import_inceptionv1.py b/captum/optim/models/import_inceptionv1.py new file mode 100644 index 0000000000..2d49717904 --- /dev/null +++ b/captum/optim/models/import_inceptionv1.py @@ -0,0 +1,100 @@ +import os +from pprint import pprint +import tensorflow as tf +from lucid.misc.io.loading import _load_graphdef_protobuf +from lucid.misc.io.writing import write_handle +from lucid.modelzoo.vision_models import InceptionV1 + +# tf_path = os.path.abspath('./inception_v1.pb') # Path to our TensorFlow checkpoint +# with open(tf_path, 'rb') as f: +# graph_def = _load_graphdef_protobuf(f) +# pprint(tf_vars) + +inception_v1_tf = InceptionV1() + +# better ds? +node_info = dict((n.name, n) for n in inception_v1_tf.graph_def.node) + +# interactive +op_types = set() +for node in inception_v1_tf.graph_def.node: + op_types.add(node.op) +pprint(op_types) + +aconst = None +for node in inception_v1_tf.graph_def.node: + if node.op == "Const": + print(node.name, node.op) + aconst = node + break + +sess = tf.InteractiveSession() +tf.import_graph_def(inception_v1_tf.graph_def) +graph = tf.get_default_graph() + + +for op in graph.get_operations(): + if op.type == "Const": + print(op.name, op.type) + + + +# Testing our reimplementation +import torch +import numpy as np + +from lucid.misc.io import load + +img_tf = load( + "https://lucid-static.storage.googleapis.com/building-blocks/examples/dog_cat.png" +) +img_pt = torch.as_tensor(img_tf.transpose(2, 0, 1))[None, ...] + + +from clarity.pytorch.inception_v1 import GoogLeNet, googlenet, GS_SAVED_WEIGHTS_URL + + +fresh_import = True + +if fresh_import: + net = GoogLeNet(transform_input=True) + net.import_weights_from_tf(inception_v1_tf) + + tmp_dst = '/tmp/inceptionv1_weights.pth' + torch.save(net.state_dict(), tmp_dst) + with write_handle(GS_SAVED_WEIGHTS_URL, 'wb') as handle: + with open(tmp_dst, 'rb') as tmp_file: + handle.write(tmp_file.read()) +else: + net = googlenet(pretrained=True) + +# forward pass PyTorch +out_pt = net(img_pt).detach() + +latest_op_name = "softmax2" +# forward pass TF +from lucid.optvis.render import import_model + +with tf.Graph().as_default(), tf.Session() as sess: + t_img = tf.placeholder("float32", [None, None, None, 3]) + T = import_model(inception_v1_tf, t_img, t_img) + out_tf = T(latest_op_name).eval(feed_dict={t_img: img_tf[None]}) + +# diagnostic +print(f"\nDiagnostics… evaluating at '{latest_op_name}'") +print( + f"PyTorch: {tuple(out_pt.shape)} µ: {out_pt.mean().item():.3f}, ↓: {out_pt.min().item():.1f}, ↑: {out_pt.max().item():8.3f}" +) +print( + f"TnsrFlw: {tuple(out_tf.shape)} µ: {out_tf.mean().item():.3f}, ↓: {out_tf.min().item():.1f}, ↑: {out_tf.max().item():8.3f}" +) + +if len(out_pt.shape) == 4: + mean_error = np.abs(out_tf.transpose(0, 3, 1, 2) - out_pt.numpy()).mean() +else: + mean_error = np.abs(out_tf - out_pt.numpy()).mean() +print(f"Mean Error: {mean_error:.5f}") + + + + diff --git a/captum/optim/models/inception_v1.py b/captum/optim/models/inception_v1.py new file mode 100644 index 0000000000..05b50528fe --- /dev/null +++ b/captum/optim/models/inception_v1.py @@ -0,0 +1,448 @@ +from __future__ import division + +import warnings +from collections import namedtuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.jit.annotations import Optional, Tuple +from torch import Tensor +from .conv2d import Conv2dSame + +from torch.hub import load_state_dict_from_url + +# __all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"] + +GS_SAVED_WEIGHTS_URL = ( + "https://storage.googleapis.com/openai-clarity/temp/InceptionV1_pytorch.pth" +) + +GoogLeNetOutputs = namedtuple( + "GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"] +) +GoogLeNetOutputs.__annotations__ = { + "logits": Tensor, + "aux_logits2": Optional[Tensor], + "aux_logits1": Optional[Tensor], +} + +# Script annotations failed with _GoogleNetOutputs = namedtuple ... +# _GoogLeNetOutputs set here for backwards compat +_GoogLeNetOutputs = GoogLeNetOutputs + + +def googlenet(pretrained=False, progress=True, **kwargs): + r"""GoogLeNet (Inception v1) model architecture from + `"Going Deeper with Convolutions" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + aux_logits (bool): If True, adds two auxiliary branches that can improve training. + Default: *False* when pretrained is True otherwise *True* + transform_input (bool): If True, preprocesses the input according to the method with which it + was trained on ImageNet. Default: *False* + """ + if pretrained: + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "aux_logits" not in kwargs: + kwargs["aux_logits"] = False + if kwargs["aux_logits"]: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, " + "so make sure to train them" + ) + original_aux_logits = kwargs["aux_logits"] + kwargs["aux_logits"] = True + kwargs["init_weights"] = False + model = GoogLeNet(**kwargs) + + state_dict = load_state_dict_from_url( + GS_SAVED_WEIGHTS_URL, progress=progress, check_hash=False + ) + model.load_state_dict(state_dict) + # if not original_aux_logits: + # model.aux_logits = False + # del model.aux1, model.aux2 + return model + + return GoogLeNet(**kwargs) + + +def _get_tf_value_by_name(name, graph, sess): + op = graph.get_operation_by_name(name) + return sess.run(op.values()[0]) + + +def _import_weight_into_module(pt_param, tf_name, graph, sess): + tf_value = _get_tf_value_by_name(tf_name, graph, sess) + if len(tf_value.shape) == 4 and len(pt_param.shape) == 4: + # assume k,k,c_in,c_out -> c_out,c_in,k,k + tf_value_transposed = tf_value.transpose(3, 2, 0, 1) + if tf_value_transposed.shape == pt_param.shape: + pt_param.data = torch.as_tensor(tf_value_transposed) + else: + raise RuntimeError( + f"non-matching shapes: {tf_value_transposed.shape} != {pt_param.shape}" + ) + elif len(tf_value.shape) == 2 and len(pt_param.shape) == 2: + if tf_value.shape == pt_param.shape: + pt_param.data = torch.as_tensor(tf_value) + elif tf_value.transpose(1, 0).shape == pt_param.shape: + pt_param.data = torch.as_tensor(tf_value.transpose(1, 0)) + else: + raise RuntimeError( + f"non-matching shapes: {tf_value.shape} != {pt_param.shape}" + ) + elif len(tf_value.shape) == 1 and len(pt_param.shape) == 1: + if tf_value.shape == pt_param.shape: + pt_param.data = torch.as_tensor(tf_value) + else: + raise RuntimeError( + f"non-matching shapes: {tf_value.shape} != {pt_param.shape}" + ) + else: + raise NotImplementedError + + +def _tf_param_name_for_module(module, pt_param_name): + if hasattr(module, "tf_param_name"): + return module.tf_param_name(pt_param_name) + + if isinstance(module, (nn.Conv2d, nn.Linear)): + assert pt_param_name in ["weight", "bias"] + return pt_param_name[0] # will be w or b + elif isinstance(module, nn.Sequential): + sequence, pt_param_name = pt_param_name.split(".") + assert pt_param_name in ["weight", "bias"] + if int(sequence) == 0: + return f"bottleneck_{pt_param_name[0]}" + elif int(sequence) == 1 or int(sequence) == 2: + return pt_param_name[0] + else: + raise NotImplementedError(f"cannot handle sequence blocks larger than 3") + else: + raise NotImplementedError(f"unknown module: {module}") + + +class GoogLeNet(nn.Module): + # __constants__ = ['aux_logits', 'transform_input'] + + def __init__( + self, + num_classes=1008, + aux_logits=True, + transform_input=True, + init_weights=True, + blocks=None, + ): + super(GoogLeNet, self).__init__() + if blocks is None: + blocks = [BasicConv2d, Inception, InceptionAux] + assert len(blocks) == 3 + conv_block = blocks[0] + inception_block = blocks[1] + inception_aux_block = blocks[2] + + self.aux_logits = aux_logits + self.transform_input = transform_input + + self.conv2d0 = Conv2dSame(3, 64, kernel_size=7, stride=2, padding=3) + # self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3) + self.maxpool0 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + # nn.modules.LocalResponseNorm specifies size rather than radius + tf_radius = 5 + pt_size = tf_radius * 2 + 1 + self.lrn = nn.LocalResponseNorm(pt_size, alpha=0.0001 * pt_size, beta=0.5, k=2) + self.conv2d1 = nn.Conv2d(64, 64, kernel_size=1) + self.conv2d2 = Conv2dSame(64, 192, kernel_size=3, padding=1) + self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + + self.mixed3a = inception_block(192, 64, 96, 128, 16, 32, 32) + self.mixed3b = inception_block(256, 128, 128, 192, 32, 96, 64) + self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + + self.mixed4a = inception_block(480, 192, 96, 204, 16, 48, 64) + self.mixed4b = inception_block(508, 160, 112, 224, 24, 64, 64) + self.mixed4c = inception_block(512, 128, 128, 256, 24, 64, 64) + self.mixed4d = inception_block(512, 112, 144, 288, 32, 64, 64) + self.mixed4e = inception_block(528, 256, 160, 320, 32, 128, 128) + self.maxpool10 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + + self.mixed5a = inception_block(832, 256, 160, 320, 48, 128, 128) + self.mixed5b = inception_block(832, 384, 192, 384, 48, 128, 128) + + # if aux_logits: + # self.aux1 = inception_aux_block(512, num_classes) + # self.aux2 = inception_aux_block(528, num_classes) + + self.avgpool0 = nn.AdaptiveAvgPool2d((1, 1)) + # self.dropout = nn.Dropout(0.2) + self.softmax2_pre_activation = nn.Linear(1024, num_classes) + self.softmax2 = nn.Softmax() + + # if init_weights: + # self._initialize_weights() + + # def _initialize_weights(self): + # for m in self.modules(): + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + # import scipy.stats as stats + # X = stats.truncnorm(-2, 2, scale=0.01) + # values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) + # values = values.view(m.weight.size()) + # with torch.no_grad(): + # m.weight.copy_(values) + # elif isinstance(m, nn.BatchNorm2d): + # nn.init.constant_(m.weight, 1) + # nn.init.constant_(m.bias, 0) + + def _transform_input(self, x): + # type: (Tensor) -> Tensor + if self.transform_input: + # x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + # x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + # x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + # x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + assert x.min() >= 0.0 and x.max() <= 1.0 + x = x * 255 - 117 + return x + + def _forward(self, x): + # assert x.size(1) == 3 + # type: (Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]] + # N x 3 x 224 x 224 + x = self.conv2d0(x) + x = F.relu(x, inplace=True) + # N x 64 x 112 x 112 + x = self.maxpool0(x) + x = self.lrn(x) + # N x 64 x 56 x 56 + x = self.conv2d1(x) + x = F.relu(x, inplace=True) + # N x 64 x 56 x 56 + x = self.conv2d2(x) + x = F.relu(x, inplace=True) + x = self.lrn(x) + # N x 192 x 56 x 56 + x = self.maxpool1(x) + + # # N x 192 x 28 x 28 + # x = self.mixed3a(x) + x = self.mixed3a(x) + # # N x 256 x 28 x 28 + x = self.mixed3b(x) + # # N x 480 x 28 x 28 + x = self.maxpool4(x) + # # N x 480 x 14 x 14 + x = self.mixed4a(x) + # # N x 512 x 14 x 14 + # aux_defined = self.training and self.aux_logits + # if aux_defined: + # aux1 = self.aux1(x) + # else: + # aux1 = None + + x = self.mixed4b(x) + # # N x 512 x 14 x 14 + x = self.mixed4c(x) + # # N x 512 x 14 x 14 + x = self.mixed4d(x) + # # N x 528 x 14 x 14 + # if aux_defined: + # aux2 = self.aux2(x) + # else: + # aux2 = None + + x = self.mixed4e(x) + # # N x 832 x 14 x 14 + x = self.maxpool10(x) + # # N x 832 x 7 x 7 + x = self.mixed5a(x) + # # N x 832 x 7 x 7 + x = self.mixed5b(x) + # # N x 1024 x 7 x 7 + + x = self.avgpool0(x) + # # N x 1024 x 1 x 1 + x = torch.flatten(x, 1) + # # N x 1024 + # x = self.dropout(x) + x = self.softmax2_pre_activation(x) + x = self.softmax2(x) + # # N x 1000 (num_classes) + aux2, aux1 = None, None + return x, aux2, aux1 + + # @torch.jit.unused + # def eager_outputs(self, x, aux2, aux1): + # # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> GoogLeNetOutputs + # if self.training and self.aux_logits: + # return _GoogLeNetOutputs(x, aux2, aux1) + # else: + # return x + + def forward(self, x): + # type: (Tensor) -> GoogLeNetOutputs + x = self._transform_input(x) + x, aux1, aux2 = self._forward(x) + return x + # aux_defined = self.training and self.aux_logits + # if torch.jit.is_scripting(): + # if not aux_defined: + # warnings.warn( + # "Scripted GoogleNet always returns GoogleNetOutputs Tuple" + # ) + # return GoogLeNetOutputs(x, aux2, aux1) + # else: + # return self.eager_outputs(x, aux2, aux1) + + def import_weights_from_tf(self, model): + import tensorflow as tf + + print("Setting Paramaters…") + with tf.Graph().as_default() as graph, tf.Session() as sess: + tf.import_graph_def(model.graph_def) + + prefix = "import/" + for module_name, module in self.named_children(): + print("named child", module_name) + if module_name == "softmax2_pre_activation": + module_name = "softmax2" + if not hasattr(module, "import_weights_from_tf"): + for param_name, pt_param in module.named_parameters(recurse=False): + tf_param_name = _tf_param_name_for_module(module, param_name) + tf_param_name = f"{prefix}{module_name}_{tf_param_name}" + + print( + f"Setting {module_name}.{param_name} to value of {tf_param_name} ({pt_param.shape})" + ) + _import_weight_into_module(pt_param, tf_param_name, graph, sess) + else: + module.import_weights_from_tf(prefix, module_name, graph, sess) + + # print(name, type(module)) + # for param_name, pt_param in module.named_parameters(recurse=False): + # print('module param', param_name) + + # for name, pt_param in self.named_parameters(recurse=False): + # print('non-recurse', name) + + # for name, pt_param in self.named_parameters(recurse=True): + # print('recurse', name) + + # print("Setting Paramaters…") + # with tf.Graph().as_default() as graph, tf.Session() as sess: + # tf.import_graph_def(model.graph_def) + # for name, pt_param in self.named_parameters(recurse=True): + # module_name, w_or_b = name.rsplit(".", 1) + # tf_name = f"import/{module_name}_{w_or_b[0]}" + # tf_name = tf_name.replace('.', '') + # _import_weight_into_module(pt_param, name, tf_name, graph, sess) + + +class Inception(nn.Module): + # __constants__ = ["branch2", "branch3", "branch4"] + # (192, 64, 96, 128, 16, 32, 32) + def __init__( + self, + in_channels, # 192 + ch1x1, # 64 + ch3x3bottleneck, # 96 + ch3x3, # 128 + ch5x5bottleneck, # 16 + ch5x5, # 32 + pool_proj, # 32 + conv_block=None, + ): + super(Inception, self).__init__() + if conv_block is None: + conv_block = Conv2dSame + self._1x1 = conv_block(in_channels, ch1x1, kernel_size=1) + + self._3x3 = nn.Sequential( + conv_block(in_channels, ch3x3bottleneck, kernel_size=1), + nn.ReLU(inplace=True), + conv_block(ch3x3bottleneck, ch3x3, kernel_size=3, padding=1), + ) + + self._5x5 = nn.Sequential( + conv_block(in_channels, ch5x5bottleneck, kernel_size=1), + nn.ReLU(inplace=True), + conv_block(ch5x5bottleneck, ch5x5, kernel_size=5, padding=1), + ) + + self._pool_reduce = nn.Sequential( + nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), + conv_block(in_channels, pool_proj, kernel_size=1), + ) + + def _forward(self, x): + _1x1 = self._1x1(x) + # _3x3_bottleneck = self._3x3[0](x) + _3x3 = self._3x3(x) + _5x5 = self._5x5(x) + # _5x5_bottleneck = self._5x5[0](x) + _pool_reduce = self._pool_reduce(x) + + outputs = [_1x1, _3x3, _5x5, _pool_reduce] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return F.relu(torch.cat(outputs, 1), inplace=True) + + def import_weights_from_tf(self, prefix, own_name, graph, sess): + for module_name, module in self.named_children(): + print(f"{own_name}: named child {module_name}") + if not hasattr(module, "import_weights_from_tf"): + for param_name, pt_param in module.named_parameters(recurse=True): + tf_param_name = _tf_param_name_for_module(module, param_name) + tf_param_name = f"{prefix}{own_name}{module_name}_{tf_param_name}" + + print( + f"Setting {module_name}.{param_name} to value of {tf_param_name} ({pt_param.shape})" + ) + _import_weight_into_module(pt_param, tf_param_name, graph, sess) + else: + module.import_weights_from_tf(module_name, prefix, graph, sess) + + +class InceptionAux(nn.Module): + def __init__(self, in_channels, num_classes, conv_block=None): + super(InceptionAux, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.conv = conv_block(in_channels, 128, kernel_size=1) + + self.fc1 = nn.Linear(2048, 1024) + self.fc2 = nn.Linear(1024, num_classes) + + def forward(self, x): + # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 + x = F.adaptive_avg_pool2d(x, (4, 4)) + # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 + x = self.conv(x) + # N x 128 x 4 x 4 + x = torch.flatten(x, 1) + # N x 2048 + x = F.relu(self.fc1(x), inplace=True) + # N x 1024 + x = F.dropout(x, 0.7, training=self.training) + # N x 1024 + x = self.fc2(x) + # N x 1000 (num_classes) + + return x + + +class BasicConv2d(nn.Module): + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) diff --git a/captum/optim/optim/__init__.py b/captum/optim/optim/__init__.py new file mode 100644 index 0000000000..3d8a464331 --- /dev/null +++ b/captum/optim/optim/__init__.py @@ -0,0 +1 @@ +from .output_hook import AbortForwardException, ModuleOutputsHook diff --git a/captum/optim/optim/objectives.py b/captum/optim/optim/objectives.py new file mode 100644 index 0000000000..a373aac0bf --- /dev/null +++ b/captum/optim/optim/objectives.py @@ -0,0 +1,168 @@ +from contextlib import suppress +from typing import Callable, Iterable, Optional + +import torch +import torch.nn as nn + +from clarity.pytorch import Parameterized, Objective, ModuleOutputMapping +from clarity.pytorch.param import ImageParameterization, NaturalImage, RandomAffine + +from .output_hook import AbortForwardException, ModuleOutputsHook + +LossFunction = Callable[[ModuleOutputMapping], torch.Tensor] +SingleTargetLossFunction = Callable[[torch.Tensor], torch.Tensor] + + +class InputOptimization(Objective, Parameterized): + net: nn.Module + input_param: ImageParameterization + input_transformation: nn.Module + + def __init__( + self, + net: nn.Module, + input_param: Optional[nn.Module], + transform: Optional[nn.Module], + targets: Iterable[nn.Module], + loss_function: LossFunction, + ): + self.net = net + self.hooks = ModuleOutputsHook(targets) + self.input_param = input_param or NaturalImage((224, 224)) + self.transform = transform or RandomAffine(scale=True, translate=True) + self.loss_function = loss_function + + def loss(self) -> torch.Tensor: + image = self.input_param()[None, ...] + + if self.transform: + image = self.transform(image) + + with suppress(AbortForwardException): + _unreachable = self.net(image) + + # consume_ouputs return the captured values and resets the hook's state + module_outputs = self.hooks.consume_outputs() + loss_value = self.loss_function(module_outputs) + return loss_value + + def cleanup(self): + self.hooks.remove_hooks() + + # Targets are managed by ModuleOutputHooks; we mainly just want a convenient setter + @property + def targets(self): + return self.hooks.targets + + @targets.setter + def targets(self, value): + self.hooks.remove_hooks() + self.hooks = ModuleOutputsHook(value) + + def parameters(self): + return self.input_param.parameters() + + +def channel_activation(target: nn.Module, channel_index: int) -> LossFunction: + # ensure channel_index will be valid + assert channel_index < target.out_channels + + def loss_function(targets_to_values: ModuleOutputMapping): + activations = targets_to_values[target] + assert activations is not None + assert len(activations.shape) == 4 # assume NCHW + return activations[:, channel_index, ...] + + return loss_function + + +def neuron_activation( + target: nn.Module, channel_index: int, x: int = None, y: int = None +) -> LossFunction: + # ensure channel_index will be valid + assert channel_index < target.out_channels + + def loss_function(targets_to_values: ModuleOutputMapping): + activations = targets_to_values[target] + assert activations is not None + assert len(activations.shape) == 4 # assume NCHW + _, _, H, W = activations.shape + + if x is None: + _x = W // 2 + else: + assert x < W + _x = x + + if y is None: + _y = H // 2 + else: + assert y < W + _y = y + + return activations[:, channel_index, _x, _y] + + return loss_function + + +def single_target_objective( + target: nn.Module, loss_function: SingleTargetLossFunction +) -> LossFunction: + def inner(targets_to_values: ModuleOutputMapping): + value = targets_to_values[target] + return loss_function(value) + + return inner + + +class SingleTargetObjective(Objective): + def __init__( + self, + net: nn.Module, + target: nn.Module, + loss_function: Callable[[torch.Tensor], torch.Tensor], + ): + super(SingleTargetObjective, self).__init__(net=net, targets=[target]) + self.loss_function = loss_function + + def loss(self, targets_to_values): + assert len(self.targets) == 1 + target = self.targets[0] + target_value = targets_to_values[target] + loss_value = self.loss_function(target_value) + self.history.append(loss_value.sum().cpu().detach().numpy().squeeze().item()) + return loss_value + + +# class MultiObjective(Objective): +# def __init__( +# self, objectives: List[Objective], weights: Optional[Iterable[float]] = None +# ): +# net = objectives[0].net +# assert all(o.net == net for o in objectives) +# targets = (target for objective in objectives for target in objective.targets) +# super(MultiObjective, self).__init__(net=net, targets=targets) +# self.objectives = objectives +# self.weights = weights or len(objectives) * [1] + +# def loss(self, targets_to_values): +# losses = ( +# objective.loss_function(targets_to_values) for objective in self.objectives +# ) +# weighted = (loss * weight for weight in self.weights) +# loss_value = sum(weighted) +# self.history.append(loss_value.cpu().detach().numpy().squeeze().item()) +# return loss_value + +# @property +# def histories(self) -> List[List[float]]: +# return [objective.history for objective in self.objectives] + + +# class ChannelObjective(SingleTargetObjective): +# def __init__(self, channel: int, *args, **kwargs): +# loss_function = lambda activation: activation[:, channel, :, :].mean() +# super(ChannelObjective, self).__init__( +# *args, loss_function=loss_function, **kwargs +# ) + diff --git a/captum/optim/optim/optimize.py b/captum/optim/optim/optimize.py new file mode 100644 index 0000000000..4f935016c2 --- /dev/null +++ b/captum/optim/optim/optimize.py @@ -0,0 +1,50 @@ +from contextlib import suppress +from typing import Dict, Callable, Iterable, Optional, List, Union +from typing_extensions import Protocol +from tqdm.auto import tqdm + +import torch +import torch.nn as nn +import torch.optim as optim + +from clarity.pytorch import StopCriteria, Objective + + +def optimize( + objective: Objective, + stop_criteria: Optional[StopCriteria] = None, + optimizer: Optional[optim.Optimizer] = None, +): + stop_criteria = stop_criteria or n_steps(1024) + optimizer = optimizer or optim.Adam(objective.parameters(), lr=0.025) + assert isinstance(optimizer, optim.Optimizer) + + history = [] + step = 0 + while stop_criteria(step, objective, history, optimizer): + optimizer.zero_grad() + + loss_value = objective.loss() + history.append(loss_value.cpu().detach().numpy()) + (-1 * loss_value.mean()).backward() + optimizer.step() + step += 1 + + objective.cleanup() + return history + + +def n_steps(n: int) -> StopCriteria: + pbar = tqdm(total=n, unit="step") + + def continue_while(step, obj, history, optim): + if len(history) > 0: + pbar.set_postfix({"Objective": f"{history[-1].mean():.1f}"}, refresh=False) + if step < n: + pbar.update() + return True + else: + pbar.close() + return False + + return continue_while diff --git a/captum/optim/optim/output_hook.py b/captum/optim/optim/output_hook.py new file mode 100644 index 0000000000..767819e6c7 --- /dev/null +++ b/captum/optim/optim/output_hook.py @@ -0,0 +1,88 @@ +from warnings import warn +from typing import Iterable, Dict, Optional + +import torch +import torch.nn as nn + +from clarity.pytorch import ModuleOutputMapping + + +class AbortForwardException(Exception): + pass + + +class ModuleReuseException(Exception): + pass + + +# class SingleTargetHook: +# def __init__(self, module: nn.Module): +# self.saved_output = None +# self.target_modules = [module] +# self.remove_forward = module.register_forward_hook(self._forward_hook()) + +# @property +# def is_ready(self) -> bool: +# return self.saved_output is not None + +# def _forward_hook(self): +# def forward_hook(module, input, output): +# assert self.module == module +# self.saved_output = output +# raise AbortForwardException("Forward hook called, output saved.") + +# return forward_hook + +# def __del__(self): +# self.remove_forward() + + +class ModuleOutputsHook: + def __init__(self, target_modules: Iterable[nn.Module]): + self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None) + self.hooks = [ + module.register_forward_hook(self._forward_hook()) + for module in target_modules + ] + + def _reset_outputs(self): + self.outputs = dict.fromkeys(self.outputs.keys(), None) + + @property + def is_ready(self) -> bool: + return all(value is not None for value in self.outputs.values()) + + def _forward_hook(self): + def forward_hook(module, input, output): + assert module in self.outputs.keys() + if self.outputs[module] is None: + self.outputs[module] = output + else: + warn( + f"Hook attached to {module} was called multiple times. As of 2019-11-22 please don't reuse nn.Modules in your models." + ) + if self.is_ready: + raise AbortForwardException("Forward hook called, all outputs saved.") + + return forward_hook + + def consume_outputs(self) -> ModuleOutputMapping: + if not self.is_ready: + warn( + "Consume captured outputs, but not all requested target outputs have been captured yet!" + ) + outputs = self.outputs + self._reset_outputs() + return outputs + + @property + def targets(self): + return self.outputs.keys() + + def remove_hooks(self): + for hook in self.hooks: + hook.remove() + + def __del__(self): + print(f"DEL HOOKS!: {list(self.outputs.keys())}") + self.remove_hooks() diff --git a/captum/optim/param/__init__.py b/captum/optim/param/__init__.py new file mode 100644 index 0000000000..6495a250c0 --- /dev/null +++ b/captum/optim/param/__init__.py @@ -0,0 +1,4 @@ +"""(Differentiable) Input Parameterizations. Currently only 3-channel images""" + +from .images import ImageParameterization, NaturalImage +from .transform import RandomAffine, GaussianSmoothing, BlendAlpha, IgnoreAlpha diff --git a/captum/optim/param/images.py b/captum/optim/param/images.py new file mode 100644 index 0000000000..9c6fc65ffc --- /dev/null +++ b/captum/optim/param/images.py @@ -0,0 +1,263 @@ +from typing import List, Union, Tuple +import numpy as np + +import torch +import torch.nn as nn +import torchvision.models as models +import torch.nn.functional as F +from torchvision import transforms +import torchvision.transforms.functional as TF + +from lucid.misc.io import load, save, show + +# mean = [0.485, 0.456, 0.406] +# std = [0.229, 0.224, 0.225] + +# normalize = transforms.Normalize(mean=mean, std=std) + +# mean = torch.Tensor(mean)[None, :, None, None] +# std = torch.Tensor(std)[None, :, None, None] + + +# def denormalize(x: torch.Tensor): +# return std * x + mean + + +def logit(p: torch.Tensor, epsilon=1e-6) -> torch.Tensor: + p = torch.clamp(p, min=epsilon, max=1.0 - epsilon) + assert p.min() >= 0 and p.max() < 1 + return torch.log(p / (1 - p)) + + +# def jitter(x: torch.Tensor, pad_width=2, pad_value=0.5): +# _, _, H, W = x.shape +# y = F.pad(x, 4 * (pad_width,), value=pad_value) +# idx, idy = np.random.randint(low=0, high=2 * pad_width, size=(2,)) +# return y[:, :, idx : idx + H, idy : idy + W] + + +# def color_correction(): +# S = np.asarray( +# [[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]] +# ).astype("float32") +# C = S / np.max(np.linalg.norm(S, axis=0)) +# C = torch.Tensor(C) +# return C.transpose(0, 1) + + +class ToRGB(nn.Module): + """Transforms arbitrary channels to RGB. We use this to ensure our + image parameteriation itself can be decorrelated. So this goes between + the image parameterization and the normalization/sigmoid step. + + We offer two transforms: Karhunen-Loève (KLT) and I1I2I3. + + KLT corresponds to the empirically measured channel correlations on imagenet. + I1I2I3 corresponds to an aproximation for natural images from Ohta et al.[0] + + [0] Y. Ohta, T. Kanade, and T. Sakai, "Color information for region segmentation," + Computer Graphics and Image Processing, vol. 13, no. 3, pp. 222–241, 1980 + https://www.sciencedirect.com/science/article/pii/0146664X80900477 + """ + + @staticmethod + def klt_transform(): + """Karhunen-Loève transform (KLT) measured on ImageNet""" + KLT = [[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]] + transform = np.asarray(KLT, dtype=np.float32) + transform /= np.max(np.linalg.norm(transform, axis=0)) + return torch.as_tensor(transform) + + @staticmethod + def i1i2i3_transform(): + i1i2i3_matrix = [ + [1 / 3, 1 / 3, 1 / 3], + [1 / 2, 0, -1 / 2], + [-1 / 4, 1 / 2, -1 / 4], + ] + return torch.Tensor(i1i2i3_matrix) + + def __init__(self, transform_name="klt"): + super().__init__() + + if transform_name == "klt": + self.register_buffer("transform", ToRGB.klt_transform()) + elif transform_name == "i1i2i3": + self.register_buffer("transform", ToRGB.i1i2i3_transform()) + else: + raise ValueError(f"transform_name has to be either 'klt' or 'i1i2i3'") + + def forward(self, x, inverse=False): + assert x.dim() == 3 + + # alpha channel is taken off... + has_alpha = x.size("C") == 4 + if has_alpha: + x, alpha_channel = x[:3], x[3:] + assert x.dim() == alpha_channel.dim() # ensure we "keep_dim" + + h, w = x.size("H"), x.size("W") + flat = x.flatten(("H", "W"), "spatials") + if inverse: + correct = self.transform.t() @ flat + else: + correct = self.transform @ flat + chw = correct.unflatten("spatials", (("H", h), ("W", w))).refine_names("C", ...) + + # ...alpha channel is concatenated on again. + if has_alpha: + chw = torch.cat([chw, alpha_channel], 0) + + return chw + + +# def model(layer): +# net = models.googlenet(pretrained=True) +# net.train(False) + +# def get_subnet_at_layer(lay_idx): +# subnet = nn.Sequential(*list(net._modules.values())[: lay_idx + 1]) +# subnet +# for p in subnet.parameters(): +# p.requires_grad = False +# return subnet + +# subnet = get_subnet_at_layer(layer) +# subnet.train(False) +# return subnet + + +# def upsample(): +# upsample = torch.nn.Upsample(scale_factor=1.1, mode="bilinear", align_corners=True) + +# def up(x): +# upsample.scale_factor = ( +# 1 + np.random.randn(1)[0] / 50, +# 1 + np.random.randn(1)[0] / 50, +# ) +# return upsample(x) + +# return up + + +class ImageTensor(torch.Tensor): + pass + + +class ImageParameterization(torch.nn.Module): + def set_image(self, x: torch.Tensor): + ... + + +class FFTImage(ImageParameterization): + """Parameterize an image using inverse real 2D FFT""" + + def __init__(self, size, channels=3): + super().__init__() + assert len(size) == 2 + self.size = size + + coeffs_shape = (channels, size[0], size[1] // 2 + 1, 2) + random_coeffs = torch.randn( + coeffs_shape + ) # names=["C", "H_f", "W_f", "complex"] + self.fourier_coeffs = nn.Parameter(random_coeffs / 50) + + frequencies = FFTImage.rfft2d_freqs(*size) + scale = 1.0 / np.maximum(frequencies, 1.0 / max(*size)) + scale *= np.sqrt(size[0] * size[1]) + spectrum_scale = torch.Tensor(scale[None, :, :, None].astype(np.float32)) + self.register_buffer("spectrum_scale", spectrum_scale) + + @staticmethod + def rfft2d_freqs(height, width): + """Computes 2D spectrum frequencies.""" + f_y = np.fft.fftfreq(height)[:, None] + # on odd input dimensions we need to keep one additional frequency + add = 2 if width % 2 == 1 else 1 + f_x = np.fft.fftfreq(width)[: width // 2 + add] + return np.sqrt(f_x * f_x + f_y * f_y) + + def set_image(self, correlated_image: torch.Tensor): + coeffs = torch.rfft(correlated_image, signal_ndim=2) + self.fourier_coeffs = coeffs / self.spectrum_scale + + def forward(self): + h, w = self.size + scaled_spectrum = self.fourier_coeffs * self.spectrum_scale + output = torch.irfft(scaled_spectrum, signal_ndim=2)[:, :h, :w] + return output.refine_names("C", "H", "W") + + +class PixelImage(ImageParameterization): + def __init__(self, size=None, channels: int = 3, init: torch.Tensor = None): + super().__init__() + if init is None: + assert size is not None and channels is not None + init = torch.randn([channels, size[0], size[1]]) / 10 + 0.5 + else: + assert init.shape[0] == 3 + self.image = nn.Parameter(init) + + def forward(self): + return self.image + + def set_image(self, correlated_image: torch.Tensor): + self.image = nn.Parameter(correlated_image) + + +class LaplacianImage(ImageParameterization): + def __init__(self): + super().__init__() + power = 0.1 + X = [] + scaler = [] + for scale in [1, 2, 4, 8, 16, 32]: + upsample = torch.nn.Upsample(scale_factor=scale, mode="nearest") + x = torch.randn([1, 3, 224 // scale, 224 // scale]) / 10 + x = x.cuda() + x = x * (scale ** power) / (32 ** power) + x.requires_grad = True + X.append(x) + scaler.append(upsample) + + self.parameters = X + self.scaler = scaler + + def forward(self): + A = [] + for xi, upsamplei in zip(self.X, self.scaler): + A.append(upsamplei(xi)) + return torch.sum(torch.cat(A), 0) + 0.5 + + +class NaturalImage(ImageParameterization): + r"""Outputs an optimizable input image. + + By convention, single images are CHW and float32s in [0,1]. + The underlying parameterization is decorrelated via a ToRGB transform. + When used with the (default) FFT parameterization, this results in a fully + uncorrelated image parameterization. :-) + + If a model requires a normalization step, such as normalizing imagenet RGB values, + or rescaling to [0,255], it has to perform that step inside its computation. + For example, our GoogleNet factory function has a `transform_input=True` argument. + """ + + def __init__(self, size, channels=3, Parameterization=FFTImage): + super().__init__() + + self.parameterization = Parameterization(size=size, channels=channels) + self.decorrelate = ToRGB(transform_name="klt") + + def forward(self): + image = self.parameterization() + image = self.decorrelate(image) + image = image.rename(None) # TODO: the world is not yet ready + return torch.sigmoid_(image) + + def set_image(self, image): + logits = logit(image, epsilon=1e-4) + correlated = self.decorrelate(logits, inverse=True) + self.parameterization.set_image(correlated) + diff --git a/captum/optim/param/test_images.py b/captum/optim/param/test_images.py new file mode 100644 index 0000000000..5ba75a3ffe --- /dev/null +++ b/captum/optim/param/test_images.py @@ -0,0 +1 @@ +from .images import Image, FFTImage, PixelImage \ No newline at end of file diff --git a/captum/optim/param/transform.py b/captum/optim/param/transform.py new file mode 100644 index 0000000000..5d202e42cb --- /dev/null +++ b/captum/optim/param/transform.py @@ -0,0 +1,225 @@ +import math +import numbers +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from kornia.geometry.transform import rotate, scale, shear, translate + + +class BlendAlpha(nn.Module): + r"""Blends a 4 channel input parameterization into an RGB image. + + You can specify a fixed background, or a random one will be used by default. + """ + + def __init__(self, background: torch.Tensor = None): + super().__init__() + self.background = background + + def forward(self, x): + assert x.size(1) == 4 + rgb, alpha = x[:, :3, ...], x[:, 3:4, ...] + background = self.background or torch.rand_like(rgb) + blended = alpha * rgb + (1 - alpha) * background + return blended + + +class IgnoreAlpha(nn.Module): + r"""Ignores a 4th channel""" + + def forward(self, x): + assert x.size(1) == 4 + rgb = x[:, :3, ...] + return rgb + + +def center_crop(input: torch.Tensor, output_size) -> torch.Tensor: + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + if len(output_size) == 4: # assume NCHW + output_size = output_size[2:] + + assert len(output_size) == 2 and len(input.shape) == 4 + + image_width, image_height = input.shape[2:] + height, width = output_size + top = int(round((image_height - height) / 2.0)) + left = int(round((image_width - width) / 2.0)) + + return F.pad( + input, [top, height - image_height - top, left, width - image_width - left] + ) + + +# class RandomSpatialJitter(nn.Module): +# def __init__(self, max_distance): +# super().__init__() + +# self.pad_range = 2 * max_distance +# self.pad = nn.ReflectionPad2d(max_distance) + +# def forward(self, x): +# padded = self.pad(x) +# insets = torch.randint(high=self.pad_range, size=(2,)) +# tblr = [ +# -insets[0], +# -(self.pad_range - insets[0]), +# -insets[1], +# -(self.pad_range - insets[1]), +# ] +# cropped = F.pad(padded, pad=tblr) +# assert cropped.shape == x.shape +# return cropped + + +# class RandomScale(nn.Module): +# def __init__(self, *args, **kwargs): +# super().__init__() +# self.scale = torch.distributions.Uniform(0.95, 1.05) + +# def forward(self, x): +# by = self.scale.sample().item() +# return F.interpolate(x, scale_factor=by, mode="bilinear") + + +# class TransformationRobustness(nn.Module): +# def __init__(self, jitter=False, scale=False): +# super().__init__() +# if jitter: +# self.jitter = RandomSpatialJitter(4) +# if scale: +# self.scale = RandomScale() + +# def forward(self, x): +# original_shape = x.shape +# if hasattr(self, "jitter"): +# x = self.jitter(x) +# if hasattr(self, "scale"): +# x = self.scale(x) +# cropped = center_crop(x, original_shape) +# return cropped + + +class RandomAffine(nn.Module): + """TODO: Can we look into Distributions more to give more control and be more PyTorch-y?""" + + def __init__(self, rotate=False, scale=False, shear=False, translate=False): + super().__init__() + self.rotate = rotate + self.scale = scale + self.shear = shear + self.translate = translate + + def forward(self, x): + if self.rotate: + rotate_angle = torch.randn(1, device=x.device) # >95% < 6deg + logging.info(f"Rotate: {rotate_angle}") + x = rotate(x, rotate_angle) + if self.scale: + scale_factor = (torch.randn(1, device=x.device) / 40.0) + 1 + logging.info(f"Scale: {scale_factor}") + x = scale(x, scale_factor) + if self.shear: + shear_matrix = torch.randn((1, 2), device=x.device) / 40.0 # >95% < 2deg + logging.info(f"Shear: {shear_matrix}") + x = shear(x, shear_matrix) + if self.translate: + translation = torch.randn((1, 2), device=x.device) + logging.info(f"Translate: {translation}") + x = translate(x, translation) + return x + + +# class RandomHomography(nn.Module): +# def __init__(self): +# super().__init__() + +# def forward(self, x): +# _, _, H, W = x.shape +# self.homography_warper = HomographyWarper( +# height=H, width=W, padding_mode="reflection" +# ) +# homography = +# return self.homography_warper(x, homography) + + +# via https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/9 +class GaussianSmoothing(nn.Module): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + def __init__(self, channels, kernel_size, sigma, dim=2): + super().__init__() + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [torch.arange(size, dtype=torch.float32) for size in kernel_size] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= ( + 1 + / (std * math.sqrt(2 * math.pi)) + * torch.exp(-((mgrid - mean) / std) ** 2 / 2) + ) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim) + ) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight, groups=self.groups) + + +def test_transform(): + from clarity.pytorch.fixtures import image + from clarity.pytorch.io import show + + input_image = image()[None, ...] + show(input_image) + transform = GaussianSmoothing(channels=3, kernel_size=(5, 5), sigma=2) + transformed_image = transform(input_image) + show(transformed_image) diff --git a/captum/optim/techs/__init__.py b/captum/optim/techs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/captum/optim/techs/feature_vis.py b/captum/optim/techs/feature_vis.py new file mode 100644 index 0000000000..a73f441768 --- /dev/null +++ b/captum/optim/techs/feature_vis.py @@ -0,0 +1,118 @@ +# %load_ext autoreload +# %autoreload 2 + +from typing import Dict, Callable, Iterable + +from tqdm.auto import tqdm + +import torch +import torch.nn as nn +import torch.optim as optim + +from kornia.losses import total_variation + +from lucid.misc.io import save +from clarity.pytorch.io import show +from clarity.pytorch.models import googlenet +from clarity.pytorch.param import ( + NaturalImage, + RandomAffine, + GaussianSmoothing, + BlendAlpha, + IgnoreAlpha, +) +from clarity.pytorch.optim.objectives import ( + InputOptimization, + single_target_objective, + channel_activation, + neuron_activation, +) +from clarity.pytorch.optim.optimize import optimize, n_steps + +if torch.cuda.is_available(): + device = torch.device("cuda:0") +else: + device = torch.device("cpu") + + +def alpha_neuron(target, param, channel_index): + def innr(mapping): + acts = mapping[target] + input_value = mapping[param] + _, _, H, W = acts.shape + obj = acts[:, channel_index, H // 2, W // 2] + mean_alpha = input_value[3].mean() + # mean_tv = total_variation(input_value[3:]) + return obj * (1 - mean_alpha) # - mean_tv + + return innr + + +def run(): + net = googlenet(pretrained=True).to(device) + + robustness_transforms = nn.Sequential( + RandomAffine(translate=True, rotate=True), + BlendAlpha(), + ) + + param = NaturalImage((112, 112), channels=4).to(device) + + target = net.mixed3a._1x1 + objective = InputOptimization( + net=net, + input_param=param, + transform=robustness_transforms, + targets=[target, param], + loss_function=alpha_neuron(target, param, channel_index=8), + ) + + optimize(objective, n_steps(128)) + result = objective.input_param() + save(result.detach().numpy().transpose(1, 2, 0), "image.png") + return result.cpu() + + +def example_of_two_stage_optimization(): + # TODO: Objective abstraction doesn't work that well. Should just be nn.Module?? + # TODO: loss functions don't yet compose, thus "alpha_neuron"; dirty as it is + # TODO: eliminate having to pass targets twice: once to objective as targets, once to loss function. Should be one abstraction? Sketch first! + + net = googlenet(pretrained=True).to(device) + param = NaturalImage((112, 112), channels=4).to(device) + target = net.mixed3a._3x3[-1] + channel_index = 76 - 64 + + ignore_transforms = nn.Sequential( + RandomAffine(translate=True, rotate=True), IgnoreAlpha() + ) + objective = InputOptimization( + net=net, + input_param=param, + transform=ignore_transforms, + targets=[target], + loss_function=neuron_activation(target, channel_index=channel_index), + ) + + optimize(objective, n_steps(256)) + + intermediate_result = param() + show(intermediate_result) + + blend_transforms = nn.Sequential( + RandomAffine(translate=True, rotate=True), BlendAlpha() + ) + objective.transform = blend_transforms + objective.targets = [target, param] + objective.loss_function = alpha_neuron(target, param, channel_index=channel_index) + + optimize(objective, n_steps(512)) + + final_result = param() + show(final_result) + + return intermediate_result, final_result + + +if __name__ == "__main__": + example_of_two_stage_optimization() diff --git a/captum/optim/techs/feature_vis_two_stage.py b/captum/optim/techs/feature_vis_two_stage.py new file mode 100644 index 0000000000..5107b076c2 --- /dev/null +++ b/captum/optim/techs/feature_vis_two_stage.py @@ -0,0 +1,91 @@ +from typing import Dict, Callable, Iterable + +from tqdm.auto import tqdm + +import torch +import torch.nn as nn +import torch.optim as optim + +from kornia.losses import total_variation + +from lucid.misc.io import save +from clarity.pytorch.io import show +from clarity.pytorch.models import googlenet +from clarity.pytorch.param import ( + NaturalImage, + RandomAffine, + GaussianSmoothing, + BlendAlpha, + IgnoreAlpha, +) +from clarity.pytorch.optim.objectives import ( + InputOptimization, + single_target_objective, + channel_activation, + neuron_activation, +) +from clarity.pytorch.optim.optimize import optimize, n_steps + +# set device based on availability +if torch.cuda.is_available(): + device = torch.device("cuda:0") +else: + device = torch.device("cpu") + + +# TODO: loss functions don't yet compose, thus "alpha_neuron"; dirty as it is +def alpha_neuron(target, param, channel_index): + def innr(mapping): + acts = mapping[target] + input_value = mapping[param] + _, _, H, W = acts.shape + obj = acts[:, channel_index, H // 2, W // 2] + mean_alpha = input_value[3].mean() + # mean_tv = total_variation(input_value[3:]) + return obj * (1 - mean_alpha) # - mean_tv + + return innr + + +def example_of_two_stage_optimization(): + # TODO: Objective abstraction doesn't work that well. Should just be a nn.Module?? + # TODO: eliminate having to pass targets twice: once to objective as targets, once to loss function. Should be one abstraction? Sketch first! + + net = googlenet(pretrained=True).to(device) + param = NaturalImage((112, 112), channels=4).to(device) + target = net.mixed3a # or more complicated: net.mixed3a._3x3[-1][12] + channel_index = 76 + + ignore_transforms = nn.Sequential( + RandomAffine(translate=True, rotate=True, shear=True, IgnoreAlpha() + ) + objective = InputOptimization( + net=net, + input_param=param, + transform=ignore_transforms, + targets=[target], + loss_function=neuron_activation(target, channel_index), + ) + + optimize(objective, n_steps(256)) + + intermediate_result = param() + show(intermediate_result) + + blend_transforms = nn.Sequential( + RandomAffine(translate=True, rotate=True, shear=True), BlendAlpha() + ) + objective.transform = blend_transforms + objective.targets = [target, param] + objective.loss_function = alpha_neuron(target, param, channel_index) + + optimize(objective, n_steps(512)) + + final_result = param() + show(final_result) + + return intermediate_result, final_result + + +if __name__ == "__main__": + example_of_two_stage_optimization() diff --git a/captum/optim/techs/optimization_schedule_test.py b/captum/optim/techs/optimization_schedule_test.py new file mode 100644 index 0000000000..0a2d05b66c --- /dev/null +++ b/captum/optim/techs/optimization_schedule_test.py @@ -0,0 +1,118 @@ +# %load_ext autoreload +# %autoreload 2 + +from typing import Dict, Callable, Iterable + +from tqdm.auto import tqdm + +import torch +import torch.nn as nn +import torch.optim as optim + +from kornia.losses import total_variation + +from lucid.misc.io import save +from clarity.pytorch.io import show +from clarity.pytorch.models import googlenet +from clarity.pytorch.param import ( + NaturalImage, + TransformationRobustness, + RandomAffine, + GaussianSmoothing, + BlendAlpha, +) +from clarity.pytorch.optim.objectives import ( + InputOptimization, + single_target_objective, + channel_activation, + neuron_activation, +) +from clarity.pytorch.optim.optimize import optimize, n_steps + +if torch.cuda.is_available(): + device = torch.device("cuda:0") +else: + device = torch.device("cpu") + + +def alpha_neuron(target, input): + def innr(mapping): + acts = mapping[target] + input_value = mapping[input] + _, _, H, W = acts.shape + obj = acts[:, 8, H // 2, W // 2] + mean_alpha = input_value[3].mean() + mean_tv = total_variation(input_value[3:]) + return obj * (1 - mean_alpha) - mean_tv + + return innr + + +def run(): + net = googlenet(pretrained=True).to(device) + + robustness_transforms = nn.Sequential( + TransformationRobustness(jitter=True), + TransformationRobustness(jitter=True), + TransformationRobustness(jitter=True), + # GaussianSmoothing(channels=3, kernel_size=(3, 3), sigma=1), + # TransformationRobustness(jitter=True), + RandomAffine(), + BlendAlpha(), + ) + + param = NaturalImage((112, 112), channels=4, color_correct=True, normalize=False).to( + device + ) + + target = net.mixed3a._1x1 + objective = InputOptimization( + net=net, + input_param=param, + transform=robustness_transforms, + targets=[target, param], + loss_function=alpha_neuron(target, param), + ) + + optimize(objective, n_steps(128)) + result = objective.input_param() + save(result.detach().numpy().transpose(1, 2, 0), "image.png") + return result.cpu() + + +run() + + +def two_stag_opt(): + + net = googlenet(pretrained=True).to(device) + + robustness_transforms = nn.Sequential( + RandomAffine(rotate=True, scale=True, translate=True, shear=True), BlendAlpha() + ) + + param = NaturalImage((112, 112), channels=3).to(device) + + target = net.mixed3a._1x1 + objective = InputOptimization( + net=net, + input_param=param, + transform=robustness_transforms, + targets=[target, param], + loss_function=alpha_neuron(target, param), + ) + + optimizer = optim.Adam(objective.parameters(), lr=0.025) + + history = [] + step = 0 + for step in tqdm(range(256)): + optimizer.zero_grad() + + loss_value = objective.loss() + history.append(loss_value.cpu().detach().numpy()) + (-1 * loss_value.mean()).backward() + optimizer.step() + step += 1 + + return history