In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import os
import time

In [None]:
class AverageMeter(object):
    """
    Computes and stores the average and
    current value.
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class Trainer(object):
    
    def __init__(self, config, dataloader):
        
        self.config = config
        
        self.model = focusLocNet()
        self.std = config.std
        
        if config.is_train:
            self.train_loader = dataloader
            self.num_train = len(self.train_loader.dataset)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        
        self.start_epoch = 0
        self.epochs = config.epochs
        self.batch_size = 1
        self.lr = config.init_lr
        self.optimizer= optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
        
        self.device = config.device
        self.ckpt_dir = config.ckpt_dir
        self.resume = config.resume
        self.model_name = config.model_name

        
        
    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint()
            
        self.model.train()

        for epoch in range(self.start_epoch, self.epochs):

            print(
                '\nEpoch: {}/{} - LR: {:.6f}'.format(
                    epoch+1, self.epochs, self.lr)
            )

            # train for 1 epoch
            train_loss = self.train_one_epoch(epoch)

            msg1 = "train loss: {:.3f}"
            msg = msg1
            print(msg.format(train_loss))

            self.save_checkpoint(
                {'epoch': epoch + 1,
                 'model_state': self.model.state_dict(),
                 'optim_state': self.optimizer.state_dict(),
                 }
            )
    
    def train_one_epoch(self):
    
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        tic = time.time()
        
        for i, (y_train, dpt) in enumerate(self.train_loader):

            y_train = y_train.to(self.device)
            dpt = dpt.to(self.device)
            
            self.batch_size = y_train.size()[0]
            self.model.init_hidden()
            
            # data shape: y_train (B, Seq, C, H, W)
            log_pi = []
            J_est = []
            J_prev = y_train[:, 0, ...] ## set J_prev to be first frame of the image sequences
            J_est.append(J_prev)

            for t in range(y_train.size()[1]-1):
                # for each time step: estimate, capture and fuse.
                mu, l, p = self.model(J_prev)
                log_pi.append(p)
                I = self.model.getDefocuesImage(l, y_train[:, t+1, ...], dpt[:, t+1, ...])
                J_prev = self.model.fuseTwoImages(I, J_prev)
                J_est.append(J_prev)

            J_est = torch.stack(J_est, dim = 1)

            log_pi = torch.stack(log_pi).transpose(1, 0)
            R = -self.model.reconsLoss(J_est, y_train)
            R = R.unsqueeze(1).repeat(1, y_train.size()[1]-1)

            ## Basic REINFORCE algorithm
            loss = torch.sum(-log_pi*R, dim=1)
            loss = torch.mean(loss, dim=0)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            losses.update(loss.item(), y_train.size()[0])

            # measure elapsed time
            toc = time.time()
            batch_time.update(toc-tic)

        return losses.avg
    
    
    def test(self):
        self.model.eval()
        losses = AverageMeter()
        
        for i, (y_test, dpt) in enumerate(self.test_loader):

            y_test = y_test.to(self.device)
            dpt = dpt.to(self.device)
            
            self.batch_size = y_test.size()[0]
            self.model.init_hidden()
            
            # data shape: y_train (B, Seq, C, H, W)
            log_pi = []
            J_est = []
            J_prev = y_test[:, 0, ...] ## set J_prev to be first frame of the image sequences
            J_est.append(J_prev)

            for t in range(y_test.size()[1]-1):
                # for each time step: estimate, capture and fuse.
                mu, l, p = self.model(J_prev)
                log_pi.append(p)
                I = self.model.getDefocuesImage(l, y_test[:, t+1, ...], dpt[:, t+1, ...])
                J_prev = self.model.fuseTwoImages(I, J_prev)
                J_est.append(J_prev)

            J_est = torch.stack(J_est, dim = 1)

            log_pi = torch.stack(log_pi).transpose(1, 0)
            R = -self.model.reconsLoss(J_est, y_test)
            R = R.unsqueeze(1).repeat(1, y_test.size()[1]-1)

            ## Basic REINFORCE algorithm
            loss = torch.sum(-log_pi*R, dim=1)
            loss = torch.mean(loss, dim=0)
            
            losses.update(loss.item(), y_test.size()[0])

        return losses.avg
    
    def save_checkpoint(self):
        """
        Save a copy of the model so that it can be loaded at a future
        date.
        """
        print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)
        
        print("[*] Saved model to {}".format(self.ckpt_dir))

    def load_checkpoint(self):
        
        print("[*] Loading model from {}".format(self.ckpt_dir))
        
        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])   
        
        print("[*] Loaded model from {}".format(self.ckpt_dir))
    