In [1]:
import glob
import os
from typing import Any, Dict, List, Tuple, Union

import torch
import yaml
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder, VisionDataset

from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from model import CustomVGG
import torchvision

In [2]:
class Hook:
    
    def before_train(self, **kargs):
        pass
    
    def after_train(self):
        pass        
    
    def before_epoch(self):
        pass

    def after_epoch(self, inputs):
        pass

    def before_step(self):
        pass
    
    def after_step(self, inputs):
        pass

In [3]:
class LoggerHook(Hook):

    def before_train(self, **kargs):
        self.state = kargs['state']
        self.max_epochs = kargs['num_epochs']
        print('Start Training...\n')
    
    def before_epoch(self):
        if self.state['mode'] == 'train':
            print('Epoch {}/{}'.format(self.state['epoch'], self.max_epochs ))

    def after_epoch(self, inputs):
        loss, acc, f1 = inputs['loss'], inputs['acc'], inputs['f1'] 
        print('[{}] Epoch: {} Loss: {:.4f} Acc: {:.4f} F1: {:.4f}'.format(self.state['mode'].title(), self.state['epoch'], loss, acc, f1))
        if self.state['mode'] == 'valid':
            print()

    def after_train(self):
        print('Training complete!')
        print('Best f1 score {:.4f} at epoch {}'.format(best_score, best_epoch))

class LossMeter(Hook):
    
    def __init__(self):
        self.needs = ['loss', 'cnt']
        self.writes = []
        
    def before_epoch(self):
        self.loss = 0
        self.cnt = 0
        
    def after_step(self, inputs):
        loss, cnt = inputs['loss'], inputs['cnt']
        self.cnt += cnt
        self.loss += loss * cnt

    def after_epoch(self, inputs):
        inputs['loss'] = self.loss / self.cnt

In [4]:
class ConfusionMatrix:
    """
    T  8  3  6  2
    R  5  9  3  9
    U  0  3  1  4
    E  0  9  9  5
       P  R  E  D
    """
    
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.conf_mtx = np.zeros((num_classes, num_classes), dtype=int)
        
    def accumulate(self, y_pred, y_true):
        self.conf_mtx += np.bincount(y_true * self.num_classes + y_pred, minlength=self.num_classes**2).reshape((self.num_classes, self.num_classes))
        
    def accuracy(self):
        return np.diag(self.conf_mtx).sum() / self.conf_mtx.sum()
    
    def f1_score(self):
        tp = np.diag(self.conf_mtx)
        precision = tp / np.sum(self.conf_mtx, axis=0)
        recall = tp / np.sum(self.conf_mtx, axis=1)
        f1_score = 2 * precision * recall / (precision + recall)
        return np.mean(f1_score)

    def reset(self):
        self.conf_mtx.fill(0)

In [5]:
class ConfusionMatrixMeter(Hook):
    
    def __init__(self, num_classes):
        self.needs = ['pred', 'true']
        self.writes = []
        self.conf_mtx = ConfusionMatrix(num_classes)
        
    def before_epoch(self):
        self.conf_mtx.reset()
        
    def after_step(self, inputs):
        pred, true = inputs['pred'], inputs['true']
        self.conf_mtx.accumulate(pred, true)

    def after_epoch(self, inputs):
        inputs['acc'] = self.conf_mtx.accuracy()
        inputs['f1'] = self.conf_mtx.f1_score()

In [None]:
class ValidateHook(Hook):
    
    def before_train(self, **kargs):
        self.model = kargs['model']
        self.dataloader = kargs['valid_dataloader']
        self.criterion = kargs['criterion']
        self.device = next(iter(self.model.parameters)).device
        self.conf_mtx = ConfusionMatrix(len(self.dataloader.classes))
        
    @torch.no_grad()
    def after_epoch(self, inputs):
        
        running_cnt, running_loss = 0, 0
        self.conf_mtx.reset()
        self.model.eval()

        for inputs, labels in self.dataloader:
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)
            outputs = self.model(inputs)
            _, preds = torch.max(outputs, 1)
            preds, true = preds.cpu().numpy(), labels.cpu().numpy()
            self.conf_mtx.accumulate(preds, true)

            running_cnt += inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / running_cnt
        epoch_acc = self.conf_mtx.accuracy()
        f1 = self.conf_mtx.f1_score()

        print('Acc: {:.4f} F1: {:.4f}'.format(epoch_acc, f1))

