In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import datasets, models, transforms

import time
import os
import copy
import glob
import gc
import shutil

import numpy as np
from PIL import Image

## Set Device

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_GPU = torch.cuda.device_count()
print("Using {} GPUs".format(torch.cuda.device_count()))

Using 2 GPUs


## Configuration

In [3]:
class Config(object):
    def __init__(self, **kwargs):
        self._homedir = ".."
        
        # Training Data path
        self._datapath = os.path.join(
            self._homedir, 
            kwargs.get("datapath", "hymenoptera_data")
        )
        self._target_classes = ['ants', 'bees']
        self._target_class_to_idx = {
            "ants": 0,
            "bees": 1
        }
        
        # Model backbone
        self._model_backbone = "resnet18"
        self._pretrain = True

        # Data Loader configs
        self._batch_size = kwargs.get("batch_size", 16)
        self._shuffle = kwargs.get("shuffle", True)
        self._num_worker = kwargs.get("num_worker", 0)

        # Optimization params
        self._num_epochs = kwargs.get("num_epochs", 25)
        self._learning_rate = kwargs.get("learning_rate", 0.001)
        self._momentum = kwargs.get("momentum", 0.9)
        self._lr_scheduler_dict = kwargs.get("lr_scheduler", {
            "__name__": "step_lr",
            "step_size": 7,
            "gamma": 0.1
        })
        
        # Output file
        self._snapshot_folder = os.path.join(
            self._homedir,
            kwargs.get("snapshot_folder", "snapshots")
        )
        self._results_folder = os.path.join(
            self._homedir,
            kwargs.get("result_folder", "results")
        )

## Model 

In [4]:
class FineTuneModel(Config):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def get_model(self, num_labels):
        if self._model_backbone == "resnet18":
            model_ft = models.resnet18(pretrained=self._pretrain)
            num_ftrs = model_ft.fc.in_features

            model_ft.fc = nn.Linear(num_ftrs, num_labels)

            return model_ft
        
    def _num_total_params(self, _model):
        num_params = 0
        
        for p in _model.parameters():
            num_params += p.numel()
            
        return num_params
    
    def _num_trainable_params(self, _model):
        return sum(p.numel() for p in _model.parameters() if p.requires_grad)

## Dataset

In [5]:
class FTDataset(Config, Dataset):
    def __init__(self, phase="train", **kwargs):
        # intialize config
        super().__init__(**kwargs)

        # current phase
        self._phase = phase
        
        # load raw data
        self._prepare_data()

    _image_transforms = None
    @property
    def image_transforms(self):
        if self._image_transforms is None:
            self._image_transforms = self._set_image_transforms()
        return self._image_transforms

    @image_transforms.setter
    def image_transforms(self, image_transforms):
        self._image_transforms = image_transforms

    def _set_image_transforms(self):
        """Function to set up data augmentation
        
        Data augmentation and normalization for training; just normalization for validation
        """
        if self._phase == "train":
            return transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        elif self._phase == "val":
            return transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

    @property
    def data_location(self):
        return os.path.join(self._datapath, self._phase)
    
    def _prepare_data(self):
        """Function to load data from raw images with targets
        
        Attributes:
            image_labels: [
                                (image_path_1, target_1), 
                                (image_path_2, target_2),
                                ...
                          ]
        """
        self.image_labels = [
            (f, os.path.basename(os.path.dirname(f)))
            for f in glob.glob(os.path.join(self.data_location, "*", "*"))
        ]

    
    def __getitem__(self, idx):
        filename, target = self.image_labels[idx]
        
        img = Image.open(filename).convert('RGB')
        img_ = self.image_transforms(img)
        
        label_ = self._target_class_to_idx[target]

        return img_, label_
    
    def __len__(self):
        return len(self.image_labels)

class FTDataLoader(Config):
    def __init__(self, phase="train", **kwargs):
        # intialize config
        super().__init__(**kwargs)

        # current phase
        self._phase = phase
        self.image_dataset = FTDataset(phase=phase, **kwargs)
        
        # set global batch size (for multi-device training)
        self._global_batch_size = kwargs.get("global_batch_size", self._batch_size)

    _dataloader = None
    @property
    def dataloader(self):
        if self._dataloader is None:
            self._dataloader = self._get_data_loader()
        return self._dataloader

    @dataloader.setter
    def dataloader(self, dataloader):
        self._dataloader = dataloader
        
    def _get_data_loader(self):
        return DataLoader(
            self.image_dataset, 
            batch_size=self._global_batch_size,
            shuffle=self._shuffle, 
            num_workers=self._num_worker
        )

    @property
    def _size(self):
        return len(self.image_dataset)

    @property
    def _classes(self):
        return self.image_dataset._target_classes

