In [1]:
from torchvision import datasets, transforms

dataset = datasets.MNIST(
    '../data', train=True, download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
    ]),
)

x = dataset.data.float() / 255.
y = dataset.targets

model.py

In [None]:
# v2
import torch 
import torch.nn as nn

class Block(nn.Moudel): 
    
    def __init__(self,
                 input_size,
                 output_size,
                 use_batch_norm=True,
                 dropout_p=.4):
        self.input_size = input_size
        self.output_size = output_size
        self.use_batch_norm = use_batch_norm
        self.dropout_p = dropout_p
        
        super().__init__()
        
        def get_regularizer(use_batch_norm, size):
            return nn.BatchNorm1d(size) if use_batch_norm else nn.Dropout(dropout_p)
        
        self.block = nn.Sequential(
            nn.Linear(input_size, output_size),
            nn.LeakyReLU(),
            get_regularizer(use_batch_norm, output_size)
        )
        
        
    def forward(self, x):
        # |x| = (batch_size, input_size)
        y = self.block(x)
        # |y| = (batch_size, output_size)
        
        return y

class MnistClassifier(nn.Module):
    
    def __init__(self,
                 input_size,
                 output_size,
                 hidden_sizes = [500,400,300,200,100],
                 use_batch_norm = True,
                 dropout_p = .3):
        super().__init__()
        
        assert len(hidden_sizes) > 0, "you need to specify hidden layers"
        
        last_hidden_size = input_size
        blocks = []
        
        for hidden_size in hidden_sizes:
            block += [Block(
                last_hidden_size,
                hidden_size,
                use_batch_norm,
                dropout_p
            )]
            last_hidden_size = hidden_size
        
        self.layers = nn.Sequential(
            *blocks,
            nn.Linear(last_hidden_size, output_size),
            nn.LogSoftmax(dim=-1)   
        )
    
    def forward(self, x):
        # |x| = (batch_size, input_size)        
        y = self.layers(x)
        # |y| = (batch_size, output_size)
        
        return y      

train.py

In [3]:
# cnt 설정
train_ratio = 0.8
train_cnt = int(x.size(0) * train_ratio)
valid_cnt = x.size(0) - train_cnt

In [4]:
# shuffle dataset to split into train/valid set
x = x.view(x.size(0), -1)
indices = torch.randperm(x.size(0))

x = torch.index_select(
    x, 
    dim=0,
    index=indices
).split([train_cnt, valid_cnt], dim = 0)


y = torch.index_select(
    y, 
    dim=0,
    index=indices
).split([train_cnt, valid_cnt], dim = 0)

In [5]:
print("Train:", x[0].shape, y[0].shape)
print("Valid:", x[1].shape, y[1].shape)

Train: torch.Size([48000, 784]) torch.Size([48000])
Valid: torch.Size([12000, 784]) torch.Size([12000])


In [7]:
import torch.optim as optim

# model, optimizer, crit 설정 
model = MnistClassifier(28**2, 10)
optimizer = optim.Adam(model.parameters())
crit = nn.NLLLoss()

trainer.py

In [27]:
from copy import deepcopy

import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim


