In [1]:
class Learner():
    
    def __init__(self):
        self.training_loss_full = []
        self.testing_loss_full = []
        self.validation_loss_full = []
        self.training_loss = []
        self.testing_loss = []
        self.validation_loss = []
        self.train_acc = [] 
        self.test_acc = [] 
        self.val_acc = []
        self.max_train_acc = 0
        
    def train(self, cuda, trainloader, model):
        correct = 0
        loss_ = []
        model.train()
        for data, target in (trainloader):
            if cuda:
                data, target = data.to('cuda:0'), target.to('cuda:0')
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss_.append(loss.item())
            softmax_output = torch.exp(output)
            top_p, top_class = softmax_output.topk(1, dim=1)
            equals = top_class == target.view(*top_class.shape)
            correct += torch.mean(equals.type(torch.FloatTensor)).item()
            loss.backward()
            optimizer.step()
        return loss_, sum(loss_)/len(trainloader), int((correct/len(trainloader))*100)
    
    def test(self, cuda, loader, model):
        correct = 0
        loss_ = []
        model.eval()
        with torch.no_grad():
            for data, target in (loader):
                if cuda:
                    data, target = data.to('cuda:0'), target.to('cuda:0')
                output = model(data)
                loss = criterion(output, target)
                loss_.append(loss.item())
                softmax_output = torch.exp(output)
                top_p, top_class = softmax_output.topk(1, dim=1)
                equals = top_class == target.view(*top_class.shape)
                correct += torch.mean(equals.type(torch.FloatTensor)).item()
            return loss_, sum(loss_)/len(loader), int((correct/len(loader))*100)
    
    def plot(self, list1, list2, list3, title, label1, label2, label3):
        if list1 != []:
            plt.plot(list1, 'b-', label=label1)
        if list2 != []:
            plt.plot(list2, 'r-', label=label2)
        if list3 != []:
            plt.plot(list3, 'y-', label=label3)
        plt.title(title)
        plt.xlabel("Epochs")
        plt.legend()
        plt.show()
        
    def save_model(self, path):
        torch.save(model.state_dict(), os.path.join(path, 'try1.pth'))
    
    def weight_reset(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.LSTM):
            m.reset_parameters()
    
    def fit(self, 
            model, 
            training=True, 
            epoch=5, 
            testing=False, 
            validation=True, 
            plot_loss=True, 
            print_loss=False,
            plot_acc=False,
            print_acc=False, 
            print_all=False, 
            plot_overall_avg=True, 
            cuda=True,
            reset_full_model=False):
        
        if reset_full_model:
            model.apply(self.weight_reset)
        
        if cuda:
            model.to('cuda:0');
        
        for epoch in tqdm(range(epoch)):
            
            
            if training:
                loss_, _loss, training_accuracy = self.train(cuda, trainloader, model)
                self.training_loss_full.extend(loss_)
                self.training_loss.append(_loss)
                self.train_acc.append(training_accuracy)
                if self.train_acc[epoch] > self.max_train_acc:
                    max_acc = self.train_acc[epoch]
                    self.save_model(args.model_save_path)
                    
            if testing:
                loss_, _loss, test_accuracy = self.test(cuda, testloader, model)
                self.testing_loss_full.extend(loss_)
                self.testing_loss.append(_loss)
                self.test_acc.append(test_accuracy)
                
            if validation:
                loss_, _loss, val_accuracy = self.test(cuda, valloader, model)
                self.validation_loss_full.extend(loss_)
                self.validation_loss.append(_loss)
                self.val_acc.append(val_accuracy)
            clear_output()
            
            
            if plot_loss:
                self.plot(self.training_loss, self.validation_loss, self.testing_loss, 'loss', 'train', 'val', 'test')
                
            if plot_acc:
                self.plot(self.train_acc, self.val_acc, self.test_acc, 'acc', 'train', 'val', 'test')
        
            if print_acc:
                print('Training Accuracy:', self.max_train_acc)
                print('Validation Accuracy:', max(self.val_acc))
    
   
    
    def predict(self, 
                model_path,
                cuda=False,
                load_model=False,
                print_acc=True):
        
        if load_model == True:
            model.load_state_dict(torch.load(model_path))
        loss_, _loss, test_accuracy = self.test(cuda, testloader, model)
        print("The test accuracy is {0}%".format(test_accuracy, 0))
        
   
#add model reset
#add transfer learning modules

In [2]:
class Args():
    def __init__(self):
        self.image_path = ""
        self.model_save_path = ""
args = Args()