In [1]:
BASE_DIR = '../'*3
MLPMODEL = BASE_DIR +'code/MLP/MLP_MODEL/mlp_model.ipynb'
%run {MLPMODEL}

In [2]:
class CnnBasicModel(Mlp_Torch):
    def __init__(self, name, dataset, hconfigs, show_maps = False):
        '''
        full =-> ['full',{'width':30}]
        
        conv =->[['conv', {'ksize':3, 'chn':6, 'actfunc':'sigmoid'}], 
                ['max', {'stride':2}], 
                ['conv', {'ksize':3, 'chn':12, 'actfunc':'sigmoid'}], 
                ['max', {'stride':2}], 
                ['conv', {'ksize':3, 'chn':24, 'actfunc':'sigmoid'}], 
                ['avg', {'stride':3}]])
                
        Mlp_Torch.make_layers = mlp_make_layer_torch
        Mlp_Torch.alloc_make_layer = mlp_alloc_make_layer_torch
        Mlp_Torch.train = train_torch
        Mlp_Torch.test = test_torch
        Mlp_Torch.eval_accuracy = mlp_eval_accuracy_torch
        Mlp_Torch.get_estimate = mlp_get_estimate_torch
        Mlp_Torch.visualize = mlp_model_visualize_torch
        Mlp_Torch.prtinfo = prt_model_optim
        Mlp_Torch.get_optim = get_optim
        Mlp_Torch.forward_postproc = forward_postproc
        Mlp_Torch.forward_extra_cost = mlp_forward_extra_cost
                
        '''
        if isinstance(hconfigs, list) and \
        not isinstance(hconfigs[0], (list, int)):
            hconfigs = [hconfigs]
        self.show_maps = show_maps
        self.need_maps = False
        self.kernels = [] #!!필요한지 확인
        
        super(CnnBasicModel, self).__init__(name, dataset, hconfigs)
        self.use_adam = True

In [3]:
def cnn_basic_alloc_layer_param(self, input_shape, hconfig):
    layer_type = get_layer_type(hconfig)
    
    m_name = 'alloc_{}_layer'.format(layer_type)
    method = getattr(self, m_name)
    pm, output_shape = method(input_shape, hconfig)

    return pm, output_shape

CnnBasicModel.alloc_layer_param = cnn_basic_alloc_layer_param

In [4]:
def cnn_basic_forward_layer(self, x, hconfig, pm):
    layer_type = get_layer_type(hconfig)
    
    m_name = 'forward_{}_layer'.format(layer_type)
    method = getattr(self, m_name)
    y, aux = method(x, hconfig, pm)
        
    return y, aux

CnnBasicModel.forward_layer = cnn_basic_forward_layer

In [5]:
def cnn_basic_backprop_layer(self, G_y, hconfig, pm, aux):
    layer_type = get_layer_type(hconfig)
    
    m_name = 'backprop_{}_layer'.format(layer_type)
    method = getattr(self, m_name)
    G_input = method(G_y, hconfig, pm, aux)

    return G_input

CnnBasicModel.backprop_layer = cnn_basic_backprop_layer

In [101]:
def cnn_basic_alloc_full_layer(self, input_shape, hconfig):
    pm=[]
    input_cnt = np.prod(input_shape)
    output_cnt = get_conf_param(hconfig, 'width', hconfig)
    
    pm.append(nn.Linear(in_features=input_cnt,out_features=output_cnt))
    
    act=self.activate(hconfig)
    if act != None: pm.append(act)
        
    return pm, [output_cnt]
    
def cnn_basic_alloc_conv_layer(self, input_shape, hconfig):
    pm = []
    assert len(input_shape) == 3
    xchn, xh, xw  = input_shape
#     print(input_shape)
    kh, kw = get_conf_param_2d(hconfig, 'ksize')
    ychn = get_conf_param(hconfig, 'chn')
    actfunc = self.activate(hconfig)

    pm.append(nn.Conv2d(kernel_size=(kh, kw),in_channels=xchn,out_channels=ychn,padding='same'))
    if actfunc != None: pm.append(actfunc)
    
    #!!
#     if self.show_maps: self.kernels.append(kernel)
    
    return pm, [ychn, xh, xw]
    
def cnn_basic_alloc_pool_layer(self, input_shape, hconfig):
    assert len(input_shape) == 3
    xchn, xh, xw  = input_shape
    sh, sw = get_conf_param_2d(hconfig, 'stride')
    

    name = get_layer_type(hconfig)
    
    if name == 'max':
        pm = [nn.MaxPool2d(kernel_size=(sh, sw),stride=(sh,sw))]
    
    elif name == 'avg':
        pm = [nn.AvgPool2d(kernel_size=(sh, sw),stride=(sh,sw))]
        
    assert xh % sh == 0
    assert xw % sw == 0

#     return pm, [xh//sh, xw//sw, xchn]
    return pm, [xchn, xh//sh, xw//sw]


CnnBasicModel.alloc_full_layer = cnn_basic_alloc_full_layer
CnnBasicModel.alloc_conv_layer = cnn_basic_alloc_conv_layer
CnnBasicModel.alloc_max_layer = cnn_basic_alloc_pool_layer
CnnBasicModel.alloc_avg_layer = cnn_basic_alloc_pool_layer

In [8]:
def get_layer_type(hconfig):
    if not isinstance(hconfig, list): return 'full'
    return hconfig[0]

def get_conf_param(hconfig, key, defval = None):
    if not isinstance(hconfig, list): return defval
    if len(hconfig) <= 1: return defval    
    if not key in hconfig[1]: return defval    
    return hconfig[1][key]
    
def get_conf_param_2d(hconfig, key, defval = None):
    if len(hconfig) <= 1: return defval
    if not key in hconfig[1]: return defval
    val = hconfig[1][key]
    if isinstance(val, list): return val
    return [val, val]


In [9]:
def cnn_basic_activate(self, hconfig):
    if hconfig is None: return None
    func = get_conf_param(hconfig, 'actfunc')
    
    if func == 'none':      return 'none'
    elif not isinstance(func, str): return None
    elif func == 'relu':    return nn.ReLU()
    elif func == 'sigmoid': return nn.Sigmoid()
    elif func == 'tanh':    return nn.Tanh()
    else:                   assert 0
    
CnnBasicModel.activate = cnn_basic_activate

In [10]:
class Net(nn.Module):
    def __init__(self, layers):
        
        super(Net, self).__init__()
        
        self.layer1 = nn.Sequential(
            *layers
                                    )
        
    def forward(self, x):
        if isinstance(self.layer1[0],nn.Linear):
            x = torch.flatten(x, 1)
        out = self.layer1(x)
    
        return out

In [None]:
def cnn_basic_visualize(self, num):
    print('Model {} Visualization'.format(self.name))
    
    self.need_maps = self.show_maps
    self.maps = []

    deX, deY = self.dataset.get_visualize_data(num)
    est = self.get_estimate(deX)

    if self.show_maps:
        for kernel in self.kernels:
            kh, kw, xchn, ychn = kernel.shape
            grids = kernel.reshape([kh, kw, -1]).transpose(2, 0, 1)
            draw_images_horz(grids[0:5, :, :])

        for pmap in self.maps:
            draw_images_horz(pmap[:, :, :, 0])
        
    self.dataset.visualize(deX, est, deY)

    self.need_maps = False
    self.maps = None

CnnBasicModel.visualize = cnn_basic_visualize