In [None]:
# load schp
import os
from collections import OrderedDict
import torch

import starry.utils.env
from starry.schp import networks as schp_networks


SCHP_PRETRAINED = os.getenv('SCHP_PRETRAINED')

model = schp_networks.init_model('resnet101', num_classes=18, pretrained=None)

state_dict = torch.load(SCHP_PRETRAINED)['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
	name = k[7:]  # remove `module.`
	new_state_dict[name] = v
model.load_state_dict(new_state_dict)

model.eval()
for param in model.parameters():
	param.requires_grad = False


In [None]:
# load dataset
import matplotlib.pyplot as plt
import numpy as np

from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset


DATA_DIR = os.getenv('DATA_DIR')

config = Configuration.create('configs/peris-score-simple-balance.yaml', volatile=True)
data, = loadDataset(config, data_dir=DATA_DIR, splits='*0/1')

it = iter(data)

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

from starry.vision.data.masker import Masker


masker = Masker(model, resize=512, mask_semantic=0, blur_iterations=40)

image = cv2.imread('./test/download.jpg')
image = torch.from_numpy((image / 255.).astype(np.float32)).permute(2, 0, 1)
masked = masker.mask(image)
masked_img = masked.permute(1, 2, 0).numpy()
masked_img = (masked_img * 255).astype(np.uint8)
print('masked_img:', masked_img.shape)
cv2.imwrite('test/download_masked.jpg', masked_img)

plt.imshow(masked_img)
plt.show()
