In [14]:
import json
import torch
from pathlib import Path
from src import utils
from src.model import RBVPredictor
from PIL import Image

In [7]:
data_root = Path("/ssd_scratch/cvit/ishaanshah/tree_transformed_dataset/")
model_path = Path("./model/model.ckpt")

In [10]:
model = RBVPredictor.load_from_checkpoint(str(model_path))
model.eval()

RBVPredictor(
  (criterion): MSELoss()
  (accuracy): CosineSimilarity()
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (13): ReLU()
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(12

In [32]:
def infer(idx):
    label = utils.process_label(Image.open(str(data_root / 'train' / 'Labels' / f'{idx}.png')))
    label = torch.unsqueeze(label, dim=0)
    with open(str(data_root / 'train' / 'Info' / f'{idx}.json')) as f:
        gt = json.load(f)['rbv']

    with torch.no_grad():
        x1, _, _, inferred = model(label)
    print(x1)
    return torch.tensor(gt), torch.reshape(inferred, (8, 8))

In [34]:
gt, inferred = infer(0)
print(torch.abs(gt - inferred))
print(inferred)

tensor([[0.4959]])
tensor([[0.0072, 0.0098, 0.0068, 0.0081, 0.0077, 0.0086, 0.0070, 0.0095],
        [0.4029, 0.4263, 0.1420, 0.0539, 0.0821, 0.4706, 0.2740, 0.0706],
        [0.3300, 0.3098, 0.2688, 0.1228, 0.0964, 0.3439, 0.3550, 0.2920],
        [0.2387, 0.2405, 0.0228, 0.1609, 0.2089, 0.2222, 0.2435, 0.2434],
        [0.0947, 0.1526, 0.0317, 0.1855, 0.1651, 0.1544, 0.1532, 0.1361],
        [0.0326, 0.1291, 0.0325, 0.0060, 0.1344, 0.0950, 0.1018, 0.0508],
        [0.0148, 0.0201, 0.0800, 0.0715, 0.1046, 0.0730, 0.1414, 0.1226],
        [0.1379, 0.0939, 0.1775, 0.1766, 0.1027, 0.1084, 0.1796, 0.1643]])
tensor([[0.0462, 0.0487, 0.0457, 0.0470, 0.0476, 0.0476, 0.0459, 0.0484],
        [0.1582, 0.1581, 0.1590, 0.1579, 0.1584, 0.1577, 0.1599, 0.1552],
        [0.2813, 0.2833, 0.2796, 0.2798, 0.2803, 0.2798, 0.2817, 0.2799],
        [0.3545, 0.3566, 0.3528, 0.3541, 0.3559, 0.3548, 0.3545, 0.3567],
        [0.3602, 0.3579, 0.3609, 0.3596, 0.3614, 0.3590, 0.3604, 0.3583],
        [0.3283, 0

In [35]:
gt, inferred = infer(1)
print(torch.abs(gt - inferred))
print(gt)
print(inferred)

tensor([[0.4959]])
tensor([[0.0031, 0.0135, 0.0159, 0.0195, 0.0201, 0.0200, 0.0090, 0.0077],
        [0.1830, 0.1581, 0.1590, 0.1579, 0.1584, 0.2190, 0.2170, 0.1656],
        [0.0719, 0.2131, 0.0530, 0.0773, 0.0240, 0.1332, 0.1759, 0.1746],
        [0.0166, 0.1643, 0.1301, 0.1206, 0.0570, 0.0431, 0.1075, 0.0912],
        [0.2208, 0.1889, 0.2260, 0.1001, 0.0234, 0.0091, 0.0174, 0.0191],
        [0.0420, 0.1856, 0.1648, 0.0454, 0.0660, 0.0179, 0.0412, 0.0138],
        [0.0031, 0.2740, 0.1141, 0.1141, 0.0020, 0.0671, 0.0811, 0.0333],
        [0.0918, 0.1753, 0.1775, 0.1766, 0.0256, 0.1620, 0.1642, 0.1203]])
tensor([[0.0430, 0.0352, 0.0298, 0.0275, 0.0275, 0.0275, 0.0549, 0.0561],
        [0.3412, 0.0000, 0.0000, 0.0000, 0.0000, 0.3767, 0.3769, 0.3208],
        [0.3532, 0.0702, 0.2266, 0.2025, 0.3043, 0.4131, 0.4576, 0.4545],
        [0.3378, 0.1923, 0.2227, 0.2334, 0.2989, 0.3979, 0.4619, 0.4479],
        [0.1394, 0.1690, 0.1348, 0.2594, 0.3380, 0.3681, 0.3778, 0.3391],
        [0.2862, 0

In [40]:
gt, inferred = infer(2)
#print(torch.abs(gt - inferred))
# print(torch.mean(torch.abs(gt - inferred)))
#print(gt)
#print(inferred)

tensor([[0.4959]])