## Solver

In [6]:
class Solver(Config):
    def __init__(self, gpu_number=0, **kwargs):
        super().__init__(**kwargs)

        # prepare data
        self._get_dataset()
        
        # prepare model
        self._get_or_load_model()
        self._set_optimizer(self.model.parameters())
        self._set_criterion()
        self._set_learningrate_scheduler()

    def _get_dataset(self):
        for p in ["train", "val"]:
            temp_d = FTDataLoader(phase=p, global_batch_size=NUM_GPU * self._batch_size)
            setattr(self, "{}_dataloader".format(p), temp_d.dataloader)
            setattr(self, "{}_datasize".format(p), temp_d._size)

    def _get_or_load_model(self):
        self.model = FineTuneModel().get_model(len(self._target_classes))
        self.model = nn.DataParallel(self.model)
        self.model.to(DEVICE)

    def _set_optimizer(self, parameters):
        self.optimizer = optim.SGD(
            parameters,
            lr=self._learning_rate,
            momentum=self._momentum,
        )

    def _set_criterion(self):
        self.criterion = nn.CrossEntropyLoss()

    def _set_learningrate_scheduler(self):
        if self._lr_scheduler_dict["__name__"] == "step_lr":
            self.lr_scheduler = lr_scheduler.StepLR(
                self.optimizer, 
                step_size=self._lr_scheduler_dict.get("step_size", 7), 
                gamma=self._lr_scheduler_dict.get("gamma", 0.1)
            )
            
    def save_checkpoint(self, state, epoch, filename='checkpoint.pth.tar'):
        if not os.path.exists(self._snapshot_folder):
            os.makedirs(self._snapshot_folder)
        
        absolute_path = os.path.join(self._snapshot_folder, "epoch_{}_{}".format(epoch, filename))
        torch.save(state, absolute_path)
        
    def update_best_model(self, epoch, acc, filename='checkpoint.pth.tar'):
        if not os.path.exists(self._results_folder):
            os.makedirs(self._results_folder)
        
        current_absolute_path = os.path.join(self._snapshot_folder, "epoch_{}_{}".format(epoch, filename))
        best_absolute_path = os.path.join(self._snapshot_folder, "best_{}".format(filename))
        best_absolute_result = os.path.join(
            self._results_folder, 
            "best_{}_acc{:.4f}_{}".format(self._model_backbone, acc, filename)
        )
        
        shutil.copyfile(current_absolute_path, best_absolute_path)
        shutil.copyfile(current_absolute_path, best_absolute_result)
        print("Saving new best model to results: {}".format(best_absolute_result))
        
    def restore_model(self, resultname=None, epoch=-1, filename='checkpoint.pth.tar'):
        if resultname is None:
            if epoch == -1:
                model_path = "best_{}".format(filename)
            else:
                model_path = "epoch_{}_{}".format(epoch, filename)

            model_fullpath = os.path.join(self._snapshot_folder, model_path)
        else:
            model_fullpath = os.path.join(self._results_folder, resultname)
        
        print("Loading model: {}".format(model_fullpath))
        checkpoint = torch.load(model_fullpath, map_location=DEVICE)
        self.model.load_state_dict(checkpoint["state_dict"])
        self.model.to(DEVICE)

    def train(self, load_epoch=None, load_model=None):
        print('Start training...')
        since_ = time.time()

        if load_model is not None:
            self.restore_model(resultname=load_model)
        elif load_epoch is not None:
            self.restore_model(epoch=load_epoch)
            
        best_epoch = 0
        best_acc = 0.0

        for epoch in range(self._num_epochs):
            print('Epoch {}/{}'.format(epoch, self._num_epochs - 1))
            print('-' * 10)

            
            loss_dict = {}
            acc_dict = {}
            
            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    self.model.train()  # Set model to training mode
                else:
                    self.model.eval()   # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                for inputs, labels in getattr(self, "{}_dataloader".format(phase)):
                    # map data to device
                    inputs = inputs.to(DEVICE)
                    labels = labels.to(DEVICE)

                    # zero the parameter gradients
                    # clears old gradients from the last step 
                    # (otherwise you’d just accumulate the gradients from all loss.backward() calls).
                    self.optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    # set_grad_enabled() can be used to conditionally enable gradients.
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = self.model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = self.criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            # computes the derivative of the loss w.r.t. the parameters
                            # (or anything requiring gradients) using backpropagation.
                            loss.backward()
                            
                            # Performs a single optimization step (parameter update).
                            self.optimizer.step() 

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                
                # learning rate update for training only
                if phase == 'train':
                    # update learning rate with learning rate scheduler
                    self.lr_scheduler.step()

                # save stats and display
                epoch_loss = running_loss / getattr(
                    self, "{}_datasize".format(phase)
                )
                loss_dict[phase] = epoch_loss
                
                epoch_acc = running_corrects.double() / getattr(
                    self, "{}_datasize".format(phase)
                )
                acc_dict[phase] = epoch_acc

                print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                    phase, epoch_loss, epoch_acc
                ))
                
                # save intermediate model state, params and etc.
                if phase == 'val':
                    self.save_checkpoint(
                        {
                            'lr': self.optimizer.param_groups[0]["lr"],
                            'state_dict': self.model.state_dict(),
                            'loss_stats': loss_dict,
                            'acc_stats': acc_dict
                        },
                        epoch
                    )

                    # deep copy the model
                    if epoch_acc > best_acc:
                        best_acc = epoch_acc
                        best_epoch = epoch

            gc.collect()
            torch.cuda.empty_cache()
        
        self.update_best_model(best_epoch, best_acc)
        
        time_elapsed = time.time() - since_
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: {:4f}'.format(best_acc))
        
    def evaluate(self, epoch):
        # load model
        self.restore_model(epoch)
        
        # Set model to evaluate mode
        self.model.eval()
        
        running_loss = 0.0
        running_corrects = 0
            
        for inputs, labels in self.val_dataloader:
            # map data to device
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            with torch.no_grad():
                outputs = self.model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = self.criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            
        total_loss = running_loss / self.val_datasize
        total_acc = running_corrects / self.val_datasize
            
        print("Total Loss: {:.4f}, Acc: {:.4f}".format(total_loss, total_acc))

