# Changing the rules about windows making reflections

This notebook demonstrates a rank-one change that reverses a nontrivial rule that connects the presence of windows with the appearance of countertop reflections.  When the rank-one change is made in layer 6, adding a window will remove a reflection instead of adding it.

To investigate this we proceed in several steps.

1. We loada pretrained progressive GAN for kitchen scenes, and then use GAN dissection to identify units in layer4 that control the presence of windows in the scene.  This analysis is preloaded.
2. We choose a set of generated scenes that have windows or countertops, and we use an interactive tool to create masks that activate or deactivate units to add and remove windows in those scene.  Then using the same tool, we identify reflections on countertops that are affected when windows are added or removed.
3. To identiify if any layer is decisive in connecting reflections to windows, we pass through layers 1 through 8, and use optimization to try to find a change in the layer which reverses the reflections (darkening countertops when windows are added and vice-versa.)  We notice that a change in layer6 achieves this change best.  Furthermore, (not shown here), we notice that the change is very close to a rank-one change.
4. To find an exactly rank-one change to layer6, we repeat the optimization to flip the reflection rule, but project to a rank-one change (using SVD) periodically during the optimization.
5. To demonstrate the rule change, we show an interactive tool that allows windows to be added or removed in sches generated by both the original unchanged model, and the new model where the reflection rule is flipped.

## 1. Load a pretrained kitchen progressive GAN

Also load a precomputed dissection that identifies units responsbile for windows in layer4.

In [None]:
# 1. Load a kitchen model and dissection.

import torch, copy, os, json
from utils import nethook, proggan, zdataset, imgviz, show, upsample, tally, proggan
from utils import renormalize, quickdissect
torch.backends.cudnn.benchmark = True

modelname = 'kitchen'
dirname = 'masks/reflections/progan-%s' % modelname
dis = quickdissect.DissectVis(model=modelname)
device = torch.device('cuda:0')
model = proggan.load_pretrained('kitchen')
model.cuda()
zds = zdataset.z_dataset_for_model(model, size=1000)
os.makedirs(dirname, exist_ok=True)
torch.set_grad_enabled(False)

def rf(fn):
    return os.path.join(dirname, fn)

Run the model to generate 30 images that happen to have reflective countertops.

In [None]:
reflective_imgnums = sorted([
    958, 505, 487, 229, 723, 777, 819, 133,
    285, 275, 246, 594, 244, 21, 524, 770, 788,
    947, 350, 587, 889, 14, 38, 46, 52,
    126, 147, 186, 196, 33
])
zbatch = torch.cat([zds[i][0][None] for i in reflective_imgnums]).cuda()
imgbatch = model(zbatch)
iv = imgviz.ImageVisualizer(140)

show([[i, iv.image(imgbatch[j])] for j, i in enumerate(reflective_imgnums)])

## 2. Interactively create masks for adding and removing windows

The tool allows a user to point at two things:
1. locations where a window should be added or removed
2. locations of shiny countertops that show reflections that change in response to the window.

This tool is used to save or load a set of such masks.

In [None]:
# 3. Collect and freeze a set of interventions both to add and to remove windows.
from utils import labwidget, paintwidget

