# Inference

In [None]:
import os
os.chdir("/home/yinyin/salient_text_official")

from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms

from utils import load_config, load_checkpoint
from modeling.model import build_salientText_model
from modeling.optimizer import build_optimizer
from modeling.scheduler import build_scheduler
from modeling.data.generate_mask_bbox import LabelGeneration

In [None]:
trained_weights = f'./output/model_160.pt'
config = load_config(f'./configs/ecom.yaml')
device = 'cpu'

model = build_salientText_model(backbone_cfg=config['MODEL']['BACKBONE_CFG'], input_size=config['INPUT']['SIZE'], device=device).to(device)
optimizer = build_optimizer(config,model)
scheduler = build_scheduler(config,optimizer,config['INPUT']['TRAIN_NUM_DATA'])

model, optimizer, scheduler, start_epoch = load_checkpoint(trained_weights, model, optimizer, scheduler)

In [None]:
test_list  = [3, 6, 10, 23, 27, 30, 32, 36, 65, 71, 75, 97, 103, 105, 121, 124, 138, 191, 195, 205, 212, 215, 222, 225, 228, 235, 237, 239, 262, 273, 291, 297, 304, 311, 326, 341, 354, 389, 391, 400, 403, 427, 433, 444, 449, 457, 461, 471, 484, 492, 508, 512, 513, 518, 520, 541, 555, 567, 570, 583, 602, 604, 606, 623, 645, 658, 666, 668, 687, 694, 703, 713, 729, 739, 742, 744, 761, 780, 781, 783, 786, 808, 822, 850, 856, 858, 867, 868, 874, 881, 903, 915, 924, 925, 934, 942, 954, 963, 964, 968, 969]
root_dir = r'/home/yinyin/salient_text/dataset/ECdata/'

# load image and labels
preprocess =transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(config['INPUT']['SIZE']),
]) 

label = LabelGeneration()

## Show masks

In [None]:

for num in test_list:
    
    try:
        img_path = os.path.join(root_dir, f'ALLSTIMULI/{num}.jpg')
        saliency_path = os.path.join(root_dir, f'ALLFIXATIONMAPS/{num}_fixMap.jpg')
        text_path =  os.path.join(root_dir,  f'TEXT/gt_{num}.txt')

        ori_img = Image.open(img_path).convert("RGB")
        saliency_gt = Image.open(saliency_path).convert("L")

        boxes = label.get_annotations(text_path)
        text_mask = label.box2mask(ori_img,boxes)
        text_gt = Image.fromarray(text_mask).convert("L")

        img = preprocess(ori_img).unsqueeze(0)

        # run model
        saliency_map, text_map = model(img)

        # show mask
        binary_text = text_map[0]>0.05
        binary_text = binary_text.permute(1, 2, 0).detach().numpy()

        binary_saliency = saliency_map[0]>0.5
        binary_saliency = binary_saliency.permute(1, 2, 0).detach().numpy()

        fig, axes = plt.subplots(1, 5, figsize=(10, 5))  # Adjust the figsize as needed

        axes[0].imshow(ori_img,cmap="gray")
        axes[0].set_title('Ori Image')
        axes[1].imshow(binary_text,cmap="gray")
        axes[1].set_title('Output Text Mask')
        axes[2].imshow(text_gt,cmap="gray")
        axes[2].set_title('GT Text')
        axes[3].imshow(binary_saliency,cmap="gray")
        axes[3].set_title('Output Saliency')
        axes[4].imshow(saliency_gt,cmap="gray")
        axes[4].set_title('GT Saliency')

        for ax in axes:
            ax.axis('off')
        plt.show()
    
    except Exception as e:
        print(e)
        pass

## Show output boxes

In [None]:
img_path = os.path.join(root_dir, f'ALLSTIMULI/6.jpg')
ori_img = Image.open(img_path).convert("RGB")

img = preprocess(ori_img).unsqueeze(0)
# run model
saliency_map, text_map = model(img)

binary_text = text_map[0]>0.05
binary_text = binary_text.permute(1, 2, 0).detach().numpy()

# show boxes
shrunk_map =label.shrunk_map(binary_text)
boxes_batch, scores_batch = label.mask2box(ori_img,shrunk_map.reshape(224,224))
out_img = label.draw_box(ori_img,boxes_batch)

fig, axes = plt.subplots(1, 1, figsize=(10, 5))  # Adjust the figsize as needed

axes.imshow(out_img,cmap="gray")
plt.show()
