# CNN test

In [1]:
import os, sys, logging, configparser
# sys.path.append('/'.join(os.getcwd().split('/')[:-2]))

import warnings
warnings.filterwarnings("ignore")
#os.environ['CUDA_VISIBLE_DEVICES']='1'

import torch
from tensorboardX import SummaryWriter
from XAE.logging_daily import logging_daily

is_cuda = torch.cuda.is_available()
device = torch.device('cuda' if is_cuda else 'cpu')

logger = logging_daily('./config/log_info.yaml')
log = logger.get_logging()
log.setLevel(logging.INFO)

In [37]:
# from model.exp1.train_wae_gan import WAE_GAN_MNIST

cfg = configparser.ConfigParser()
cfg.read('./config/train_config_cnn.cfg')

['./config/train_config_cnn.cfg']

In [38]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from XAE.util import inc_avg, save_sample_images
from XAE import dataset, sampler

class CNN_MNIST(nn.Module):
    def __init__(self, cfg, log, device = 'cpu', verbose = 1):
        super(CNN_MNIST, self).__init__()
        self.log = log
        if verbose == 1:
            self.log.info('------------------------------------------------------------')
            for key in cfg['train_info']:
                self.log.info('%s : %s' % (key, cfg['train_info'][key]))

            for key in cfg['path_info']:
                self.log.info('%s : %s' % (key, cfg['path_info'][key]))

        self.cfg = cfg
    
        # Concrete Parts
        self.device = device
        self.z_dim = int(cfg['train_info']['z_dim'])
        self.z_sampler = getattr(sampler, cfg['train_info']['z_sampler']) # generate prior

        data_class = getattr(dataset, cfg['train_info']['train_data'])
        labeled = cfg['train_info'].getboolean('train_data_label')
        self.validate_batch = cfg['train_info'].getboolean('validate')
        try:
            self.train_data =  data_class(cfg['path_info']['data_home'], train = True, label = labeled)
            self.test_data = data_class(cfg['path_info']['data_home'], train = False, label = labeled)

            self.batch_size = int(cfg['train_info']['batch_size'])
            if cfg['train_info'].getboolean('replace'):
                it = int(cfg['train_info']['iter_per_epoch'])
                train_sampler = torch.utils.data.RandomSampler(self.train_data, replacement = True, num_samples = self.batch_size * it)
                self.train_generator = torch.utils.data.DataLoader(self.train_data, self.batch_size, num_workers = 5, sampler = train_sampler, pin_memory=True)
            else:
                self.train_generator = torch.utils.data.DataLoader(self.train_data, self.batch_size, num_workers = 5, shuffle = True, pin_memory=True, drop_last=True)
            self.test_generator = torch.utils.data.DataLoader(self.test_data, self.batch_size, num_workers = 5, shuffle = False, pin_memory=True, drop_last=True)
        except KeyError:
            pass
            
        self.save_best = cfg['train_info'].getboolean('save_best')
        self.save_path = cfg['path_info']['save_path']
        self.tensorboard_dir = cfg['path_info']['tb_logs']
        self.save_img_path = cfg['path_info']['save_img_path']
        self.save_state = cfg['path_info']['save_state']
        
        self.encoder_pretrain = cfg['train_info'].getboolean('encoder_pretrain')
        if self.encoder_pretrain:
            self.encoder_pretrain_batch_size = int(cfg['train_info']['encoder_pretrain_batch_size'])
            self.encoder_pretrain_step = int(cfg['train_info']['encoder_pretrain_max_step'])
            self.pretrain_generator = torch.utils.data.DataLoader(self.train_data, self.encoder_pretrain_batch_size, num_workers = 5, shuffle = True, pin_memory=True, drop_last=True)

        
        self.lr = float(cfg['train_info']['lr'])
        self.beta1 = float(cfg['train_info']['beta1'])
        self.lamb = float(cfg['train_info']['lambda'])
        self.lr_schedule = cfg['train_info']['lr_schedule']
        self.num_epoch = int(cfg['train_info']['epoch'])

        # Abstract Parts need overriding
        self.d = 64
        d = self.d
        self.embed_data = nn.Sequential(
            nn.Conv2d(1, d, kernel_size = 4, stride = 2, padding = 1, bias = False),
            nn.BatchNorm2d(d),
            nn.ReLU(True),

            nn.Conv2d(d, d, kernel_size = 4, padding = 'same', bias = False),
            nn.BatchNorm2d(d),
            nn.ReLU(True),

            nn.Conv2d(d, 2*d, kernel_size = 4, stride = 2, padding = 1, bias = False),
            nn.BatchNorm2d(2*d),
            nn.ReLU(True),

            nn.Conv2d(2*d, 2*d, kernel_size = 4, padding = 'same', bias = False),
            nn.BatchNorm2d(2*d),
            nn.ReLU(True),

            nn.Flatten(),
        ).to(device)
        
        self.enc = nn.Sequential(
            nn.Linear(49*2*d, d),
            nn.BatchNorm1d(d),
            nn.ReLU(True),

            nn.Linear(d, d),
            nn.BatchNorm1d(d),
            nn.ReLU(True),
            nn.Linear(d, 10),
            nn.Sigmoid()
            ).to(device)

        self.loss = nn.MSELoss()

        self.encoder_trainable = [self.enc, self.embed_data]
    
    def main_loss(self, x, y):
        return self.loss(x, y)
    
    def encode(self, x):
        return self.enc(self.embed_data(x))
    
    def forward(self, x):
        return self.encode(x)
    
    def lr_scheduler(self, optimizer, decay = 1.0):
        lamb = lambda e: decay
        if self.lr_schedule is "basic":
            lamb = lambda e: 1.0 / (1.0 + decay * e)
        if self.lr_schedule is "manual":
            lamb = lambda e: decay * 1.0 * (0.5 ** (e >= 30)) * (0.2 ** (e >= 50)) * (0.1 ** (e >= 100))
        return optim.lr_scheduler.MultiplicativeLR(optimizer, lamb)

    def train(self, resume = False):
        self.train_main_list = []
        self.test_main_list = []

        for net in self.encoder_trainable:
            net.train()
     
        optimizer = optim.Adam(sum([list(net.parameters()) for net in self.encoder_trainable], []), lr = self.lr, betas = (self.beta1, 0.999))

        start_epoch = 0
        scheduler = self.lr_scheduler(optimizer)

        if resume:
            checkpoint = torch.load(self.save_state)
            start_epoch = checkpoint['epoch']
            self.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            if len(self.lr_schedule) > 0:
                scheduler.load_state_dict(checkpoint['scheduler'])

        self.log.info('------------------------------------------------------------')
        self.log.info('Training Start!')
        start_time = time.time()
        
        for epoch in range(start_epoch, self.num_epoch):
            # train_step
            train_loss_main = inc_avg()
            train_loss_penalty = inc_avg()
            
            for i, data in enumerate(self.train_generator):

                for net in self.encoder_trainable:
                    net.zero_grad()

                n = len(data[0])
                x = data[0].to(self.device)
                y = data[1].to(self.device)
                predicted = self.encode(x)
                
                loss = self.main_loss(y, predicted)
                loss.backward()
                optimizer.step()
                
                train_loss_main.append(loss.item(), n)
                
                print('[%i/%i]\ttrain_main: %.4f' % (i+1, len(self.train_generator), train_loss_main.avg), 
                      end = "\r")

            self.train_main_list.append(train_loss_main.avg)
        
            # validation_step
            test_loss_main = inc_avg()

            if self.validate_batch:
                for i, data in enumerate(self.test_generator):

                    n = len(data[0])
                    x = data[0].to(self.device)
                    y = data[1].to(self.device)
                    predicted = self.encode(x)

                    test_loss_main.append(self.main_loss(y, predicted).item(), n)
                    print('[%i/%i]\ttest_main: %.4f' % (i, len(self.test_generator), test_loss_main.avg), end = "\r")

                self.test_main_list.append(test_loss_main.avg)
                
                self.log.info('[%d/%d]\ttrain_main: %.6e\ttest_main: %.6e'
                      % (epoch + 1, self.num_epoch, train_loss_main.avg, test_loss_main.avg))

            scheduler.step()

        self.log.info('Training Finished!')
        self.log.info("Elapsed time: %.3fs" % (time.time() - start_time))