class InteractionProber(labwidget.Widget):
    def __init__(self, model, state=None):
        super().__init__()
        self.model = model
        self.state = {int(k): v for k,v in state.items()} if state else {}
        self.orig = labwidget.Div()
        self.menu = labwidget.Menu(choices=EXAMPLES).on(
            'selection', self.set_selection)
        self.valuebox = labwidget.Textbox(10.0).on(
            'value', self.rerender)
        self.intervention = paintwidget.PaintWidget().on(
            'mask', self.rerender)
        self.ibutton = labwidget.Button('clear edit').on(
            'click', self.clear_intervention)
        self.revert = paintwidget.PaintWidget().on(
            'mask', self.rerender)
        self.rbutton = labwidget.Button('clear mask').on(
            'click', self.clear_revert)
        self.menu.selection = EXAMPLES[0]
        # self.rerender(None)
    
    def set_selection(self):
        imgnum = int(self.menu.selection)
        if imgnum in self.state:
            record = self.state[imgnum]
            self.valuebox.value, self.intervention.mask, self.revert.mask = [
                record[k] for k in 'value intervention revert'.split()]
        else:
            self.valuebox.value, self.intervention.mask, self.revert.mask = [
                '10.0', '', ''
            ]
        self.rerender()
        
    def rerender(self):
        imgnum = int(self.menu.selection)
        value = float(self.valuebox.value)
        self.state[imgnum] = dict(
            value=self.valuebox.value,
            intervention=self.intervention.mask,
            revert=self.revert.mask)
        self.orig.show(renormalize.as_image(self.render_result(
                imgnum, value, None, None)))
        self.intervention.image = renormalize.as_url(self.render_result(
                imgnum, value, self.intervention.mask, None))
        self.revert.image = renormalize.as_url(self.render_result(
                imgnum, value, self.intervention.mask, self.revert.mask))
        
    def render_result(self, image_num, value, intervention_mask, revert_mask):
        imodel = nethook.InstrumentedModel(copy.deepcopy(self.model))
        layername = 'layer4'
        imodel.retain_layer(layername)
        z = zds[image_num][0][None].cuda()
        orig_out = imodel(z)
        if not intervention_mask:
            return orig_out[0]
        acts = imodel.retained_layer(layername)
        units = dis.top_units(layername, 'window', 20)
        iarea = renormalize.from_url(
            intervention_mask, target='pt', size=acts.shape[2:])[0]
        def editrule(x, name):
            x[:,units] = (
                value * iarea[None].to(x.device) +
                x[:,units] * (1 - iarea[None].to(x.device))
            )
            return x
        imodel.edit_layer(layername, rule=editrule)
        inter_out = imodel(z)
        if not revert_mask:
            return inter_out[0]
        rarea = renormalize.from_url(
            revert_mask, target='pt', size=inter_out.shape[2:]
        )[0][None][None].to(orig_out.device)
        revert_out = (rarea * orig_out) + (1 - rarea) * inter_out
        return revert_out[0]
    
    def clear_intervention(self):
        self.intervention.mask = ''
        self.rerender()

    def clear_revert(self):
        self.revert.mask = ''
        self.rerender()
        
    def widget_html(self):
        return show.html([[self.menu], ['intervention strength:'], [self.valuebox], [[self.orig], [self.intervention, self.ibutton], [self.revert, self.rbutton]]])

EXAMPLES = reflective_imgnums

with open(rf('posneg.json')) as f:
    data = {int(k): v for k,v in json.load(f).items() if int(k) in EXAMPLES}

prober = InteractionProber(model, data)
show(prober)

Three panes are shown in the tool above.  To use it:
1. Seleect an image number.
2. Choose a strength for windows to add or remove (positive to add, negative to remove).
3. In the middle pane, draw windows where they should be added or removed.
4. In the right pane, mark reflective countertops that change in response.

The masks are stored by the tool, and the line of code below can be used to save the masks to a file for use in training.

In [None]:
if False: # True to save changes
    with open(rf('posneg.json'), 'w') as f:
        json.dump(prober.state, f, indent=1)
len(prober.state)

## 3. Use optimization to try to find a change in a layer which reverses the reflections

Which layer is responsbible for the reflection rule?

* To answer this, we ask: what rule would have had to change for all the highlighted reflections
     in these cases to be dark in the first place, without changing the other parts of the image?
* We ask this question by hunting for weight changes restricted to a single layer.
* We repeat this for each layer to see if changes in one layer affect the reflections more than others.

In [None]:
# Ingredients:
print('zbatch', zbatch.shape, zbatch.device)
print('EXAMPLES', EXAMPLES)
print('data', sorted(data.keys()))
print('data contents', data[next(iter(data))].keys())

In [None]:
# More ingredients: get all layer activations when the interventions are
# high and low and etc.

