In [20]:
import os

from torch.utils.data import DataLoader, Dataset

def get_loader(x, y,
               dataset,
               batch_size=256,
               is_train = True,
               train_ratio=.8, 
               valid_ratio=.2,):
        
    assert (train_ratio + valid_ratio) <= 1
    
    if is_train : 
        train_cnt = int(x.size(0) * train_ratio)
        valid_cnt = x.size(0) - train_cnt

        indices = torch.randperm(x.size(0))

        train_x, valid_x = torch.index_select(
            x,
            dim=0,
            index=indices
        ).split([train_cnt, valid_cnt])
        
        train_y, valid_y = torch.index_select(
            y,
            dim=0,
            index=indices
        ).split([train_cnt, valid_cnt])

        train_loader = DataLoader(
            dataset=dataset(train_x, train_y),
            batch_size=batch_size,
            shuffle=True
        )
        valid_loader = DataLoader(
            dataset=dataset(valid_x, valid_y),
            batch_size=batch_size,
            shuffle=True
        )
        
        return train_loader, valid_loader

    else :
        test_loader = DataLoader(
            dataset=dataset(x, y),
            batch_size=batch_size,
            shuffle=False
        )
        
        return test_loader

In [133]:
from copy import deepcopy

import numpy as np

import torch
from torch import nn
from torch import optim
from tqdm import tqdm

class Trainer():
    
    def __init__(self, model, optimizer, crit):
        self.model = model
        self.optimizer = optimizer
        self.crit = crit
    
    def _train(self, batch_item):
        self.model.train()
        
        x = batch_item[0]
        y = batch_item[1]
        
        y_hat = self.model(x)
        loss = self.crit(y_hat, y)

        self.optimizer.zero_grad()
        loss.backward()

        self.optimizer.step()
        
        return float(loss)
    
    def _validate(self, batch_item):
        self.model.eval()
        
        x = batch_item[0]
        y = batch_item[1]
        
        val_loss = 0
        with torch.no_grad():
            y_hat = self.model(x)
            val_loss = self.crit(y_hat, y)
        
        return float(val_loss)
            
    
    def train(self, train_loader, valid_loader, epochs):
        best_loss = np.inf
        best_model = None
        best_epoch = np.inf
        
        for i in range(epochs):
            total_loss, total_val_loss = 0, 0
            
            tqdm_dataset = tqdm(enumerate(train_loader))
            for batch, batch_item in tqdm_dataset:
                batch_loss = self._train(batch_item)
                total_loss += batch_loss
                
                tqdm
                tqdm_dataset.set_postfix({
                    'Epoch': i + 1,
                    'Loss': '%.6f' % batch_loss,
                })
            
            tqdm_dataset = tqdm(enumerate(valid_loader))
            for batch, batch_item in tqdm_dataset:
                batch_loss = self._validate(batch_item)
                total_val_loss += batch_loss
                
                tqdm_dataset.set_postfix({
                    'Epoch': i + 1,
                    'Val Loss': '%.6f' % batch_loss,
                })
            
            total_val_loss = total_val_loss / (batch + 1)
            if total_val_loss <= best_loss:
                best_loss = total_val_loss
                best_epoch = i
                best_model = deepcopy(self.model.state_dict())
         
        print('Best: Epoch= %d  val_loss= %.6f' % (best_epoch + 1, best_loss))
        
        self.model.load_state_dict(best_model)

In [134]:
class CustomDataset(Dataset):
    
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, i):
        x = self.data[i]
        y = self.labels[i]
        return x, y

In [135]:
def load_mnist(is_train=True, flatten=True):
    from torchvision import datasets, transforms

    dataset = datasets.MNIST(
        '../data', train=is_train, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
        ]),
    )

    x = dataset.data.float() / 255.
    y = dataset.targets

    if flatten:
        x = x.view(x.size(0), -1)

    return x, y
x, y = load_mnist()

In [136]:
train_loader, valid_loader = get_loader(x, y, CustomDataset)

In [137]:
x.shape

torch.Size([60000, 784])

In [138]:
model = nn.Sequential(
    nn.Linear(784, 500),
    nn.LeakyReLU(),
    nn.Linear(500, 300),
    nn.LeakyReLU(),
    nn.Linear(300, 100),
    nn.LeakyReLU(),
    nn.Linear(100, 10),
    nn.Softmax()
)
model

Sequential(
  (0): Linear(in_features=784, out_features=500, bias=True)
  (1): LeakyReLU(negative_slope=0.01)
  (2): Linear(in_features=500, out_features=300, bias=True)
  (3): LeakyReLU(negative_slope=0.01)
  (4): Linear(in_features=300, out_features=100, bias=True)
  (5): LeakyReLU(negative_slope=0.01)
  (6): Linear(in_features=100, out_features=10, bias=True)
  (7): Softmax(dim=None)
)

In [139]:
optimizer = optim.Adam(model.parameters())
crit = nn.CrossEntropyLoss()

In [140]:
trainer = Trainer(model, optimizer, crit)

In [141]:
trainer.train(train_loader, valid_loader, epochs=10)

188it [00:03, 61.80it/s, Epoch=1, Loss=1.535540, Mean_Loss=310.845429]
47it [00:00, 138.64it/s, Epoch=1, Val Loss=1.527023, Val Mean_Loss=72.532902]
188it [00:03, 57.97it/s, Epoch=2, Loss=1.505156, Mean_Loss=144.056532]
47it [00:00, 137.03it/s, Epoch=2, Val Loss=1.513799, Val Mean_Loss=35.583687]
188it [00:03, 59.63it/s, Epoch=3, Loss=1.491895, Mean_Loss=94.854366]
47it [00:00, 135.06it/s, Epoch=3, Val Loss=1.535857, Val Mean_Loss=23.644481]
188it [00:03, 59.53it/s, Epoch=4, Loss=1.529430, Mean_Loss=70.618893]
47it [00:00, 126.00it/s, Epoch=4, Val Loss=1.504084, Val Mean_Loss=17.638962]
188it [00:03, 59.46it/s, Epoch=5, Loss=1.476685, Mean_Loss=56.286452]
47it [00:00, 139.47it/s, Epoch=5, Val Loss=1.527619, Val Mean_Loss=14.090008]
188it [00:03, 60.45it/s, Epoch=6, Loss=1.485073, Mean_Loss=46.771618]
47it [00:00, 137.83it/s, Epoch=6, Val Loss=1.489187, Val Mean_Loss=11.736745]
188it [00:03, 60.04it/s, Epoch=7, Loss=1.484641, Mean_Loss=39.955299]
47it [00:00, 135.84it/s, Epoch=7, Val Lo

Best: Epoch= 9  val_loss= 1.491017



