In [1]:
from collections import namedtuple

import numpy as np

import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
   
from PIL import Image, ImageFilter

import os
import sys
sys.path.append('../')

from model import LapSRN

from torchvision.models import vgg16

In [2]:
loss_model = vgg16(pretrained=True)

In [3]:
loss_model.features

Sequential(
  (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace)
  (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace)
  (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
  (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace)
  (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace)
  (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
  (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace)
  (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace)
  (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace)
  (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
  (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (18): ReLU(inplace)
  

In [10]:
LossOutput = namedtuple("LossOutput", ["relu1_2", "relu2_2", "relu3_3", "relu4_3"])
# https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/3
class LossNetwork(torch.nn.Module):
    def __init__(self, vgg):
        super(LossNetwork, self).__init__()
        self.layer_name_mapping = {
            '3': "relu1_2",
            '8': "relu2_2",
            '15': "relu3_3",
            '22': "relu4_3"
        }
        
        # Drop all layers that are not needed from the model.
        last_layer = max([int(i) for i in self.layer_name_mapping])
        self.vgg_layers = nn.Sequential(*list(vgg.features)[:last_layer+1])
        
    def forward(self, x):
        output = {}
        for name, module in self.vgg_layers._modules.items():
            x = module(x)
            print(name)
            if name in self.layer_name_mapping:
                output[self.layer_name_mapping[name]] = x
        return LossOutput(**output)


In [11]:
loss_features = loss_model.features

In [12]:
loss_net = LossNetwork(loss_model)

In [13]:
loss_net

LossNetwork(
  (vgg_layers): Sequential(
    (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (17): Conv2d (256, 512, kernel_siz

In [14]:
loss_net.forward(torch.autograd.Variable(torch.rand((1,3,256,256))))

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22


LossOutput(relu1_2=Variable containing:
( 0 , 0 ,.,.) = 
   0.0000   0.0000   0.0000  ...    0.0000   0.0000   0.0000
   0.0000   0.0000   0.0000  ...    0.2070   0.0000   1.0183
   0.0000   0.1787   0.0486  ...    0.4734   0.0000   0.7074
            ...               ⋱              ...            
   0.0000   0.2356   0.0000  ...    0.0000   0.0000   0.0000
   0.9285   1.9964   0.0000  ...    0.0000   0.0000   0.2014
   0.6978   2.6261   1.1103  ...    0.0000   0.0000   0.0000

( 0 , 1 ,.,.) = 
   0.4902   0.0000   0.0000  ...    0.0000   0.5289   0.0000
   1.1688   0.0669   0.0000  ...    0.0000   1.8362   0.0000
   0.3109   0.0000   0.0000  ...    0.0000   2.3589   0.0000
            ...               ⋱              ...            
   1.6491   0.0562   0.0000  ...    0.0000   0.2205   0.0000
   1.2490   1.0829   0.7546  ...    0.0000   0.0000   0.0000
   0.0000   0.7419   1.6634  ...    0.2008   0.3117   0.0000

( 0 , 2 ,.,.) = 
   0.0000   0.0000   0.0000  ...    0.0000   0.0000  

In [15]:
# https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion
# Gray-Channel YCbCr to RGB is RGB = (Y, Y, Y)