In [1]:
# trainer.py

from copy import deepcopy

import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim

In [2]:
class Trainer():
    
    def __init__(self, model, optimizer, crit):
        self.model = model
        self.optimizer = optimizer
        self.crit = crit
        
        super().__init__()
        
    def _train(self, x, y, config):
        self.model.train()
        
        # Shuffle before begin
        indices = torch.randperm(x, size(0), deivce=x.device) #인덱스 만들기
        # index_select
        # https://pytorch.org/docs/stable/generated/torch.index_select.html?highlight=index#torch.index_select
        x = torch.index_select(x, dim=0, index=indeces).split(config.batch_size, dim=0)
        y = torch.index_select(y, dim=0, index=indices).split(config.batch_size, dim=0)
        
        total_loss = 0
        
        for i, (x_i, y_i) in enumerate(zip(x,y)):
            y_hat_i = self.model(x_i)
            loss_i = self.crit(y_hat, y_i.squeeze())
            
            # Initialize the gradients of the model
            """ 
            순서
            Zero_grad => backward => Step
            """
            self.optimizer.zero_grad() #0으로  optimizer 초기화 시켜주기
            loss_i.backward()
            
            self.optimizer.step()
            
            if config.verbose >= 2:
                print("TRAIN INTERATION(%d/%d): loss=%.4e" % (i+1, len(x), float(loss_i)))
                
            # Don't forget to detach to prevent memory leak
            # float를 안씌우면 tensor가 됨 => 엄청난 메모리 소요
            total_loss += float(loss_i)
            
        return total_loss / len(x) #평균값 
    
    def _validate(self, x, y, config):
        # Turn evaluation mode on
        self.model.eval()
        
        # Turn on the no_grad mode to make more efficinity
        with torch.no_grad(): #gradient 필요 없음
            # Suffle before begin
            indices = torch.randperm(x.size(0), device=x.device)
            x = torch.index_select(x, dim=0, index=indeces).split(config.batch_size, dim=0)
            y = torch.index_select(y, dim=0, index=indices).split(config.batch_size, dim=0)
        
            total_loss = 0
    
            for i, (x_i, y_i) in enumerate(zip(x,y)):
                y_hat_i = self.model(x_i)
                loss_i = self.crit(y_hat, y_i.squeeze())
                
                ## Zero_grad... 할 필요 없이 바로 total_loss 구해주기
                if config.verbose >= 2:
                    print("TRAIN INTERATION(%d/%d): loss=%.4e" % (i+1, len(x), float(loss_i)))
                    
                total_loss += float(loss_i)
                
            return total_loss / len(x)
        
    
    def train(self, train_data, valid_data, config):
        lowest_loss = np.inf
        best_model = None
        
        for epoch_index in range(config.n_epochs):
            train_loss = self._train(train_data[0], train_data[1], config)
            valid_loss = self.validate(valid_data[0], valid_data[1], config)
            
            if valid_loss <= lowest_loss: #가장 낮은 Valid_loss 값 구하기
                lowest_loss = valid_loss
                best_model = deepcopy(self.model.state_dict())
            
            print("Epoch(%d/%d): train_loss = %.4e valid_loss=%.4e lowest_loss=%.4e" % (
                epoch_index +1 ,
                config.n_epochs,
                train_loss,
                valid_loss,
                lowest_loss)
            )
        
        # Restore to best model.
        #best model을 저장
        self.model.load_state_dict(best_model)