In [13]:
from copy import deepcopy
import os

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

from ignite.engine import Engine
from ignite.engine import Events
from ignite.metrics import RunningAverage
from ignite.contrib.handlers.tqdm_logger import ProgressBar

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

In [2]:
from data import MyDataset
from model import MyModel
from engine import MyEngine
from trainer import MyTrainer

### Configuration

In [3]:
config = {
    'MNIST_SAVE_PATH' : os.path.join(os.getcwd(), './data'), 
    'MODEL_FN' : './model.pt',
    'N_EPOCH' : 10,
    'BATCH_SIZE' : 64, 
    'lr' : 1e-4
}

### Dataset

In [4]:
class MyDataset:
    def __init__(self, config):
        self.config = config

    def get_train_dataset(self):
        return datasets.MNIST(root='./data',
                download=True,
                train=True,
                transform=transforms.Compose([
                transforms.Resize(28),
                transforms.ToTensor(),
                transforms.Normalize((.5), (.5))])) 

    def get_test_dataset(self):
        return datasets.MNIST(root='./data', 
                download=True,
                train=False,
                transform= transforms.Compose([
                transforms.Resize(28),
                transforms.ToTensor(),
                transforms.Normalize((.5),(.5))]))
                
    def get_train_loader(self):
        return DataLoader(self.get_train_dataset(),
                        batch_size=self.config['BATCH_SIZE'],
                        shuffle=True)

    def get_test_loader(self):
        return DataLoader(self.get_test_dataset(),
                        batch_size=self.config['BATCH_SIZE'],
                        shuffle=False)

In [5]:
dataset = MyDataset(config)
train_loader = dataset.get_train_loader()
test_loader = dataset.get_test_loader()

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


## model

In [6]:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel,self).__init__()
        self.conv = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.pool = nn.MaxPool2d(kernel_size=2) 

        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.flat = nn.Flatten()

        self.dense = nn.Linear(in_features = 256, out_features=120)
        self.dense2 = nn.Linear(in_features = 120, out_features=84)
        self.dense3 = nn.Linear(in_features = 84, out_features=10)

        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.conv(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = self.pool2(x)

        x = self.flat(x)

        x = self.dense(x)
        x = self.relu(x)

        x = self.dense2(x)
        x = self.relu(x)

        x = self.dense3(x)
        return F.softmax(x, dim = 1)

In [7]:
model = MyModel()
print(model)

MyModel(
  (conv): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flat): Flatten(start_dim=1, end_dim=-1)
  (dense): Linear(in_features=256, out_features=120, bias=True)
  (dense2): Linear(in_features=120, out_features=84, bias=True)
  (dense3): Linear(in_features=84, out_features=10, bias=True)
  (relu): ReLU()
)


## Create Instance

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), 
            lr = config['lr'])

## Trainer

In [9]:

class MyEngine(Engine):
    def __init__(self, func, model, criterion, optimizer, config):
        """
        func : torch ignite를 돌리는 함수
        - 아래에서 구현한 static method를 인수로 넣어주면 됨 (train, test, attach...)
        - 그 외에 부분 (model, criterion, optimizer, config)은 직접 지정
        """        
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        if 'MODEL_FN' not in config.keys():config['MODEL_FN'] = './model.pt'
        self.config = config
        super().__init__(func)

        self.best_loss = np.inf
        self.best_model = None
        
        self.device = next(model.parameters()).device

    @staticmethod
    def train(engine, batch):
        def acc(prediction, label):                     # 정확도
            assert prediction.size(0) == label.size(0)
            batch_size = prediction.size(0)
            _, prediction = torch.max(prediction.data, 1)
            return float((prediction == label).sum().item()/batch_size)

        engine.model.train()                            # test 모드
        engine.optimizer.zero_grad()                    # optimizer 초기화

        device = engine.device
        image, label = batch
        batch_size = image.size(0)
        if image.shape != [batch_size, 1, 28, 28]:      # 채널길이
            image=image.reshape(batch_size, 1, 28, 28)
        image = image.to(device)
        label = label.to(device)  

        prediction = engine.model(image)                 # 예측
        loss = engine.criterion(prediction, label)
        loss.backward()                                 # gradient계산
        engine.optimizer.step()                         # 업데이트
      
        return {'loss' : float(loss), 'accuracy' : acc(prediction, label)}      # metric 반환

    @staticmethod
    def test(engine, batch):
        
        def acc(prediction, label):                     # 정확도
            assert prediction.size(0) == label.size(0)
            batch_size = prediction.size(0)
            _, prediction = torch.max(prediction.data, 1)
            return float((prediction == label).sum().item()/batch_size)

        engine.model.eval()                             # 평가 모드
        with torch.no_grad():                           # gradient 계산 x (이래야 더 빠름)
            device = engine.device
            image, label = batch
            batch_size = image.size(0)          
            if image.shape != [batch_size, 1, 28, 28]:  # 채널길이
                image=image.reshape(batch_size, 1, 28, 28)
            image = image.to(device)
            label = label.to(device)
            prediction = engine.model(image)
            loss = engine.criterion(prediction, label)        

            return {'loss' : float(loss), 'accuracy' : acc(prediction, label)}   # metric 반환

    @staticmethod
    def attach(train_engine, test_engine):
        """
        engine에 status bar와 progress bar를 붙여주는 부분
        """
        
        metrics = ['loss', 'accuracy']

        def attach_running_average(engine, metric):
            RunningAverage(output_transform=lambda x:x[metric]).attach(engine, metric,)
        for metric in metrics:
            attach_running_average(train_engine, metric)
            attach_running_average(test_engine, metric)

        # train
        ProgressBar(bar_format=None, ncols=100).attach(train_engine, metrics)
        @train_engine.on(Events.EPOCH_COMPLETED)
        def print_train_status(engine):
            print(f"Train | epoch {engine.state.epoch} - \
                loss : {round(float(engine.state.metrics['loss']), 2)} - \
                    accuracy : {round(float(engine.state.metrics['accuracy']*100), 2)} %")

        # test
        ProgressBar(bar_format=None, ncols=100).attach(test_engine, metrics)
        @test_engine.on(Events.EPOCH_COMPLETED)
        def print_test_status(engine):
            print(f"Test | - loss : {round(float(engine.state.metrics['loss']), 2)} - \
                    accuracy : {round(float(engine.state.metrics['accuracy']*100), 2)} %")

    @staticmethod
    def check_best(engine):
        """
        가장 loss값이 작은 경우 copy해서 engine.best_model에 저장
        """
        current_loss = float(engine.state.metrics['loss'])
        if current_loss < engine.best_loss:
            engine.best_loss = current_loss    
            engine.best_model = deepcopy(engine.model.state_dict())
        print(f"\n model updated  - loss : {round(engine.best_loss, 4)} ")
        
    @staticmethod
    def save_model(engine, train_engine, config, **kwargs):
        torch.save({'model' : engine.best_model, 'config' : config, **kwargs},
            config.MODEL_FN)

