In [17]:
import torch
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms

import random
import os
import numpy as np

from PIL import Image
import matplotlib.pyplot as plt
import json

In [19]:
img_dir = "./data/seeds/"
for f in os.listdir(img_dir):
    if f[-4:] != "JPEG":
        continue
    img=Image.open(img_dir+f)
    
    print(np.array(img).shape)
    break

(343, 500)


In [165]:
import itertools
import random

def get_img(dir_path):
    
    img = Image.open(path).resize((224,224), Image.LANCZOS)
    img.show()
    return transforms.ToTensor()(img).reshape(-1, 3, 224, 224)

def to_image(img_tensor):
    img = img_tensor.squeeze(0).detach()
    img = img.transpose(0,2).transpose(0,1).numpy()
    return img

def output_table(model):
    d = {}
    def set_table(name):
        def hook(model, i, o):
            d[name] = o
        return hook
    for name, layer in model.named_modules():
        if name == '' or name == 'fc' or isinstance(layer, torch.nn.Sequential) or isinstance(layer, models.resnet.BasicBlock):
            continue
        layer.register_forward_hook(set_table(name))
    return d

def scale(o):
    return (o - o.min()) / (o.max() - o.min())

def update_coverage(coverage_table, output_table, threshold):
    for key in output_table.keys():
        scaled = scale(output_table[key][0])
        scaled = scaled.mean(dim=list(range(len(scaled.shape) - 1)))
        try:
            coverage_table[key] = coverage_table[key] | (scaled > threshold)
        except:
            coverage_table[key] = scaled > threshold
            
def neuron_coverage(coverage_table):
    activated = sum(map(lambda key: coverage_table[key].sum(), coverage_table.keys()))
    neurons = sum(map(lambda key: len(coverage_table[key]), coverage_table.keys()))
    return activated / neurons

def neuron_to_cover(coverage_table):
    to_cover = []
    for layer in coverage_table.keys():
        to_cover.extend(itertools.product([layer], torch.where(coverage_table[layer] == False)[0]))
    return random.choice(to_cover)

def compute_obj1(x, c, out, lambda1):
    loss = sum(o[c] for o in out[1:])
    loss -= out[0][c] * lambda1
    return loss

def compute_obj2(coverage_tables, output_tables):
    neurons_to_cover = [neuron_to_cover(cov) for cov in coverage_tables]
    loss = 0
    for (layer, index), o in zip(neurons_to_cover, output_tables):
        loss += o[layer][..., index].mean()
    return loss

def constraint_light(grad):
    return 1e4 * grad.mean() * torch.ones_like(grad)

In [170]:
class deepXplore:
    def __init__(self, dnns, lambda_1=0.1, lambda_2=0.5, threshold=0.25, s=0.1):
        self.dnns = dnns
        self.output_tables = list(map(output_table, dnns))
        self.coverage_tables = list({} for dnn in dnns)
        self.t = threshold
        self.lambda_1 = lambda_1
        self.lambda_2 = lambda_2
        self.s = s
        
    def generate(self, x):
        gen_x = x.detach().clone()
        gen_x.requires_grad=True
        itr = 0
        while True:
            out = [dnn(gen_x).squeeze() for dnn in self.dnns]
            labels = [o.argmax() for o in out]
            d = self.dnns[0]
            classes = json.load(open("data/imagenet_classes.json"))
            if all(label == labels[0] for label in labels):
                for ct, ot in zip(self.coverage_tables, self.output_tables):
                    update_coverage(ct, ot, self.t)
                obj1 = compute_obj1(gen_x, labels[0], out, self.lambda_1)
                obj2 = compute_obj2(self.coverage_tables, self.output_tables)
                loss = obj1 + self.lambda_2 * obj2
                if itr % 100 == 0:
                    print(itr, loss)
                loss.backward()
                grad = constraint_light(gen_x.grad)
                gen_x = gen_x.detach()
                gen_x += self.s * grad
                gen_x.requires_grad=True
                itr += 1
            else:
                break
        return gen_x, itr

In [117]:
vgg16 = models.vgg16(pretrained=True)
o = output_table(vgg16)
vgg16(torch.randn((1,3, 224,224)))
o.keys()

dict_keys(['features.0', 'features.1', 'features.2', 'features.3', 'features.4', 'features.5', 'features.6', 'features.7', 'features.8', 'features.9', 'features.10', 'features.11', 'features.12', 'features.13', 'features.14', 'features.15', 'features.16', 'features.17', 'features.18', 'features.19', 'features.20', 'features.21', 'features.22', 'features.23', 'features.24', 'features.25', 'features.26', 'features.27', 'features.28', 'features.29', 'features.30', 'avgpool', 'classifier.0', 'classifier.1', 'classifier.2', 'classifier.3', 'classifier.4', 'classifier.5', 'classifier.6'])

In [171]:
seeds = 1
img_dir = "./data/seeds/"
resnet18 = models.resnet18(pretrained=True)
resnet34 = models.resnet34(pretrained=True)
dxp = deepXplore([resnet18, resnet34])

while True:
    data_path = img_dir + random.choice(os.listdir(img_dir))
    try:
        x = preprocess_img(data_path)
    except:
        continue
    gen_x, itr = dxp.generate(x)
    if itr == 0:
        continue
    print(gen_x)
    break

['bucket, pail', 'hook, claw']
['bucket, pail', 'hook, claw']
['bucket, pail', 'hook, claw']
['bucket, pail', 'bucket, pail']
0 tensor(2.2718, grad_fn=<AddBackward0>)
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']
['bucket, pail', 'bucket, pail']


KeyboardInterrupt: 

In [92]:
a=torch.tensor(3.0)
a.requires_grad=True
b=a*3+a
b

tensor(12., grad_fn=<AddBackward0>)

In [93]:
b.backward()
a.grad

tensor(4.)