def apply_window_edit(imodel, intervention, layername='layer4'):
    if intervention is  None:
        return 
    units = dis.top_units(layername, 'window', 20)
    imodel.retain_layer(layername)
    orig_out = imodel(zbatch[0][None])
    acts = imodel.retained_layer(layername)
    batch_area = torch.cat([
        renormalize.from_url(
            data[k]['intervention'], target='pt',
            size=acts.shape[2:])[0][None,None]
        for k in sorted(data.keys())]).to(acts.device)
    def editrule(x, name):
        x = x.clone()
        x[:,units] = (
            intervention * batch_area + x[:,units] * (1 - batch_area)
        )
        return x
    imodel.edit_layer(layername, rule=editrule)
        
        
# Some functions to run the network with various interventions.
def render_with_change(intervention=None, return_layers=[]):
    imodel = nethook.InstrumentedModel(copy.deepcopy(model))
    layername = 'layer4'
    imodel.retain_layers(return_layers)
    apply_window_edit(imodel, intervention)
    out = imodel(zbatch)
    #if not revert_mask:
    #    return inter_out[0]
    # rarea = renormalize.from_url(
    #     revert_mask, target='pt', size=inter_out.shape[2:]
    # )[0][None][None].to(orig_out.device)
    # revert_out = (rarea * orig_out) + (1 - rarea) * inter_out
    results = imodel.retained_features()
    results['x'] = out
    return results

ALL_LAYERS = ['layer%s' % i for i in range(1, 15)]
orig_out = render_with_change(None, return_layers=ALL_LAYERS)
high_out = render_with_change(10.0, return_layers=ALL_LAYERS)
low_out = render_with_change(-5.0, return_layers=ALL_LAYERS)

# Use this like
# blended_layer10 = paste_acts(high_out['layer10'], low_out['layer10'], 'revert'/'intervention')

def paste_acts(background, foreground, field):
    batch_area = torch.cat([
            renormalize.from_url(
                data[k][field], target='pt',
                size=foreground.shape[2:])[0][None,None]
            for k in sorted(data.keys())]).to(foreground.device)
    return (batch_area * foreground) + (1 - batch_area) * background    

def render_from_layer(back_all, fore_all, field, layername):
    rendering_model = nethook.subsequence(model, after_layer=layername)
    batch_acts = paste_acts(back_all[layername], fore_all[layername], field)
    return rendering_model(batch_acts)

if True:
    reflect_only_img = render_from_layer(low_out, high_out, 'revert', 'layer8')
    window_only_img = render_from_layer(high_out, low_out, 'revert', 'layer8')
    show([[i, iv.image(orig_out['x'][j]), iv.image(window_only_img[j]), iv.image(reflect_only_img[j])]
          for j, i in enumerate(reflective_imgnums)])


In [None]:
# So now...

# Loss is difference between network and window_only_img at layer8.
# Network is everything up to layer8, but with intervention "high windows" at layer4.

def optimize_layer(optlayer, windowlayer='layer4', targlayer='layer8'):
    # Target feature is "reflections off" at layer8
    high_target_feature = paste_acts(high_out[targlayer], low_out[targlayer], 'revert')
    low_target_feature = paste_acts(low_out[targlayer], high_out[targlayer], 'revert')
    net = nethook.InstrumentedModel(copy.deepcopy(
        nethook.subsequence(model, last_layer=targlayer)))
    render_model = nethook.subsequence(model, after_layer=targlayer)
    def compute_loss():
        apply_window_edit(net, 10.0, windowlayer) # amp up windows to 10 at layer4
        reconstruct = torch.nn.functional.mse_loss(high_target_feature, net(zbatch))
        net.remove_edits()
        apply_window_edit(net, -5.0, windowlayer) # amp up windows to 10 at layer4
        reconstruct += torch.nn.functional.mse_loss(low_target_feature, net(zbatch))
        net.remove_edits()
        return reconstruct # no regularization
    nethook.set_requires_grad(False, net)
    weight = getattr(net.model, optlayer).conv.weight
    nethook.set_requires_grad(True, weight)
    params = [weight]
    optimizer = torch.optim.Adam(params, lr=0.02)
    for t in range(101):
        if t > 0:
            with torch.enable_grad():
                loss = compute_loss()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            if not t % 100:
                with torch.no_grad():
                    print(t, compute_loss().item())
                    imgout = render_model(net(zbatch))
                    show([[iv.image(imgout[i])] for i in range(6)])
    return weight.detach().clone()