In [10]:
class MyTrainer:
    def __init__(self, config):
        if 'N_EPOCH' not in config.keys():
            self.config['N_EPOCH'] = 20
        self.config = config
    
    def train(self, model, criterion, optimizer, train_loader, valid_loader):
        
        # engine
        train_engine = MyEngine(func=MyEngine.train, model=model, criterion=criterion,\
            optimizer=optimizer, config=self.config)
        valid_engine = MyEngine(func=MyEngine.test, model=model, criterion=criterion,\
            optimizer=optimizer, config=self.config)
        
        # attach
        MyEngine.attach(train_engine=train_engine, test_engine=valid_engine)
        
        # add event handler
        """
            매 epoch마다 validation을 하고, 성능이 제일 좋은 (loss가 작은) 모델을 best_model로 deep copy
        """
        def run_validation(engine, valid_engine, valid_loader):
            valid_engine.run(valid_loader, max_epochs=1)
        train_engine.add_event_handler(Events.EPOCH_COMPLETED, run_validation, valid_engine, valid_loader)
        valid_engine.add_event_handler(Events.EPOCH_COMPLETED, MyEngine.check_best)     
        
        # train
        train_engine.run(train_loader, max_epochs=self.config['N_EPOCH'])
    
        model.load_state_dict(valid_engine.best_model)
        return model

In [11]:
trainer = MyTrainer(config)

In [14]:
best_model = trainer.train(model=model, criterion=criterion, optimizer=optimizer, \
    train_loader=train_loader, valid_loader=test_loader)



Train | epoch 1 -                 loss : 1.49 -                     accuracy : 97.38 %




Test | - loss : 1.48 -                     accuracy : 98.31 %

 model updated  - loss : 1.4791 




Train | epoch 2 -                 loss : 1.49 -                     accuracy : 97.82 %




Test | - loss : 1.48 -                     accuracy : 98.35 %

 model updated  - loss : 1.4784 




Train | epoch 3 -                 loss : 1.49 -                     accuracy : 97.52 %




Test | - loss : 1.48 -                     accuracy : 98.64 %

 model updated  - loss : 1.4769 




Train | epoch 4 -                 loss : 1.48 -                     accuracy : 97.96 %




Test | - loss : 1.48 -                     accuracy : 98.58 %

 model updated  - loss : 1.4769 




Train | epoch 5 -                 loss : 1.48 -                     accuracy : 98.2 %




Test | - loss : 1.48 -                     accuracy : 98.53 %

 model updated  - loss : 1.4768 




Train | epoch 6 -                 loss : 1.48 -                     accuracy : 98.2 %




Test | - loss : 1.48 -                     accuracy : 98.64 %

 model updated  - loss : 1.4762 




Train | epoch 7 -                 loss : 1.48 -                     accuracy : 98.62 %




Test | - loss : 1.48 -                     accuracy : 98.71 %

 model updated  - loss : 1.4753 




Train | epoch 8 -                 loss : 1.48 -                     accuracy : 98.55 %




Test | - loss : 1.47 -                     accuracy : 98.74 %

 model updated  - loss : 1.4745 




Train | epoch 9 -                 loss : 1.48 -                     accuracy : 98.22 %




Test | - loss : 1.48 -                     accuracy : 98.7 %

 model updated  - loss : 1.4745 




Train | epoch 10 -                 loss : 1.48 -                     accuracy : 98.67 %


                                                                                                    

Test | - loss : 1.47 -                     accuracy : 98.7 %

 model updated  - loss : 1.4745 


