# impot modules

In [None]:
import numpy as np
import scipy.ndimage as nd
from PIL import Image

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# config

In [None]:
LOAD_PATH = "nsfw.pth"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

deprocess = lambda x: x * torch.Tensor([0.229, 0.224, 0.225]).to(DEVICE) + \
        torch.Tensor([0.485, 0.456, 0.406]).to(DEVICE)

# load model

In [None]:
model = torchvision.models.googlenet()
model.fc = nn.Linear(model.fc.in_features, 148)
model.load_state_dict(torch.load(LOAD_PATH))
model = model.to(DEVICE)
model.eval()

# generate dream image

In [None]:
def make_step(model, x, end, step, objective_fn):
    x.requires_grad = True
    model.zero_grad()

    y = x
    for (name, child) in model.named_children():
        y = child(y)
        if name == end:
            break
    # y = model.inception4d.branch1(y)

    diff = objective_fn(y)
    y.backward(diff)
    x.data = x.data + step / x.grad.data.abs().mean() * x.grad.data

    return x.clone().detach()

In [None]:
def deep_dream(model, base_img, end, iterations, step, octave_scale, num_octave, guide_image=None):
    objective_fn = lambda dst: dst.data
    if guide_image:
        guide = preprocess(guide_image).unsqueeze(0).to(DEVICE)
        model.zero_grad()

        y = guide
        for (name, child) in model.named_children():
            y = child(y)
            if name == end:
                break

        def objective_guide(dst, guide_features=y):
            _, ch, w, h = dst.shape
            x = dst.reshape(ch, -1)
            y = guide_features.reshape(ch, -1)
            A = x.t().mm(y)
            return y[:, A.argmax(1)].reshape(1, ch, w, h)

        objective_fn = objective_guide

    img_tensor = preprocess(base_img).unsqueeze(0)
    octaves = [img_tensor]
    for _ in range(num_octave-1):
        octaves.append(torch.tensor(nd.zoom(octaves[-1], (1, 1, 1./octave_scale, 1./octave_scale), order=1)))

    detail = torch.zeros_like(octaves[-1])
    for octave, octave_base in enumerate(octaves[::-1]):
        h, w = octave_base.shape[-2:]
        if octave > 0:
            h1, w1 = detail.shape[-2:]
            detail = torch.tensor(nd.zoom(detail, (1, 1, 1*h/h1, 1*w/w1), order=1))

        img_tensor = (octave_base + detail).to(DEVICE)
        for _ in range(iterations):
            img_tensor = make_step(model, img_tensor, end, step, objective_fn)

        detail = img_tensor.cpu() - octave_base

    ret = img_tensor.data.squeeze()
    ret = ret.transpose(0, 1)
    ret = ret.transpose(1, 2)
    ret = deprocess(ret).clamp(0, 1)

    return Image.fromarray(np.uint8(ret.cpu() * 255))

In [None]:
image = Image.open('sky.jpg')
guide_image = Image.open('flower.jpg')
deep_dream(model, image, end="inception4b", iterations=10, step=0.3,
           octave_scale=1.4, num_octave=4, guide_image=guide_image)