In [None]:
import skimage
import skimage.io
import numpy as np
import torch
import torch.nn as nn
from dotmap import DotMap
from torch.utils.serialization import load_lua
from operator import xor

In [None]:
class Layer():
    Identity = 0b000000001
    Convolution = 0b000000010
    Batch_Norm = 0b000000100
    ReLU = 0b000001000
    Sequential = 0b000010000
    Max_Pool = 0b000100000
    Add = 0b001000000
    Nearest_Upsample = 0b010000000
    Concat = 0b100000000
    
    @staticmethod
    def to_string(layer):
        return 'Identity' if layer & Layer.Identity else \
                'Convolution' if layer & Layer.Convolution else \
                'Batch_Norm' if layer & Layer.Batch_Norm else \
                'ReLU' if layer & Layer.ReLU else \
                'Sequential' if layer & Layer.Sequential else \
                'Max_Pool' if layer & Layer.Max_Pool else \
                'Add' if layer & Layer.Add else \
                'Nearest_Upsample' if layer & Layer.Nearest_Upsample else \
                'Concat' if layer & Layer.Concat else \
                None
    
    @staticmethod
    def from_name(name):
        return Layer.Identity if name.startswith('nn.Identity') else \
                Layer.Convolution if name.startswith('nn.SpatialConvolution') else \
                Layer.Batch_Norm if name.startswith('nn.SpatialBatchNormalization') else \
                Layer.ReLU if name.startswith('nn.ReLU') else \
                Layer.Sequential if name.startswith('nn.Sequential') else \
                Layer.Max_Pool if name.startswith('nn.SpatialMaxPooling') else \
                Layer.Add if name.startswith('nn.CAddTable') else \
                Layer.Nearest_Upsample if name.startswith('nn.SpatialUpSamplingNearest') else \
                Layer.Concat if name.startswith('torch.legacy.nn.ConcatTable.ConcatTable') else \
                None

In [None]:
class Node:
    def __init__(self, forwardnode, module):
        forwardnode = forwardnode.split('\n')[0]
        self.id, children = forwardnode.split(';')
        self.children = [word for word in children.split(' ') if word]
        self.data = module
        self.op = Layer.from_name(str(self.data))
        
        assert self.op is not None
        assert self.id.isdigit()
        assert all([child.isdigit() for child in self.children])
        
    def __str__(self):
        return '{node}; {operation}'.format(node=self.id, operation=Layer.to_string(self.op))
    
    @staticmethod
    def _get_param(module):
        op = Layer.from_name(str(module))
        if op & Layer.Convolution:
            param = module.weight, module.bias
            return op, param
            
        elif op & Layer.Batch_Norm:
            param = module.running_mean, module.running_var, module.weight, module.bias, module.momentum
            return op, param
            
        elif op & (Layer.Sequential | Layer.Concat):
            sub_modules = [Node._get_param(sub_module) for sub_module in module.modules]
            return op, sub_modules
        
        else:
            param = None
            return op, param
        
    
    def get_param(self):
        return Node._get_param(self.data)
    
    
    @staticmethod
    def _copy_to_convolution(source, target):
        op, param = source
        weight, bias = param

        assert target.weight.shape == weight.shape
        assert target.bias.shape == bias.shape

        target.weight.data = weight
        target.bias.data = bias
        
    
    @staticmethod
    def _copy_to_batch_norm(source, target):
        op, param = source
        running_mean, running_var, weight, bias, momentum = param

        assert target.running_mean.shape == running_mean.shape
        assert target.running_var.shape == running_var.shape
        assert target.weight.shape == weight.shape
        assert target.bias.shape == bias.shape
        assert isinstance(target.momentum, float) and isinstance(momentum, float)

        target.running_mean = running_mean
        target.running_var = running_var
        target.weight.data = weight
        target.bias.data = bias
        target.momentum = momentum
        
    
    @staticmethod
    def _copy_to_residual(source, target):
        op, sub_modules = source
        op, sub_modules = sub_modules[0]
        op, sub_modules = sub_modules[0]

        assert op & Layer.Sequential

        for torch_module, pytorch_module in zip(sub_modules, target.resSeq):
            op, _ = torch_module

            if op & Layer.Batch_Norm:
                Node._copy_to_batch_norm(torch_module, pytorch_module)

            elif op & Layer.ReLU:
                continue

            elif op & Layer.Convolution:
                Node._copy_to_convolution(torch_module, pytorch_module)

            else:
                raise NotImplementedError()

        op, sub_modules = source
        op, sub_modules = sub_modules[0]
        op, param = sub_modules[1]

        if op & Layer.Identity:
            pass

        elif op & Layer.Sequential:
            sub_modules = param
            op, param = sub_modules[0]

            assert op & Layer.Convolution and len(sub_modules) == 1
            
            Node._copy_to_convolution(sub_modules[0], target.conv_skip)
            
    @staticmethod
    def _copy_to(source, target):
        op, param = source
        if op & Layer.Convolution:
            Node._copy_to_convolution(source, target)
        
        elif op & Layer.Batch_Norm:
            Node._copy_to_batch_norm(source, target)
        
        elif op & Layer.Sequential:
            if Node._is_residual(source):
                Node._copy_to_residual(source, target)
            else:
                sub_modules = param
                for torch_module, pytorch_module in zip(sub_modules, target):
                    Node._copy_to(torch_module, pytorch_module)
                raise NotImplementedError()
        
        elif op & (Layer.Identity | Layer.ReLU | Layer.Add | Layer.Max_Pool | Layer.Nearest_Upsample):
            pass
        
        else:
            raise NotImplementedError()
    
    def copy_to(self, target):
        Node._copy_to(self.get_param(), target)
    
    @staticmethod
    def _is_residual(source):
        return True