In [6]:
class Trainer:
    
    def __init__(self, hooks):
        self.hooks = hooks
        self.state = {}
    
    def train(self, model, dataloaders, criterion, optimizer, num_epochs, **kargs):
        
        self.train_setup(model=model, dataloaders=dataloaders, criterion=criterion, optimizer=optimizer, num_epochs=num_epochs, **kargs)
        
        for self.state['epoch'] in range(1, num_epochs+1):
            for self.state['mode'] in ['train', 'valid']:
                
                model.train()

                for hook in self.hooks:
                    hook.before_epoch()
            
                for inputs in tqdm(dataloaders[self.state['mode']]):

                    for hook in self.hooks:
                        hook.before_step()

                    outputs = self.step_train(inputs) if self.state['mode'] == 'train' else self.step_valid(inputs)

                    for hook in self.hooks:
                        hook.after_step(outputs)

                outputs = {}
                for hook in self.hooks:
                    hook.after_epoch(outputs)
                
            if self.state['end']:
                break
                    
        self.train_end()
        
    def train_setup(self, **kargs):
        self.model = kargs['model'] 
        self.state['epoch'] = 1
        self.state['train'] = True
        self.state['end'] = False
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        for hook in self.hooks:
            hook.before_train(state=self.state, **kargs)
        
    def train_end(self):
        pass

    def step_train(self, inputs):
        
        imgs, labels = inputs
        imgs, labels = imgs.to(self.device), labels.to(self.device)

        outputs = self.model(imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            _, preds = torch.max(outputs, 1)
            
        outputs = {'loss':loss.item(), 'cnt':preds.size(0), 'pred':preds.cpu().numpy(), 'true':labels.cpu().numpy()}
        return outputs  

    def step_valid(self, inputs):

        with torch.no_grad():
            imgs, labels = inputs
            imgs, labels = imgs.to(self.device), labels.to(self.device)

            outputs = self.model(imgs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            
            outputs = {'loss':loss.item(), 'cnt':preds.size(0), 'pred':preds.cpu().numpy(), 'true':labels.cpu().numpy()}
        return outputs  

    def check_hooks(self):
        pass

    # print('{} Epoch: {} Loss: {:.4f} Acc: {:.4f} F1: {:.4f}'.format(phase, epoch, epoch_loss, epoch_acc, f1))

#     if phase == 'valid':
#         if f1 > best_score:
#             best_score = f1
#             best_epoch = epoch
#             patience_cnt = 0
#             torch.save(model.state_dict(), os.path.join(save_dir, f'{save_name}.pt'))
#         else:
#             if patience_cnt == patience:
#                 print()
#                 print('Training complete!')
#                 print('Best f1 score {:.4f} at epoch {}'.format(best_score, best_epoch))
#                 print()
#                 return
#             patience_cnt += 1

#         lr_scheduler.step()

In [7]:
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from dataset import get_dataset, get_weighted_sampler, get_concat_dataset
import torch
import numpy as np
import matplotlib.pyplot as plt

input_size = 224
batch_size = 128
n_worker = 8

train_dataset, valid_dataset = get_dataset(input_size)
concat_dataset = get_concat_dataset()

# sample_freq = np.bincount(train_dataset.targets + valid_dataset.targets)
# sample_weight = np.array([1/sample_freq[x] for x in train_dataset.targets] + [1/sample_freq[x] for x in valid_dataset.targets])
# sample_weight = torch.from_numpy(sample_weight)
# sampler = WeightedRandomSampler(sample_weight.type('torch.DoubleTensor'), len(sample_weight)//2)

sampler = get_weighted_sampler()

train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, sampler = sampler, num_workers=n_worker)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=n_worker)

dataloaders = { 'train' : train_loader, 'valid' : valid_loader}

In [8]:
import torch
import torchvision
from model import CustomVGG
import geffnet

cfg = [[int(64*0.75)], [int(128*0.75)], [int(256*0.75), int(256*0.75)], [int(512*0.75), int(512*0.75)], [int(512*0.75), 512]]
device = 'cuda'

# model = CustomVGG(cfg=cfg, bias=True)
# model.load_state_dict(torch.load('save/pruned_final.pt'))

model = torchvision.models.vgg11_bn(pretrained=True)
model.avgpool = torch.nn.AvgPool2d(7)
model.classifier = torch.nn.Linear(512, 6)
model.to(device)
print()




In [9]:
import torch.optim as optim
from torch.nn import CrossEntropyLoss

lr = 0.0001

criterion = CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
#lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.1) # change trainer step
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', verbose=True, patience=5) # optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=60, eta_min=0.000001)#

In [10]:
trainer = Trainer(hooks=[LossMeter(), ConfusionMatrixMeter(6), LoggerHook()])

In [11]:
trainer.train(model=model, dataloaders=dataloaders, criterion=criterion, optimizer=optimizer, num_epochs=10)

Start Training...

Epoch 1/10


100%|██████████| 61/61 [00:27<00:00,  2.18it/s]


[Train] Epoch: 1 Loss: 1.2850 Acc: 0.5847 F1: 0.5711


100%|██████████| 41/41 [00:15<00:00,  2.66it/s]


[Valid] Epoch: 1 Loss: 1.1124 Acc: 0.6049 F1: 0.5654

Epoch 2/10


  0%|          | 0/61 [00:02<?, ?it/s]


KeyboardInterrupt: 