## Training

In [7]:
s = Solver(num_epochs=10, gpu_number=[6], lr_scheduler={
            "__name__": "step_lr",
            "step_size": 1,
            "gamma": 0.1
        })
s.train(load_model="best_resnet18_acc0.9477_checkpoint.pth.tar")

Start training...
Loading model: ../results/best_resnet18_acc0.9477_checkpoint.pth.tar
Epoch 0/9
----------
train Loss: 0.2018 Acc: 0.9306
val Loss: 0.1886 Acc: 0.9477
Epoch 1/9
----------
train Loss: 0.1801 Acc: 0.9510
val Loss: 0.1939 Acc: 0.9477
Epoch 2/9
----------
train Loss: 0.1854 Acc: 0.9469
val Loss: 0.1898 Acc: 0.9477
Epoch 3/9
----------
train Loss: 0.1706 Acc: 0.9347
val Loss: 0.1878 Acc: 0.9412
Epoch 4/9
----------
train Loss: 0.1660 Acc: 0.9469
val Loss: 0.1917 Acc: 0.9412
Epoch 5/9
----------
train Loss: 0.1499 Acc: 0.9469
val Loss: 0.1856 Acc: 0.9477
Epoch 6/9
----------
train Loss: 0.1770 Acc: 0.9429
val Loss: 0.1859 Acc: 0.9477
Epoch 7/9
----------
train Loss: 0.1674 Acc: 0.9265
val Loss: 0.1839 Acc: 0.9412
Epoch 8/9
----------
train Loss: 0.1778 Acc: 0.9429
val Loss: 0.1833 Acc: 0.9477
Epoch 9/9
----------
train Loss: 0.1810 Acc: 0.9306
val Loss: 0.1803 Acc: 0.9412
Saving new best model to results: ../results/best_resnet18_acc0.9477_checkpoint.pth.tar
Training comple