In [None]:
class Graph:
    def __init__(self):
        forwardnodes, modules = self.load_data()

        self.node = list()
        for forwardnode, module in zip(forwardnodes, modules):
            self.node.append(Node(forwardnode, module))
    
    def load_data(self):
        modules = load_lua('cpu.t7')
        with open('forwardnodes.txt', 'r') as fd:
            lines = fd.readlines()
        forwardnodes = lines[1:]  # The 1st forwardnode is dummy, a input distributor
        return forwardnodes, modules
    
    def find_by_id(self, key):
        for _, node in enumerate(self.node):
            if node.id == key:
                return node
        raise LookupError()
    
    def copy_to_hg(self, first_res_in_torch7, hg):
        res_in_torch7 = [
                        0,  # 64x64 skip
                        2,  # 32x32 res
                        3,  # 32x32 skip
                        5,  # 16x16 res
                        6,  # 16x16 skip
                        8,  # 8x8 res
                        9,  # 8x8 skip
                        11,  # 4x4 res
                        12,  # 4x4 lowest
                        13,  # 4x4 res
                        16,  # 8x8 res
                        19,  # 16x16 res
                        22,  # 32x32 res
                    ]
        res_in_pytorch = [
                        hg.res1[0],
                        hg.res2[0],
                        hg.subHourglass.res1[0],
                        hg.subHourglass.res2[0],
                        hg.subHourglass.subHourglass.res1[0],
                        hg.subHourglass.subHourglass.res2[0],
                        hg.subHourglass.subHourglass.subHourglass.res1[0],
                        hg.subHourglass.subHourglass.subHourglass.res2[0],
                        hg.subHourglass.subHourglass.subHourglass.resWaist[0],
                        hg.subHourglass.subHourglass.subHourglass.res3[0],
                        hg.subHourglass.subHourglass.res3[0],
                        hg.subHourglass.res3[0],
                        hg.res3[0],
                    ]
        for torch7_idx, pytorch_module in zip(res_in_torch7, res_in_pytorch):
            torch7_module = self.node[first_res_in_torch7 + torch7_idx]
            torch7_module.copy_to(pytorch_module)
    
    def copy_to_intermediate(self, first_conv_in_torch7, lin, htmap, llBar, htmapBar):
        self.node[first_conv_in_torch7 + 0].copy_to(lin[0])  # Conv
        self.node[first_conv_in_torch7 + 1].copy_to(lin[1])  # Batch-norm
        self.node[first_conv_in_torch7 + 2].copy_to(lin[2])  # ReLU, ll in Newell's
        
        self.node[first_conv_in_torch7 + 3].copy_to(htmap)  # Conv, tmpOut in Newell's
        
        if llBar == None and htmapBar == None:
            return
            
        self.node[first_conv_in_torch7 + 4].copy_to(llBar)  # Conv, ll_ in Newell's
        self.node[first_conv_in_torch7 + 5].copy_to(htmapBar)  # Conv, tmpOut_ in Newell's

