In [90]:
from torchvision import datasets, transforms

dataset = datasets.MNIST(
    '../data', train=True, download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
    ]),
)

x = dataset.data.float() / 255.
y = dataset.targets

In [91]:
x.shape

torch.Size([60000, 28, 28])

In [92]:
type(x)

torch.Tensor

In [93]:
x = x.view(x.size(0), -1)

In [94]:
x.size()

torch.Size([60000, 784])

model.py

In [95]:
import torch
import torch.nn as nn

class MnistClassifier(nn.Module):
    
    def __init__(self,
                 input_size,
                 output_size):
        self.input_size = input_size
        self.output_size = output_size
        
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(input_size, 500),
            nn.LeakyReLU(),
            nn.BatchNorm1d(400),
            
            nn.Linear(400, 300),
            nn.LeakyReLU(),
            nn.BatchNorm1d(300),
            
            nn.Linear(300, 200),
            nn.LeakyReLU(),
            nn.BatchNorm1d(200),

            nn.Linear(200, 100),
            nn.LeakyReLU(),
            nn.BatchNorm1d(100),

            nn.Linear(100, 50),
            nn.LeakyReLU(),
            nn.BatchNorm1d(50),
            
            nn.Linear(50, output_size),
            nn.LogSoftmax(dim=-1)            
        )
        
    def forward(self, x):
        #|x| = (batch_size, input_size)
        y = self.layers(x)
        # |y| = (batch_size, output_size)
        return y 
        

train.py

In [96]:
# cnt 설정
train_ratio = 0.8
train_cnt = int(x.size(0) * train_ratio)
valid_cnt = x.size(0) - train_cnt

In [97]:
# shuffle dataset to split into train/valid set
indices = torch.randperm(x.size(0))

x = torch.index_select(
    x, 
    dim=0,
    index=indices
).split([train_cnt, valid_cnt], dim = 0)


y = torch.index_select(
    y, 
    dim=0,
    index=indices
).split([train_cnt, valid_cnt], dim = 0)

In [98]:
print("Train:", x[0].shape, y[0].shape)
print("Valid:", x[1].shape, y[1].shape)

Train: torch.Size([48000, 784]) torch.Size([48000])
Valid: torch.Size([12000, 784]) torch.Size([12000])


In [99]:
import torch.optim as optim

# model, optimizer, crit 설정 
model = MnistClassifier(28*2, 10)
optimizer = optim.Adam(model.parameters())
crit = nn.NLLLoss()

trainer.py

In [104]:
from copy import deepcopy

import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim


class Trainer():
    
    def __init__(self, model, optimizer, crit):
        self.model = model
        self.optimizer = optimizer 
        self.crit = crit 
        
        super().__init__()
        
    def _train(self, x, y, config):
        self.model.train()
        
        # shuffle before begin
        indices = torch.randperm(x.size(0))
        x = torch.index_select(x, dim=0, index=indices).split(config['batch_size'], dim=0)
        y = torch.index_select(y, dim=0, index=indices).split(config['batch_size'], dim=0)
        
        total_loss = 0
        
        for i, (x_i, y_i) in enumerate(zip(x,y)):
            y_hat_i = self.model(x_i)
            loss_i = self.crit(y_hat_i, y_i.squeeze())
            
            # initialize the gradient of the model.
            self.optimizer.zero_grad()
            loss_i.backward()
            
            self.optimizer.step()
            
            total_loss += float(loss_i)
            
        return total_loss / len(x)
    
    def _validate(self, x, y, config):
        # Turn evaluation mode on. 
        self.model.eval()
        
        # Turn on the no_grad mode to make more efficiently.
        with torch.no_grad():
            # Suffle before begin.
            indices = torch.randperm(x.size(0))
            x = torch.index_select(x, dim=0, index=indices).split(config['batch_size'], dim=0)
            print(x.size())
            y = torch.index_select(y, dim=0, index=indices).split(config['batch_size'], dim=0)
            print(y.size())
            total_loss = 0
            
            for i, (x_i, y_i) in enumerate(zip(x,y)):
                y_hat_i = self.model(x_i)
                loss_i = self.crit(y_hat_i, y_i.squeeze())
                
                total_loss += float(loss_i)
            
            return total_loss / len(x)
    
    def train(self, train_data, valid_data, config):
        lowest_loss = np.inf
        best_model = None 
        
        for epoch_index in range(config['n_epochs']):
            train_loss = self._train(train_data[0], train_data[1], config)
            valid_loss = self._validate(valid_data[0], valid_data[1], config)
            
            # You must use deep copy to take a snapshot of current best weights.
            if valid_loss <= lowest_loss:
                lowest_loss = valid_loss
                best_model = deepcopy(self.model.state_dict())
            
            print("Epoch(%d/%d): train_loss=%.4e valid_loss=%.4e lowest_loss=%.4e" % (
                epoch_index + 1,
                config.n_epochs,
                train_loss,
                valid_loss,
                lowest_loss
            ))
        # Restore to best model.
        self.model.load_state_dict(best_model)
        

train.py

In [123]:
# config.. 
import argparse

config = {
    'batch_size' : 512,
    'n_epochs' : 10
}

In [124]:
x[0].size()

torch.Size([48000, 784])

In [125]:
x[0].size()

torch.Size([48000, 784])

In [126]:
trainer = Trainer(model, optimizer, crit)

trainer.train((x[0], y[0]), (x[1], y[1]), config)


RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x784 and 56x500)

In [None]:

# Save best model weights.
torch.save({
    'model': trainer.model.state_dict(),
    'config': config,
}, config.model_fn)