In [None]:
import os
import torch
import torchvision.transforms

import matplotlib.pyplot as plt
import numpy as np

from skimage import io, color
from src import colnet
from src import dataset
from src import utils

In [None]:
model = './model/places10/colnet181211-23-40-45-19.pt'

In [None]:
checkpoint = torch.load(model, map_location=torch.device("cpu"))
classes = checkpoint['classes']
net_divisor = checkpoint['net_divisor']
num_classes = len(classes)

net = colnet.ColNet(num_classes=num_classes, net_divisor=net_divisor)
net.load_state_dict(checkpoint['model_state_dict'])


In [None]:
img_path = 'shinjuku-gyoen-square-224.jpg'
img_path = 'data/places10/test/japanese_garden/Places365_val_00007875.jpg'

img = io.imread(img_path)
io.imshow(img)
io.show()

In [None]:
composed_transforms = torchvision.transforms.Compose(
            [dataset.HandleGrayscale(), 
             dataset.RandomCrop(224),
             dataset.Rgb2LabNorm(), 
             dataset.ToTensor(), 
             dataset.SplitLab()]
        )

In [None]:
L, ab = composed_transforms(img)
L_tensor = torch.from_numpy(np.expand_dims(L, axis=0))

In [None]:
img_name = os.path.basename(img_path)
img_name

In [None]:
classes

In [None]:
softmax = torch.nn.Softmax(dim=1)
net.eval()
with torch.no_grad():
    ab_out, predicted = net(L_tensor)
    img_colorized = utils.net_out2rgb(L, ab_out[0])
    
    colorized_img_name = "colorized-" + img_name
    io.imsave(colorized_img_name, img_colorized)
    io.imshow(img_colorized)
    io.show()

    print("Saved image to: {}\n".format(colorized_img_name))
    
    
    sm = softmax(predicted)
    probs = sm[0].numpy()
    print (probs)
    
    probs_and_classes = sorted(zip(probs, classes), key=lambda x: x[0], reverse=True)

    
    print("Predicted labels: \n")
    for p, c in probs_and_classes[:10]:
        print("{:>7.2f}% \t{}".format(p*100.0, c))