In [None]:
graph = Graph()

In [None]:
class CONFIG:
    nStacks = 8
    nFeatures = 256
    nModules = 1
    nJoints = 16
    nDepth = 4

In [None]:
class ResModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResModule, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        self.resSeq = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels // 2, kernel_size=1),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU(),
            nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU(),
            nn.Conv2d(out_channels // 2, out_channels, kernel_size=1)
        )

    def forward(self, x):
        if self.in_channels != self.out_channels:
            skip = self.conv_skip(x)
        else:
            skip = x

        return skip + self.resSeq(x)

In [None]:
class Hourglass(nn.Module):
    def __init__(self, hg_depth, nFeatures):
        super(Hourglass, self).__init__()
        self.hg_depth = hg_depth
        self.nFeatures = nFeatures
        res1list = [ResModule(nFeatures, nFeatures) for _ in range(CONFIG.nModules)]
        res2list = [ResModule(nFeatures, nFeatures) for _ in range(CONFIG.nModules)]
        res3list = [ResModule(nFeatures, nFeatures) for _ in range(CONFIG.nModules)]
        self.res1 = nn.Sequential(*res1list)
        self.res2 = nn.Sequential(*res2list)
        self.res3 = nn.Sequential(*res3list)
        self.subHourglass = None
        self.resWaist = None
        if self.hg_depth > 1:
            self.subHourglass = Hourglass(self.hg_depth - 1, nFeatures)
        else:
            res_waist_list = [ResModule(nFeatures, nFeatures) for _ in range(CONFIG.nModules)]
            self.resWaist = nn.Sequential(*res_waist_list)

    def forward(self, x):
        up = self.res1(x)
        low1 = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        low1 = self.res2(low1)

        if self.hg_depth > 1:
            low2 = self.subHourglass(low1)
        else:
            low2 = self.resWaist(low1)

        low3 = self.res3(low2)

        low = nn.UpsamplingNearest2d(scale_factor=2)(low3)

        return up + low

In [None]:
class MainModel(nn.Module):
    def __init__(self, in_channels=3):
        super(MainModel, self).__init__()

        self.beforeHourglass = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),
            ResModule(in_channels=64, out_channels=128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ResModule(128, 128),
            ResModule(128, CONFIG.nFeatures)
        )

        self.hgArray = nn.ModuleList([])
        self.linArray = nn.ModuleList([])
        self.htmapArray = nn.ModuleList([])
        self.llBarArray = nn.ModuleList([])
        self.htmapBarArray = nn.ModuleList([])

        for i in range(CONFIG.nStacks):
            self.hgArray.append(Hourglass(CONFIG.nDepth, CONFIG.nFeatures))
            self.linArray.append(self.lin(CONFIG.nFeatures, CONFIG.nFeatures))
            self.htmapArray.append(nn.Conv2d(CONFIG.nFeatures, CONFIG.nJoints, kernel_size=1, stride=1, padding=0))

        for i in range(CONFIG.nStacks - 1):
            self.llBarArray.append(nn.Conv2d(CONFIG.nFeatures, CONFIG.nFeatures, kernel_size=1, stride=1, padding=0))
            self.htmapBarArray.append(nn.Conv2d(CONFIG.nJoints, CONFIG.nFeatures, kernel_size=1, stride=1, padding=0))

    def forward(self, x):
        inter = self.beforeHourglass(x)
        outHeatmap = []

        for i in range(CONFIG.nStacks):
            ll = self.hgArray[i](inter)
            ll = self.linArray[i](ll)
            htmap = self.htmapArray[i](ll)
            outHeatmap.append(htmap)

            if i < CONFIG.nStacks - 1:
                ll_ = self.llBarArray[i](ll)
                htmap_ = self.htmapBarArray[i](htmap)
                inter = inter + ll_ + htmap_

        return outHeatmap

    def lin(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU()
)

