In [1]:
import ipyvolume as ipv
import numpy as np
import torch
import torch.nn.functional as F

In [2]:
#    7       6
# 3      2   
#    
#    4       5
# 0      1

def new_agg():
    agg = {}
    agg['shift'] = 0
    agg['x'] = []
    agg['y'] = []
    agg['z'] = []
    agg['color'] = []
    agg['triangles'] = []
    return agg

def plot_agg(agg):
    ipv.plot_trisurf(np.concatenate(agg['x']),
                     np.concatenate(agg['y']),
                     np.concatenate(agg['z']),
                     triangles = np.concatenate(agg['triangles']),
                     color = agg['color']
                    )
    
def aggregate_plots(agg, x, y, z, triangles, color):
    s = agg['shift']
    agg['x'].append(x)
    agg['y'].append(y)
    agg['z'].append(z)
    agg['color'].extend([color] * len(x))
    agg['triangles'].append(triangles + s)
    agg['shift'] += len(x)

def cube(agg, x, y, z, width, height, depth, color="gray"):
    pointsx = np.array([0.0, width, width, 0.0, 0.0, width, width, 0.0]) + x
    pointsy = np.array([0.0, 0.0, height, height, 0.0, 0.0, height, height]) + y
    pointsz = np.array([0.0, 0.0, 0.0, 0.0, depth, depth, depth, depth]) + z
    triangles = np.array([0, 1, 2, 0, 2, 3,
                 4, 7, 5, 5, 7, 6,
                 1, 5, 6, 1, 6, 2,
                 0, 3, 4, 4, 3, 7,
                 3, 2, 7, 2, 6, 7,
                 0, 1, 5, 0, 5, 4
                 ], dtype=np.int)
    aggregate_plots(agg, pointsx, pointsy, pointsz, triangles=triangles, color=color)

def draw_tensor(tensor, rgb=False, W=32, H=32, Cskip=1, BSskip=1, BSYskip=1, scale=True):
    tensor = torch.FloatTensor(tensor.cpu())
    while len(tensor.shape) < 5:
        tensor = tensor[None, ...]
    H = min(H, tensor.shape[-2])
    W = min(W, tensor.shape[-1])
    tensor = tensor[::BSYskip, ::BSskip, ::Cskip, :, :]
    
    tensor = torch.stack([F.adaptive_avg_pool2d(t, (H, W)) for t in tensor])
    if scale:
        max_value = (tensor.max(dim=-1)[0].max(dim=-1)[0])[..., None, None]
        min_value = (tensor.min(dim=-1)[0].min(dim=-1)[0])[..., None, None]
        scale_by = (max_value - min_value).clamp_min(0.01)
        tensor = (tensor - min_value) / scale_by
    tensor = tensor.numpy()
    bsy, bs, c, h, w = tensor.shape
    if rgb:
        colors = np.array([[1.0,0.0,0.0], [0.0,1.0,0.0], [0.0, 0.0, 1.0]])
    else:
        colors = np.array([[1.0, 1.0, 1.0]])
    agg = new_agg()
    tx = 1.0
    ty = 1.0
    tc = 1.2
    sx = -(bs * (W * tx +1) - 1 - W) * 0.5
    sy = -(bsy * (H * ty +1) - 1 - H) * 0.5
    for by in range(bsy):
        for b in range(bs):
            for i in range(c):
                for x in range(w):
                    for y in range(h):
                        cube(agg, x * tx + sx + (W * 1.1 +1) * b, (h - y - 1) * ty + sy + (H * 1.1 +1) * by, -i * tc, 1.0, 1.0, 1.0, colors[i % len(colors)] * tensor[by, b, i, y, x])
    ipv.clear()
    ipv.figure(width=400, height=400, controls=False, controls_light=False)
    ipv.style.box_off()        
    plot_agg(agg)
    ipv.xyzlabel("W", "H", "C")
    m = max(w, h, c*1.1)
    ipv.xyzlim(0,m)
    ipv.show()

def draw_image_tensor(*args, **kwargs):
    kwargs['rgb'] = True
    draw_tensor(*args, **kwargs)

In [3]:
from fastai import *
from fastai.vision import *

path = untar_data(URLs.PETS)

path_anno = path/'annotations'
path_img = path/'images'

fnames = get_image_files(path_img)

np.random.seed(2)
pat = r'/([^/]+)_\d+.jpg$'

In [4]:
data = ImageDataBunch\
    .from_name_re(path_img, fnames, pat,
        size=32, bs=4)
x, y = data.one_batch()
draw_image_tensor(x)

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=(0.0, 0.0, 0.0, …

In [5]:
learn = cnn_learner(data, models.resnet34)
conv = learn.model[0][0]
draw_image_tensor(conv.weight)

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=(0.0, 0.0, 0.0, …

In [6]:
out = conv(x.cuda())
draw_tensor(out)

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=(0.0, 0.0, 0.0, …