In [39]:
model = CNN_MNIST(cfg, log, device = device)
model.train()
# model.train(resume = True)

[default |INFO|<ipython-input-38-9e011345e519>:13] ------------------------------------------------------------
[default |INFO|<ipython-input-38-9e011345e519>:15] model_name : CNN
[default |INFO|<ipython-input-38-9e011345e519>:15] train_data : MNIST
[default |INFO|<ipython-input-38-9e011345e519>:15] z_sampler : gaus
[default |INFO|<ipython-input-38-9e011345e519>:15] z_dim : 8
[default |INFO|<ipython-input-38-9e011345e519>:15] y_sampler : multinomial
[default |INFO|<ipython-input-38-9e011345e519>:15] y_dim : 10
[default |INFO|<ipython-input-38-9e011345e519>:15] prob_enc : True
[default |INFO|<ipython-input-38-9e011345e519>:15] train_data_label : True
[default |INFO|<ipython-input-38-9e011345e519>:15] encoder_pretrain : False
[default |INFO|<ipython-input-38-9e011345e519>:15] encoder_pretrain_batch_size : 1000
[default |INFO|<ipython-input-38-9e011345e519>:15] encoder_pretrain_max_step : 200
[default |INFO|<ipython-input-38-9e011345e519>:15] lr : 1e-3
[default |INFO|<ipython-input-38-9e0

[68/100]	test_main: 0.003163

[default |INFO|<ipython-input-38-9e011345e519>:186] [1/40]	train_main: 1.630894e-02	test_main: 2.580448e-03


[68/100]	test_main: 0.002221

[default |INFO|<ipython-input-38-9e011345e519>:186] [2/40]	train_main: 2.129726e-03	test_main: 1.769781e-03


[68/100]	test_main: 0.001814

[default |INFO|<ipython-input-38-9e011345e519>:186] [3/40]	train_main: 1.433445e-03	test_main: 1.409638e-03


[67/100]	test_main: 0.001711

[default |INFO|<ipython-input-38-9e011345e519>:186] [4/40]	train_main: 1.086634e-03	test_main: 1.255981e-03


[65/100]	test_main: 0.001709

[default |INFO|<ipython-input-38-9e011345e519>:186] [5/40]	train_main: 8.753754e-04	test_main: 1.269023e-03


[67/100]	test_main: 0.001508

[default |INFO|<ipython-input-38-9e011345e519>:186] [6/40]	train_main: 7.677241e-04	test_main: 1.152738e-03


[67/100]	test_main: 0.001307

[default |INFO|<ipython-input-38-9e011345e519>:186] [7/40]	train_main: 6.826462e-04	test_main: 1.036847e-03


[68/100]	test_main: 0.001406

[default |INFO|<ipython-input-38-9e011345e519>:186] [8/40]	train_main: 5.774166e-04	test_main: 1.067348e-03


[67/100]	test_main: 0.001305

[default |INFO|<ipython-input-38-9e011345e519>:186] [9/40]	train_main: 5.382020e-04	test_main: 1.008217e-03


[67/100]	test_main: 0.001705

[default |INFO|<ipython-input-38-9e011345e519>:186] [10/40]	train_main: 4.901054e-04	test_main: 1.226150e-03


[68/100]	test_main: 0.001504

[default |INFO|<ipython-input-38-9e011345e519>:186] [11/40]	train_main: 3.539226e-04	test_main: 1.198155e-03


[68/100]	test_main: 0.001404

[default |INFO|<ipython-input-38-9e011345e519>:186] [12/40]	train_main: 4.286752e-04	test_main: 1.090728e-03


[67/100]	test_main: 0.001804

[default |INFO|<ipython-input-38-9e011345e519>:186] [13/40]	train_main: 4.053624e-04	test_main: 1.347016e-03


[67/100]	test_main: 0.001603

[default |INFO|<ipython-input-38-9e011345e519>:186] [14/40]	train_main: 3.067185e-04	test_main: 1.232146e-03


[67/100]	test_main: 0.001503

[default |INFO|<ipython-input-38-9e011345e519>:186] [15/40]	train_main: 2.831184e-04	test_main: 1.069254e-03


[68/100]	test_main: 0.001403

[default |INFO|<ipython-input-38-9e011345e519>:186] [16/40]	train_main: 3.156687e-04	test_main: 1.121161e-03


[67/100]	test_main: 0.001403

[default |INFO|<ipython-input-38-9e011345e519>:186] [17/40]	train_main: 2.619827e-04	test_main: 1.044001e-03


[67/100]	test_main: 0.001402

[default |INFO|<ipython-input-38-9e011345e519>:186] [18/40]	train_main: 2.414495e-04	test_main: 1.086406e-03


[67/100]	test_main: 0.001303

[default |INFO|<ipython-input-38-9e011345e519>:186] [19/40]	train_main: 2.722908e-04	test_main: 9.300032e-04


[67/100]	test_main: 0.001502

[default |INFO|<ipython-input-38-9e011345e519>:186] [20/40]	train_main: 1.939653e-04	test_main: 1.121564e-03


[67/100]	test_main: 0.001202

[default |INFO|<ipython-input-38-9e011345e519>:186] [21/40]	train_main: 1.994157e-04	test_main: 9.408824e-04


[67/100]	test_main: 0.001502

[default |INFO|<ipython-input-38-9e011345e519>:186] [22/40]	train_main: 1.764147e-04	test_main: 1.077299e-03


[67/100]	test_main: 0.001502

[default |INFO|<ipython-input-38-9e011345e519>:186] [23/40]	train_main: 2.284284e-04	test_main: 1.127868e-03


[67/100]	test_main: 0.001402

[default |INFO|<ipython-input-38-9e011345e519>:186] [24/40]	train_main: 2.303643e-04	test_main: 1.089709e-03


[65/100]	test_main: 0.001402

[default |INFO|<ipython-input-38-9e011345e519>:186] [25/40]	train_main: 2.315841e-04	test_main: 9.889146e-04


[67/100]	test_main: 0.001302

[default |INFO|<ipython-input-38-9e011345e519>:186] [26/40]	train_main: 1.592319e-04	test_main: 9.313517e-04


[67/100]	test_main: 0.001401

[default |INFO|<ipython-input-38-9e011345e519>:186] [27/40]	train_main: 1.465711e-04	test_main: 1.095978e-03


[67/100]	test_main: 0.001302

[default |INFO|<ipython-input-38-9e011345e519>:186] [28/40]	train_main: 1.638494e-04	test_main: 1.014288e-03


[67/100]	test_main: 0.001502

[default |INFO|<ipython-input-38-9e011345e519>:186] [29/40]	train_main: 1.540917e-04	test_main: 1.149614e-03


[66/100]	test_main: 0.001202

[default |INFO|<ipython-input-38-9e011345e519>:186] [30/40]	train_main: 1.928252e-04	test_main: 8.625191e-04


[66/100]	test_main: 0.001301

[default |INFO|<ipython-input-38-9e011345e519>:186] [31/40]	train_main: 9.338132e-05	test_main: 9.323479e-04


[68/100]	test_main: 0.001802

[default |INFO|<ipython-input-38-9e011345e519>:186] [32/40]	train_main: 1.820648e-04	test_main: 1.317116e-03


[67/100]	test_main: 0.001402

[default |INFO|<ipython-input-38-9e011345e519>:186] [33/40]	train_main: 1.607171e-04	test_main: 1.034681e-03


[68/100]	test_main: 0.001202

[default |INFO|<ipython-input-38-9e011345e519>:186] [34/40]	train_main: 1.675914e-04	test_main: 9.005978e-04


[68/100]	test_main: 0.001401

[default |INFO|<ipython-input-38-9e011345e519>:186] [35/40]	train_main: 1.074427e-04	test_main: 1.005468e-03


[67/100]	test_main: 0.001501

[default |INFO|<ipython-input-38-9e011345e519>:186] [36/40]	train_main: 1.221686e-04	test_main: 1.103856e-03


[67/100]	test_main: 0.001402

[default |INFO|<ipython-input-38-9e011345e519>:186] [37/40]	train_main: 1.621302e-04	test_main: 9.988523e-04


[67/100]	test_main: 0.001101

[default |INFO|<ipython-input-38-9e011345e519>:186] [38/40]	train_main: 1.193144e-04	test_main: 8.491961e-04


[67/100]	test_main: 0.001301

[default |INFO|<ipython-input-38-9e011345e519>:186] [39/40]	train_main: 1.129870e-04	test_main: 9.250285e-04


[68/100]	test_main: 0.001301

[default |INFO|<ipython-input-38-9e011345e519>:186] [40/40]	train_main: 9.006116e-05	test_main: 1.032913e-03
[default |INFO|<ipython-input-38-9e011345e519>:191] Training Finished!
[default |INFO|<ipython-input-38-9e011345e519>:192] Elapsed time: 272.032s


[69/100]	test_main: 0.0013[70/100]	test_main: 0.0013[71/100]	test_main: 0.0013[72/100]	test_main: 0.0013[73/100]	test_main: 0.0012[74/100]	test_main: 0.0012[75/100]	test_main: 0.0012[76/100]	test_main: 0.0012[77/100]	test_main: 0.0012[78/100]	test_main: 0.0012[79/100]	test_main: 0.0012[80/100]	test_main: 0.0012[81/100]	test_main: 0.0011[82/100]	test_main: 0.0011[83/100]	test_main: 0.0011[84/100]	test_main: 0.0011[85/100]	test_main: 0.0011[86/100]	test_main: 0.0011[87/100]	test_main: 0.0011[88/100]	test_main: 0.0011[89/100]	test_main: 0.0011[90/100]	test_main: 0.0011[91/100]	test_main: 0.0010[92/100]	test_main: 0.0010[93/100]	test_main: 0.0010[94/100]	test_main: 0.0010[95/100]	test_main: 0.0010[96/100]	test_main: 0.0010[97/100]	test_main: 0.0010[98/100]	test_main: 0.0010[99/100]	test_main: 0.0010

In [56]:
from XAE.dataset import MNIST

batch_size = 100
test_data = MNIST(cfg['path_info']['data_home'], train = False, label = True)
test_generator = torch.utils.data.DataLoader(test_data, batch_size, num_workers = 5, shuffle = False, pin_memory=True, drop_last=True)

res = []
for data in test_generator:
    predicted = model(data[0].to(device)).cpu().detach()
    pred_label = torch.topk(predicted, 2, dim=1)[1]
    actual_label = torch.max(data[1],1)[1]
    for i, j in zip(pred_label, actual_label):
        if j in i :
            res.append((i,j))


In [77]:
n=[]
for i in res:
    if 9 == i[0][0]:
        n.append(i[1].numpy())

for i in range(10):
    print(i, n.count(i))


0 0
1 0
2 0
3 0
4 5
5 0
6 0
7 1
8 1
9 996