In [None]:
sh = MainModel()

In [None]:
for torch_module, pytorch_module in zip(graph.node[1:7+1], sh.beforeHourglass):
    torch_module.copy_to(pytorch_module)

In [None]:
for torch_idx, pytorch_hg in zip(range(8, 8 + 32 * CONFIG.nStacks + 1, 32), sh.hgArray):
    graph.copy_to_hg(first_res_in_torch7=torch_idx, hg=pytorch_hg)

In [None]:
for torch_idx, pytorch_idx in zip(range(33, 33 + 32 * CONFIG.nStacks + 1, 32), range(0, 8, 1)):
    lin = sh.linArray[pytorch_idx]
    htmap = sh.htmapArray[pytorch_idx]
    llBar = sh.llBarArray[pytorch_idx] if pytorch_idx != 7 else None
    htmapBar = sh.htmapBarArray[pytorch_idx] if pytorch_idx != 7 else None
    graph.copy_to_intermediate(first_conv_in_torch7=torch_idx, lin=lin, htmap=htmap, llBar=llBar, htmapBar=htmapBar)

In [None]:
rgb = np.asarray(skimage.img_as_float(skimage.io.imread('asdf.jpg')))

In [None]:
rgb = np.expand_dims(rgb.transpose(2, 0, 1), axis=0)

In [None]:
rgb = torch.Tensor(rgb)

In [None]:
htmaps = sh(rgb)

In [None]:
htmaps = htmaps[-1]

In [None]:
htmaps = htmaps[0]

In [None]:
htmaps.shape

In [None]:
from itertools import product

In [None]:
for x, y in product(range(64), range(64)):
    htmaps[0, y, x] = torch.max(htmaps[:, y, x])

In [None]:
x = np.asarray(htmaps[0, :, :].data)

In [None]:
import imageio

In [None]:
imageio.imwrite('pred.jpg', x)

In [None]:
torch.save(
    {
        'state': sh.state_dict(),
    },
    'torch7.save',
)

In [None]:
def beautify(pair, indent=0):
    op, param = pair
    
    if op == Layer.Convolution:
        weight, bias = param
        sentence = '{indent}{operation}; {shape}'.format(indent='\t'*indent, operation=op.name, shape=weight.shape)
    
    elif op == Layer.Batch_Norm:
        running_mean, running_var = param
        sentence = '{indent}{operation}; {shape}'.format(indent='\t'*indent, operation=op.name, shape=running_mean.shape)
    
    elif op == Layer.Sequential or op == Layer.Concat:
        sentence = '{indent}['.format(indent='\t'*indent)
        for sub_module in param:
            sentence = sentence + '\n' + beautify(sub_module, indent=indent+1)
        sentence = sentence + '\n' + '{indent}]'.format(indent='\t'*indent)
    
    else:
        sentence = '{indent}{operation}'.format(indent='\t'*indent, operation=op.name)
    
    return sentence

In [None]:
idx = 30
print(str(graph.node[idx]))
for child in graph.node[idx].children:
    if child == '2':
        continue
    print('\t', str(graph.find_by_id(child)))