In [50]:
from torchvision import models, transforms
import torch.nn as nn
import torch
import os
from skimage import io
from pathlib import Path

In [153]:
model_name = 'class'

In [154]:
model_eval = models.resnet101(pretrained=True)

In [155]:
class_output = {
    'binary': {'Ghost': 0, 'Animal': 1, 'Unknown': 2},
    'class': {'Ghost': 0, 'Aves': 1, 'Mammalia': 2, 'Unknown': 3},
    'species':{
                'Ghost': 0,
                'ArremonAurantiirostris': 1,
                'Aves (Class)': 2,
                'BosTaurus': 3,
                'CaluromysPhilander': 4,
                'CerdocyonThous': 5,
                'CuniculusPaca': 6,
                'Dasyprocta (Genus)': 7,
                'DasypusNovemcinctus': 8,
                'DidelphisAurita': 9,
                'EiraBarbara': 10,
                'Equus (Genus)': 11,
                'Leopardus (Genus)': 12,
                'LeptotilaPallida': 13,
                'Mammalia (Class)': 14,
                'MazamaAmericana': 15,
                'Metachirus (Genus)': 16,
                'Momota (Genus)': 17,
                'Nasua (Genus)': 18,
                'PecariTajacu': 19,
                'ProcyonCancrivorus': 20,
                'Rodentia (Order)': 21,
                'Sciurus (Genus)': 22,
                'SusScrofa': 23,
                'TamanduaTetradactyla': 24,
                'TinamusMajor': 25,
                'Unknown': 26}
}

In [156]:
model_map = class_output[model_name]
in_features =len(model_map)
reverse_model_map = {v: k for k, v in model_map.items()}

In [157]:
using_gpu = torch.cuda.is_available()
if using_gpu:
    print('Using GPU!')
else:
    print('Using CPU!')
device = torch.device("cuda:0" if using_gpu else "cpu")

Using CPU!


In [158]:
num_ftrs = model_eval.fc.in_features
model_eval.fc = nn.Linear(num_ftrs, in_features)

model_eval = model_eval.to(device)

In [159]:
model_full_name = sorted([i for i in os.listdir('./models/') if model_name.lower() in i.split('_')], reverse=True)[0]

In [160]:
model_full_name

'resnet101_class_20210416.pth'

In [161]:
try:
    model_eval.load_state_dict(torch.load(f'./models/{model_full_name}',  map_location=device))
    print(f'Loading {model_full_name} pre-trained model')
except Exception as e:
    print(f'cannot load model! {e}')

Loading resnet101_class_20210416.pth pre-trained model


In [162]:
regulated_size = 300, 450
default_val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(size=regulated_size),
    transforms.ToTensor(),
])


In [170]:
## TODO
img_root_path = Path('tmp_images')
imgs = sorted([i for i in os.listdir(img_root_path) if i.lower().endswith('jpg')])

for i in imgs:
    tmp_img = io.imread(img_root_path / i)

In [166]:
tmp_img = default_val_transform(tmp_img)

In [167]:
model_eval.eval()
softmax = torch.nn.Softmax(dim=1)
with torch.no_grad():
    inputs = tmp_img.to(device)
    inputs = inputs.reshape(1, *inputs.shape)
    outputs = model_eval(inputs)
    
    prob = softmax(outputs)
    pred_prob, pred_id = torch.max(prob, 1)
    pred_id = pred_id.tolist()
    
    pred_prob = pred_prob.tolist()[0]
    pred_str = reverse_model_map[pred_id[0]]

In [168]:
pred_prob

1.0

In [169]:
pred_str

'Mammalia'