In [4]:
import common as common
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pdb

class DDSR(nn.Module):
    def __init__(self, conv=common.default_conv, sr_n_resblocks=12, dm_n_resblock=9, sr_n_feats=64,
                 dm_n_feats=64, kernel_size=3, scale=2, denoise=True, bias=False, act=nn.ReLU()):
        super(DDSR, self).__init__()
        # act = nn.LeakyReLU(negative_slope=0.1, inplace=False)
        # act = nn.PReLU(n_feats)
        if denoise:
            m_sr_head = [common.BasicBlock(5, sr_n_feats, 7, act=act, bn=False, bias=False)]
        else:
            m_sr_head = [common.BasicBlock(4, sr_n_feats, 7, act=act, bn=False, bias=False)]

        # define sr module
        m_sr_resblock = [
            common.ResBlock(
                conv, sr_n_feats, kernel_size, bn=False, act=act, res_scale=1, bias=bias
            ) for _ in range(sr_n_resblocks)
        ]
        m_sr_resblock += [common.BasicBlock(sr_n_feats, sr_n_feats, kernel_size, bias=False)]
        m_sr_up = [common.Upsampler(conv, scale, sr_n_feats, act=act, bias=False)]
        m_sr_up +=  [common.BasicBlock(sr_n_feats, 4, kernel_size, bias=True)]
        # branch for sr_raw output
        m_sr_tail = [nn.PixelShuffle(2)]

        # define demosaick module
        m_dm_head = [common.BasicBlock(4, dm_n_feats, 5, act=act, bn=False, bias=False)]
        m_dm_resblock = [
            common.ResBlock(
                conv, dm_n_feats, kernel_size, bn=False, act=act, res_scale=1, bias=bias
            ) for _ in range(dm_n_resblock)
        ]
        m_dm_resblock += [common.BasicBlock(dm_n_feats, dm_n_feats, kernel_size, bias=False)]
        m_dm_up = [common.Upsampler(conv, 2, dm_n_feats, act= act, bias=True),
                   common.BasicBlock(dm_n_feats, 3, kernel_size, bias=True)]
        # m_tail = [conv(n_feats, 16, kernel_size), nn.PixelShuffle(4)]
        # self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)

        self.sr_head = nn.Sequential(*m_sr_head)
        self.sr_resblock = nn.Sequential(*m_sr_resblock)
        self.sr_up = nn.Sequential(*m_sr_up)
        self.sr_tail = nn.Sequential(*m_sr_tail)
        self.dm_head = nn.Sequential(*m_dm_head)
        self.dm_resblock = nn.Sequential(*m_dm_resblock)        
        self.dm_up = nn.Sequential(*m_dm_up)   

        # for m in self.modules():
        #     # pdb.set_trace()
        #     if isinstance(m, nn.Conv2d):
        #         n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels    # MSRA
        #         m.weight.data.normal_(0., math.sqrt(2. / n))
        #         # nn.init.xavier_normal(m.weight) # Xavier
        #         m.weight.requires_grad = True
        #         if m.bias is not None:
        #             m.bias.data.zero_()
        #             m.bias.requires_grad = True

    def forward(self, x):
        # x = self.sub_mean(x)
        x = self.sr_head(x)
        x_sr_res = self.sr_resblock(x)
        x_sr_res += x
        x_sr_res = self.sr_up(x_sr_res)
        sr_raw = self.sr_tail(x_sr_res)
        x = self.dm_head(x_sr_res)
        dm_res = self.dm_resblock(x)
        dm_res += x
        rgb = self.dm_up(dm_res)

        # x = self.add_mean(x)
        return sr_raw, rgb
        
    def demosaick_layer(self, raw):
        PS = nn.PixelShuffle(2)
        demosaick = PS(raw)
        return demosaick


In [7]:
model = DDSR(scale=2)

In [8]:
model.children()

<generator object Module.children at 0x7f3fc0291c50>

In [11]:
list(model.children())

[Sequential(
   (0): BasicBlock(
     (0): Conv2d(5, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
     (1): ReLU()
   )
 ), Sequential(
   (0): ResBlock(
     (body): Sequential(
       (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (1): ReLU()
       (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     )
   )
   (1): ResBlock(
     (body): Sequential(
       (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (1): ReLU()
       (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     )
   )
   (2): ResBlock(
     (body): Sequential(
       (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (1): ReLU()
       (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     )
   )
   (3): ResBlock(
     (body): Sequential(
       (0): Conv2d(64, 64, kernel_siz

In [None]:
model.sr_