In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from torchvision import datasets, models, transforms
import os
import glob

from PIL import Image

## Set up device

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_GPU = torch.cuda.device_count()
print("Using {} GPUs".format(torch.cuda.device_count()))

Using 1 GPUs


## Configuration

In [3]:
class Config(object):
    def __init__(self, **kwargs):
        self._homedir = ".."
        
        # Training Data path
        self._datapath = os.path.join(
            self._homedir, 
            kwargs.get("datapath", "hymenoptera_data")
        )
        self._target_classes = ['ants', 'bees']
        self._target_class_to_idx = {
            "ants": 0,
            "bees": 1
        }
        
        # Model backbone
        self._model_backbone = "resnet18"
        self._pretrain = True

        # Data Loader configs
        self._batch_size = kwargs.get("batch_size", 16)
        self._shuffle = kwargs.get("shuffle", True)
        self._num_worker = kwargs.get("num_worker", 0)

        # Optimization params
        self._num_epochs = kwargs.get("num_epochs", 25)
        self._learning_rate = kwargs.get("learning_rate", 0.001)
        self._momentum = kwargs.get("momentum", 0.9)
        self._lr_scheduler_dict = kwargs.get("lr_scheduler", {
            "__name__": "step_lr",
            "step_size": 7,
            "gamma": 0.1
        })
        
        # Output file
        self._snapshot_folder = os.path.join(
            self._homedir,
            kwargs.get("snapshot_folder", "snapshots")
        )
        self._results_folder = os.path.join(
            self._homedir,
            kwargs.get("result_folder", "results")
        )

## Model

In [4]:
class FineTuneModel(Config):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def get_model(self, num_labels):
        if self._model_backbone == "resnet18":
            model_ft = models.resnet18(pretrained=self._pretrain)
            num_ftrs = model_ft.fc.in_features

            model_ft.fc = nn.Linear(num_ftrs, num_labels)

            return model_ft
        
    def _num_total_params(self, _model):
        num_params = 0
        
        for p in _model.parameters():
            num_params += p.numel()
            
        return num_params
    
    def _num_trainable_params(self, _model):
        return sum(p.numel() for p in _model.parameters() if p.requires_grad)

## Inference 

In [5]:
class ImageClassification(Config):
    def __init__(self, weight_path, gpu_number=0, **kwargs):
        super().__init__(**kwargs)
        
        # prepare model
        self._load_model_weights(weight_path)
            
    def _preprocess_data(self, image_path):
        inference_transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        image_ = Image.open(image_path).convert('RGB')
        
        image_tensor = inference_transforms(image_).float().unsqueeze_(0)
    
        return image_tensor
    
    def _process_output(self, image_tensor):
        input_ = image_tensor.to(DEVICE)
        output_ = self.model(input_)
        
        raw_output = [
            np.round(v, 4) 
            for v in output_.data.cpu().numpy().tolist()[0]
        ]
        
        _, preds = torch.max(output_, 1)
        
        pred_index = preds.data.cpu().numpy()[0]
        
        pred_class = [
            k 
            for k, v in self._target_class_to_idx.items()
            if v == pred_index 
        ][0]
        
        return {
            "predicted_class": pred_class,
            "raw_output": raw_output,
            "predicted_label": pred_index
        }

    def _load_model_weights(self, weight_path):
        print("Preparing model: {} ...".format(self._model_backbone))
        self.model = FineTuneModel().get_model(len(self._target_classes))
        
        print("Preparing model: mapping to devices...")
        self.model = nn.DataParallel(self.model)
        self.model.to(DEVICE)
        
        print("Loading weights: {} ...".format(weight_path))  
        checkpoint = torch.load(weight_path, map_location=DEVICE)
        
        self.model.load_state_dict(checkpoint["state_dict"])
        self.model.to(DEVICE)
        
        print("Model is ready!")
        
        self.model.eval()
        
            
    def predict(self, image_path):
        image_tensor_ = self._preprocess_data(image_path)
        output_ = self._process_output(image_tensor_)
        
        output_.update({"image_path": image_path})
        
        return output_

In [6]:
import random
c = ImageClassification(weight_path="../results/best_resnet18_acc0.9477_checkpoint.pth.tar", gpu_number=[6])
for f in random.choices(glob.glob("../hymenoptera_data/val/*/*"), k=20):
    print(c.predict(f))

Preparing model: resnet18 ...
Preparing model: mapping to devices...
Loading weights: ../results/best_resnet18_acc0.9477_checkpoint.pth.tar ...
Model is ready!
{'predicted_class': 'bees', 'raw_output': [-1.4259, 1.0546], 'predicted_label': 1, 'image_path': '../hymenoptera_data/val/bees/144098310_a4176fd54d.jpg'}
{'predicted_class': 'bees', 'raw_output': [-2.235, 1.5869], 'predicted_label': 1, 'image_path': '../hymenoptera_data/val/bees/2841437312_789699c740.jpg'}
{'predicted_class': 'ants', 'raw_output': [1.4502, -1.6709], 'predicted_label': 0, 'image_path': '../hymenoptera_data/val/ants/8124241_36b290d372.jpg'}
{'predicted_class': 'ants', 'raw_output': [2.0823, -2.0385], 'predicted_label': 0, 'image_path': '../hymenoptera_data/val/ants/F.pergan.28(f).jpg'}
{'predicted_class': 'ants', 'raw_output': [1.1002, -1.448], 'predicted_label': 0, 'image_path': '../hymenoptera_data/val/bees/54736755_c057723f64.jpg'}
{'predicted_class': 'ants', 'raw_output': [0.2554, 0.248], 'predicted_label': 0,