# Visualizing Neural Networks
**Jin Yeom**  
jin.yeom@hudl.com

In [1]:
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import CIFAR10
from torchvision.transforms import functional as TF
from torchvision.utils import make_grid
from torchsummary import summary
from visdom import Visdom
from tqdm import tqdm_notebook as tqdm

from logbook import LogBook

In [2]:
print("PyTorch Version:", torch.__version__)
print("Torchvision Version:", torchvision.__version__)

PyTorch Version: 1.1.0
Torchvision Version: 0.2.2


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:", device)

device: cuda


For visualization, create a `Visdom` object with the correct server DNS and port.

In [4]:
server = 'http://localhost'
port = 8097
env = 'CIFAR-10_VAE'
viz = Visdom(port=port, server=server, env=env)
print(f"Visdom hosted at {server}:{port}/env/{env}")



Visdom hosted at http://localhost:8097/env/CIFAR-10_VAE


In [5]:
def transform(img):
    if np.random.random() < 0.5:
        img = TF.hflip(img)
    img = TF.to_tensor(img)
    return img

cifar10_train = CIFAR10(
    'datasets/CIFAR-10', 
    train=True, 
    transform=transform, 
    download=True)
train_loader = DataLoader(
    cifar10_train, 
    batch_size=16, 
    shuffle=True, 
    num_workers=4)

cifar10_test = CIFAR10(
    'datasets/CIFAR-10', 
    train=False, 
    transform=transform, 
    download=True)
test_loader = DataLoader(
    cifar10_test, 
    batch_size=16, 
    shuffle=False, 
    num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


Let's start with some images. The cell below should take a batch of training samples and project it on Visdom.

In [6]:
images, labels = iter(train_loader).next()
sample = make_grid(images, nrow=4)
viz.image(sample, opts=dict(
    title='CIFAR-10 sample',
    width=400, 
    height=400
))

'window_375b0a404d32ca'

Next, we're going to train a **variational autoencoder (VAE)** for CIFAR-10, as we focus on visualizing how loss values change and how features in convolution layers evolve as training proceeds.

In [7]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
        self.norm2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 4, stride=2)
        self.norm4 = nn.BatchNorm2d(64)
        self.conv5 = nn.Conv2d(64, 128, 4, stride=2)
        self.norm6 = nn.BatchNorm2d(128)
        self.fc7 = nn.Linear(512, latent_dim)
        self.fc8 = nn.Linear(512, latent_dim)
      
    def forward(self, x):
        x = F.relu(self.norm2(self.conv1(x)), inplace=True)
        x = F.relu(self.norm4(self.conv3(x)), inplace=True)
        x = F.relu(self.norm6(self.conv5(x)), inplace=True)
        x = x.view(x.size(0), -1)
        return self.fc7(x), self.fc8(x)

In [8]:
summary(Encoder(16).to(device), (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 15, 15]           1,568
       BatchNorm2d-2           [-1, 32, 15, 15]              64
            Conv2d-3             [-1, 64, 6, 6]          32,832
       BatchNorm2d-4             [-1, 64, 6, 6]             128
            Conv2d-5            [-1, 128, 2, 2]         131,200
       BatchNorm2d-6            [-1, 128, 2, 2]             256
            Linear-7                   [-1, 16]           8,208
            Linear-8                   [-1, 16]           8,208
Total params: 182,464
Trainable params: 182,464
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.15
Params size (MB): 0.70
Estimated Total Size (MB): 0.86
----------------------------------------------------------------


In [9]:
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 512)
        self.deconv2 = nn.ConvTranspose2d(512, 64, 5, stride=2)
        self.norm3 = nn.BatchNorm2d(64)
        self.deconv4 = nn.ConvTranspose2d(64, 32, 6, stride=2)
        self.norm5 = nn.BatchNorm2d(32)
        self.deconv6 = nn.ConvTranspose2d(32, 3, 6, stride=2)
      
    def forward(self, x):
        x = F.relu(self.fc1(x), inplace=True)
        x = x.view(x.size(0), 512, 1, 1)
        x = F.relu(self.norm3(self.deconv2(x)), inplace=True)
        x = F.relu(self.norm5(self.deconv4(x)), inplace=True)
        return torch.sigmoid(self.deconv6(x))