unoptimized_weights = {}
optimized_weights = {}

for i in range(1, 9):
    optlayer = 'layer%d' % i
    print(optlayer)
    unoptimized_weights[optlayer] = getattr(model, optlayer).conv.weight.detach().clone()
    optimized_weights[optlayer] = optimize_layer(optlayer)

## 4. Find a rank-one change to layer6 to reverse reflections.

Above we find that layer6 is the best layer to change.  Furthermore, analysis (not shown here) reveals that the needed change in layer6 is very close to rank-one.

Therefore, below, we next seek an exactly rank-one change in layer6 to achieve the objective of reversinig the reflection rule.  This optimization takes some time to run (for 10,000 iterations).


In [None]:
import numpy
def optimize_layer_rank_one(optlayer, windowlayer='layer4', targlayer='layer8'):
    # Target feature is "reflections off" at layer8
    high_target_feature = paste_acts(high_out[targlayer], low_out[targlayer], 'revert')
    low_target_feature = paste_acts(low_out[targlayer], high_out[targlayer], 'revert')
    net = nethook.InstrumentedModel(copy.deepcopy(
        nethook.subsequence(model, last_layer=targlayer)))
    render_model = nethook.subsequence(model, after_layer=targlayer)
    
    def compute_loss():
        apply_window_edit(net, 10.0, windowlayer) # amp up windows to 10 at layer4
        reconstruct = torch.nn.functional.mse_loss(high_target_feature, net(zbatch))
        net.remove_edits()
        apply_window_edit(net, -5.0, windowlayer) # amp up windows to 10 at layer4
        reconstruct += torch.nn.functional.mse_loss(low_target_feature, net(zbatch))
        net.remove_edits()
        return reconstruct # no regularization

    nethook.set_requires_grad(False, net)
    weight = getattr(net.model, optlayer).conv.weight
    orig_weight = weight.clone()
    nethook.set_requires_grad(True, weight)
    params = [weight]
    optimizer = torch.optim.Adam(params, lr=0.01)
    for t in range(10001):
        if t > 0:
            with torch.enable_grad():
                loss = compute_loss()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        if not t % 10:
            with torch.no_grad():
                delta = (weight - orig_weight).detach().permute(1,0,2,3)
                mat = delta.reshape(delta.shape[0], -1)
                u, s, v = mat.svd()
                rankone = (u[:,:1] * s[None,:1]).matmul(v[:,:1].permute(1,0))
                rankone = rankone.view(delta.shape).permute(1,0,2,3)
                weight[...] = orig_weight + rankone
                if not t % 100:
                    print(t, compute_loss().item())
                    imgout = render_model(net(zbatch))
                    show([[iv.image(imgout[i])] for i in range(6)])
    return weight.detach().clone()
r1_unoptimized_weights = {}
r1_optimized_weights = {}

if os.path.isfile(rf('reflection_switched_layer6.npz')):
    loaded = numpy.load(rf('reflection_switched_layer6.npz'))
    r1_unoptimized_weights['layer6'] = torch.from_numpy(loaded['unopt_layer6']).cuda()
    r1_optimized_weights['layer6'] = torch.from_numpy(loaded['opt_layer6']).cuda()
else:    
    for i in range(6, 7):
        optlayer = 'layer%d' % i
        print(optlayer)
        r1_unoptimized_weights[optlayer] = getattr(model, optlayer).conv.weight.detach().clone()
        r1_optimized_weights[optlayer] = optimize_layer_rank_one(optlayer)

In [None]:
# Save the weights.
if False:
    r1_weights = {'unopt_%s' % k: v.cpu().numpy() for k, v in r1_unoptimized_weights.items()}
    r1_weights.update({'opt_%s' % k: v.cpu().numpy() for k, v in r1_optimized_weights.items()})
    numpy.savez(rf('reflection_switched_layer6.npz'), **r1_weights)

## 5. Demonstrate the reversed rule.

By making a single rank-one change to layer6, we have not reversed the reflection rule.