class Trainer():
    
    def __init__(self, model, optimizer, crit):
        self.model = model
        self.optimizer = optimizer 
        self.crit = crit 
        
        super().__init__()
        
    def _train(self, x, y, config):
        self.model.train()
        
        # shuffle before begin
        indices = torch.randperm(x.size(0))
        x = torch.index_select(x, dim=0, index=indices).split(config['batch_size'], dim=0)
        y = torch.index_select(y, dim=0, index=indices).split(config['batch_size'], dim=0)
        
        total_loss = 0
        
        for i, (x_i, y_i) in enumerate(zip(x,y)):
            y_hat_i = self.model(x_i)
            loss_i = self.crit(y_hat_i, y_i.squeeze())
            
            # initialize the gradient of the model.
            self.optimizer.zero_grad()
            loss_i.backward()
            
            self.optimizer.step()
            
            total_loss += float(loss_i)
            
        return total_loss / len(x)
    
    def _validate(self, x, y, config):
        # Turn evaluation mode on. 
        self.model.eval()
        
        # Turn on the no_grad mode to make more efficiently.
        with torch.no_grad():
            # Suffle before begin.
            indices = torch.randperm(x.size(0))
            x = torch.index_select(x, dim=0, index=indices).split(config['batch_size'], dim=0)
            y = torch.index_select(y, dim=0, index=indices).split(config['batch_size'], dim=0)
            total_loss = 0
            
            for i, (x_i, y_i) in enumerate(zip(x,y)):
                y_hat_i = self.model(x_i)
                loss_i = self.crit(y_hat_i, y_i.squeeze())
                
                total_loss += float(loss_i)
            
            return total_loss / len(x)
    
    def train(self, train_data, valid_data, config):
        lowest_loss = np.inf
        best_model = None
        
        for epoch_index in range(config['n_epochs']):
            train_loss = self._train(train_data[0], train_data[1], config)
            valid_loss = self._validate(valid_data[0], valid_data[1], config)
            
            # You must use deep copy to take a snapshot of current best weights.
            if valid_loss <= lowest_loss:
                lowest_loss = valid_loss
                best_model = deepcopy(self.model.state_dict())
            
            print("Epoch(%d/%d): train_loss=%.4e valid_loss=%.4e lowest_loss=%.4e" % (
                epoch_index + 1,
                config['n_epochs'],
                train_loss,
                valid_loss,
                lowest_loss
            ))
        # Restore to best model.
        self.model.load_state_dict(best_model)
        

train.py

In [28]:
# config.. 
import argparse

config = {
    'batch_size' : 512,
    'n_epochs' : 10,
    'model_fn' : 'model.pth'
}

In [29]:
trainer = Trainer(model, optimizer, crit)

trainer.train((x[0], y[0]), (x[1], y[1]), config)


Epoch(1/10): train_loss=8.8278e-03 valid_loss=9.9897e-02 lowest_loss=9.9897e-02
Epoch(2/10): train_loss=8.0993e-03 valid_loss=9.4777e-02 lowest_loss=9.4777e-02
Epoch(3/10): train_loss=8.8160e-03 valid_loss=1.0403e-01 lowest_loss=9.4777e-02
Epoch(4/10): train_loss=9.1608e-03 valid_loss=9.2695e-02 lowest_loss=9.2695e-02
Epoch(5/10): train_loss=7.6500e-03 valid_loss=8.7332e-02 lowest_loss=8.7332e-02
Epoch(6/10): train_loss=6.2263e-03 valid_loss=9.9235e-02 lowest_loss=8.7332e-02
Epoch(7/10): train_loss=6.4061e-03 valid_loss=1.0266e-01 lowest_loss=8.7332e-02
Epoch(8/10): train_loss=7.7165e-03 valid_loss=9.9298e-02 lowest_loss=8.7332e-02
Epoch(9/10): train_loss=6.6252e-03 valid_loss=8.5051e-02 lowest_loss=8.5051e-02
Epoch(10/10): train_loss=5.4342e-03 valid_loss=9.0510e-02 lowest_loss=8.5051e-02


In [30]:
# Save best model weights.
torch.save({
    'model': trainer.model.state_dict(),
    'config': config,
}, config['model_fn'])

In [32]:
# predict
def load_mnist(is_train=True, flatten=True):
    from torchvision import datasets, transforms

    dataset = datasets.MNIST(
        '../data', train=is_train, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
        ]),
    )

    x = dataset.data.float() / 255.
    y = dataset.targets

    if flatten:
        x = x.view(x.size(0), -1)

    return x, y

In [38]:
def load(fn):
    d = torch.load(fn)
    
    return d['model']

In [34]:
x, y = load_mnist(is_train=False)

In [39]:
model_fn = "./model.pth"

model = MnistClassifier(28**2, 10)
model.load_state_dict(load(model_fn))

<All keys matched successfully>

In [44]:
def test(model, x, y, to_be_shown = True):
    model.eval()
    
    with torch.no_grad():
        y_hat = model(x)
        
        correct_cnt = (y.squeeze() == torch.argmax(y_hat, dim = -1)).sum()
        total_cnt = float(x.size(0))
        
        accuracy = correct_cnt / total_cnt
        print("Accuracy: %.4f" % accuracy)
        

In [45]:
test(model, x[:10], y[:10])

Accuracy: 1.0000
