In [None]:
import os
from collections import OrderedDict
import torch

import starry.utils.env
from starry.schp import networks as schp_networks


SCHP_PRETRAINED = os.getenv('SCHP_PRETRAINED')

model = schp_networks.init_model('resnet101', num_classes=18, pretrained=None)

state_dict = torch.load(SCHP_PRETRAINED)['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
	name = k[7:]  # remove `module.`
	new_state_dict[name] = v
model.load_state_dict(new_state_dict)

model.eval()
for param in model.parameters():
	param.requires_grad = False

result = model(torch.randn(1, 3, 256, 256))


In [None]:
result[0][0].shape, result[0][1].shape, result[1][0].shape


In [None]:
# load dataset
import matplotlib.pyplot as plt
import numpy as np

from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset


DATA_DIR = os.getenv('DATA_DIR')

config = Configuration.create('configs/peris-score-simple-b0-balance.local.yaml', volatile=True)
data, = loadDataset(config, data_dir=DATA_DIR, splits='*0/1')

it = iter(data)

def showImage (tensor):
	image = tensor[0].permute(1, 2, 0).numpy()
	image = (image * 255).astype(np.uint8)
	plt.imshow(image)
	plt.show()

def get_palette(num_cls):
	n = num_cls
	palette = [0] * (n * 3)
	for j in range(0, n):
		lab = j
		palette[j * 3 + 0] = 0
		palette[j * 3 + 1] = 0
		palette[j * 3 + 2] = 0
		i = 0
		while lab:
			palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
			palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
			palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
			i += 1
			lab >>= 3
	return palette

palette = get_palette(18)
#print('palette:', palette)


In [None]:
# show heatmap
from torchvision import transforms
from PIL import Image


source, labels = next(it)
print('origin shape:', source.shape)

size = (512 * source.shape[2] // source.shape[3], 512)
resize = transforms.Compose([transforms.Resize(size)])
source = resize(source)
print('resized shape:', source.shape)
showImage(source)

result = model(source)
a = result[0][-1].numpy()
amax, amin = np.max(a), np.min(a)
showImage((result[0][-1][:, :3, :, :] + amin) / (amax - amin))

image = result[0][-1][0].permute(1, 2, 0).numpy()	# (h, w, c)
parsing_result = np.argmax(image, axis=2)	# (h, w)
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
output_img.putpalette(palette)
#showImage(parsing_result[:, :3, :, :])
plt.imshow(np.array(output_img))
plt.show()