In [10]:
summary(Decoder(16).to(device), (16,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 512]           8,704
   ConvTranspose2d-2             [-1, 64, 5, 5]         819,264
       BatchNorm2d-3             [-1, 64, 5, 5]             128
   ConvTranspose2d-4           [-1, 32, 14, 14]          73,760
       BatchNorm2d-5           [-1, 32, 14, 14]              64
   ConvTranspose2d-6            [-1, 3, 32, 32]           3,459
Total params: 905,379
Trainable params: 905,379
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.15
Params size (MB): 3.45
Estimated Total Size (MB): 3.60
----------------------------------------------------------------


In [11]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def encode(self, x):
        mean, logvar = self.encoder(x)
        if self.training:
            return self.reparameterize(mean, logvar)
        return mean
      
    def reparameterize(self, mean, logvar):
        stdev = torch.exp(0.5 * logvar)
        eps = torch.randn_like(mean)
        return eps.mul(stdev).add_(mean)
  
    def decode(self, z):
        return self.decoder(z)
      
    def forward(self, x):
        mean, logvar = self.encoder(x)
        z = self.reparameterize(mean, logvar)
        recon_x = self.decode(z)
        return mean, logvar, z, recon_x

In [12]:
model = VAE(16).to(device)
optimizer = optim.Adam(model.parameters())

Now, let's see if we can add model summary to Visdom.

In [13]:
# since Visdom.text uses HTML, we must
# preprocess the text a little bit...
model_str = str(model).replace('\n', '<br>')
model_str = model_str.replace(' ', '&nbsp')
viz.text(model_str)

'window_375b0a41cad768'

In [14]:
def mse(recon_x, x):
    """Mean squared error"""
    mse = (recon_x - x) ** 2
    mse = torch.sum(mse.view(mse.size(0), -1), dim=1)
    return torch.mean(mse, dim=0)

In [15]:
def kld(mean, logvar):
    """KL-divergence"""
    kld = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp())
    kld = torch.mean(kld, dim=1)
    return torch.mean(kld, dim=0)

In [16]:
def validation(test_loader, model):
    model.eval()
    total_loss = 0.0
    for x, _ in test_loader:
        x = x.to(device)
        mean, logvar, z, recon_x = model(x)
        recon_loss = mse(recon_x, x)
        latent_loss = kld(mean, logvar)
        loss = recon_loss + latent_loss
        total_loss += loss.item()
    model.train()
    return total_loss / len(test_loader)

In [17]:
def vis_encoder_conv1(model):
    conv1_w = model.encoder.conv1.weight.detach().cpu()
    return make_grid(conv1_w, nrow=4, normalize=True)

def vis_encoder_conv3(model):
    conv3_w = model.encoder.conv3.weight.detach().cpu()
    out_channels = conv3_w.size(0)
    in_channels = conv3_w.size(1)
    conv3_w = conv3_w.view(out_channels * in_channels, 1, 4, 4)
    return make_grid(conv3_w, nrow=out_channels, normalize=True)

In [None]:
log = LogBook('i', 'mse', 'kld', 'loss')
valid_losses = []

# NOTE: these windows can potentially be 
# wrapped in a `Visualizer` wrapper class
mse_win = None
kld_win = None
loss_win = None
conv1_win = None
conv3_win = None
valid_win = None

global_iter = 0
for ep in tqdm(range(20)):
    for x, _ in tqdm(train_loader, leave=False):
        x = x.to(device)
        mean, logvar, z, recon_x = model(x)
        recon_loss = mse(recon_x, x)
        latent_loss = kld(mean, logvar)
        loss = recon_loss + latent_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # log training progress
        if global_iter % 50 == 0:
            log.record(
                global_iter, 
                recon_loss.item(), 
                latent_loss.item(),
                loss.item())

        # update plot
        if global_iter > 0 and global_iter % 100 == 0:
            mse_win = viz.line(X=log['i'], Y=log['mse'], win=mse_win, opts=dict(
                title='Mean Squared Error',
                xtick=True,
                ytick=True,
                xlabel='training step',
                ylabel='MSE',
            ))
            kld_win = viz.line(X=log['i'], Y=log['kld'], win=kld_win, opts=dict(
                title='KL divergence',
                xtick=True,
                ytick=True,
                xlabel='training step',
                ylabel='KLD',
            ))
            loss_win = viz.line(X=log['i'], Y=log['loss'], win=loss_win, opts=dict(
                title='Training loss',
                xtick=True,
                ytick=True,
                xlabel='training step',
                ylabel='loss',
            ))
            conv1_win = viz.image(vis_encoder_conv1(model), win=conv1_win, opts=dict(
                title='Encoder conv1',
                width=400,
                height=800,
            ))
            conv3_win = viz.image(vis_encoder_conv3(model), win=conv3_win, opts=dict(
                title='Encoder conv3',
                width=1000,
                height=1000,
            ))
            
        if global_iter > 0 and global_iter % 500 == 0:
            valid_losses.append(validation(test_loader, model))
            valid_win = viz.line(valid_losses, win=valid_win, opts=dict(
                title='Validation loss',
                xtick=True,
                ytick=True,
                xlabel='training step',
                ylabel='loss',
            ))
        
        global_iter += 1

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))