In [None]:
import os
from collections import OrderedDict
import torch

import starry.utils.env
from starry.schp import networks as schp_networks


SCHP_PRETRAINED = os.getenv('SCHP_PRETRAINED')

model = schp_networks.init_model('resnet101', num_classes=18, pretrained=None)

state_dict = torch.load(SCHP_PRETRAINED)['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
	name = k[7:]  # remove `module.`
	new_state_dict[name] = v
model.load_state_dict(new_state_dict)

model.eval()
for param in model.parameters():
	param.requires_grad = False

result = model(torch.randn(1, 3, 256, 256))


In [None]:
result[0][0].shape, result[0][1].shape, result[1][0].shape


In [None]:
# load dataset
from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset


DATA_DIR = os.getenv('DATA_DIR')

config = Configuration.create('configs/peris-score-simple-b0-balance.local.yaml', volatile=True)
data, = loadDataset(config, data_dir=DATA_DIR, splits='*0/1')

it = iter(data)


In [None]:
# show heatmap
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms


def showImage (tensor):
	image = tensor[0].permute(1, 2, 0).numpy()
	image = (image * 255).astype(np.uint8)
	plt.imshow(image)
	plt.show()

source, labels = next(it)
print('origin shape:', source.shape)

size = (256 * source.shape[2] // source.shape[3], 256)
resize = transforms.Compose([transforms.Resize(size)])
source = resize(source)
print('resized shape:', source.shape)
#showImage(source)

result = model(source)
showImage(result[0][0][:, :3, :, :])