To see the effect, paint a window into the scene - observe that in the original model, this will often create reflections in countertops, but in the modified model, reflections are not added, or are reduced.

In [None]:
class H:
    def __init__(self, html):
        self.html = html
    def _repr_html_(self):
        return self.html

class ModelInterventionComparator(labwidget.Widget):
    def __init__(self, nameA, modelA, nameB, modelB):
        super().__init__()
        self.nameA = nameA
        self.modelA = copy.deepcopy(modelA)
        self.nameB = nameB
        self.modelB = copy.deepcopy(modelB)
        self.imgnumbox = labwidget.Textbox(19, desc="imgnum: ", size=4).on(
            'value', self.clear_intervention)
        self.valuebox = labwidget.Textbox(10.0, desc="&nbsp; intervention strength: ", size=4).on(
            'value', self.rerender)
        self.origA = labwidget.Image()
        self.canvasA = paintwidget.PaintWidget(oneshot=False, brushsize=20).on(
            'mask', self.sync_mask)
        self.origB = labwidget.Image()
        self.canvasB = paintwidget.PaintWidget(oneshot=False, brushsize=20).on(
            'mask', self.sync_mask)
        self.ibutton = labwidget.Button('clear edit').on(
            'click', self.clear_intervention)
        self.rerender()
    
    def sync_mask(self, e):
        # Make masks the same.
        self.canvasA.mask = e.value
        self.canvasB.mask = e.value
        self.rerender()
        
    def rerender(self):
        imgnum = int(self.imgnumbox.value)
        value = float(self.valuebox.value)
        self.origA.render(renormalize.as_image(self.render_result(
                self.modelA, imgnum, value, None)))
        self.canvasA.image = renormalize.as_url(self.render_result(
                self.modelA, imgnum, value, self.canvasA.mask))
        self.origB.render(renormalize.as_image(self.render_result(
                self.modelB, imgnum, value, None)))
        self.canvasB.image = renormalize.as_url(self.render_result(
                self.modelB, imgnum, value, self.canvasB.mask))
        
    def render_result(self, model, image_num, value, intervention_mask):
        with torch.no_grad():
            torch.cuda.empty_cache()
            imodel = nethook.InstrumentedModel(copy.deepcopy(model))
            layername = 'layer4'
            imodel.retain_layer(layername)
            z = zds[image_num][0][None].cuda()
            orig_out = imodel(z)
            if not intervention_mask:
                return orig_out[0]
            acts = imodel.retained_layer(layername)
            units = dis.top_units(layername, 'window', 20)
            iarea = renormalize.from_url(
                intervention_mask, target='pt', size=acts.shape[2:])[0]
            def editrule(x, name):
                x[:,units] = (
                    value * iarea[None].to(x.device) +
                    x[:,units] * (1 - iarea[None].to(x.device))
                )
                return x
            imodel.edit_layer(layername, rule=editrule)
            inter_out = imodel(z)
            return inter_out[0]
    
    def clear_intervention(self):
        self.canvasA.mask = ''
        self.canvasB.mask = ''
        self.rerender()

       
    def widget_html(self):
        return show.html([
            H('<hr style="border: 2px solid gray">'),
            [self.imgnumbox], [self.valuebox], # [self.ibutton],
            H('<hr style="border: 2px solid gray">'),
            [[self.origA], [self.canvasA]],
            self.nameA,
            H('<hr style="border: 2px solid gray">'),
            [[self.origB], [self.canvasB]],
            self.nameB,
            H('<hr style="border: 2px solid gray">'),
        ])

original_model = proggan.load_pretrained('kitchen').cuda()
modified_model = copy.deepcopy(original_model)
with torch.no_grad():
    modified_model.layer6.conv.weight[...] = r1_optimized_weights['layer6']

comparator = ModelInterventionComparator(
    'Original Unchanged Model (Paint windows on right)', original_model,
    'Model With Reflection Rule Inverted', modified_model)
show(comparator)

To use the tool above:
1. Select a generated image to manipulate.
2. Choose a strength for windows to add (positive) or remove (negative)
3. Paint windows on the right-hand-side image.
4. Observe the effect on reflections in both the origintal model and the altered model.