In [1]:
import numpy as np
import torch as th
import torch.nn as nn
import torchvision as tv
import pandas as pd
from tqdm import tqdm

from models.vgg_16 import BinaryVgg16
from trainer.binary_trainer import BinaryTrainer
from datasets.image_sharpness_ds import ImageSharpnessDS

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
def plot_sharpness_loss(result):
    result = np.array(result)
    result = result.flatten()
    result = np.sort(result)
    result = result[::-1]
    plt.plot(result)
    plt.ylim((0, 1))
    plt.title('Sharpness Loss Sorted')
    plt.ylabel('Loss')
    plt.xlabel('Samples')
    plt.show()


def transform_results_vector(result_list, item_enable=False):
    result_vector = []
    for vec in result_list:
        for item in vec:
            if item_enable:
                result_vector.append(item.item())
            else:
                result_vector.append(item)

    return result_vector

In [3]:
dev = th.device("cuda") if th.cuda.is_available() else th.device("cpu")
loaded_net = th.load('C:/workspace/Binary_sharpness/saved_models/model_final.pt')
net2 = BinaryVgg16(None).to(dev)
net2.load_state_dict(loaded_net['model_state_dict'])
net2.eval()

ds = ImageSharpnessDS('C:/workspace/datasets/binary_sharpness/labels_1.csv',
                      'C:/workspace/datasets/binary_sharpness/images',
                      transform=th.nn.Sequential(tv.transforms.Resize((128, 128))
                                                 ))

totalloader = th.utils.data.DataLoader(ds, batch_size=75, shuffle=True, num_workers=0)

loop = tqdm(totalloader)

names = []
labels = []
preds = []
losses = []

for j, data in enumerate(loop, 0):
    img, target, name = data
    target = target.to(dev)
    img = img.to(dev)

    label_np = target.cpu().detach().numpy()

    pred = net2.forward(img)

    pred_np = pred.cpu().detach().numpy()
    loss = nn.functional.l1_loss(pred, target, reduction='none')
    loss_np = loss.cpu().detach().numpy()

    names.append(name)
    labels.append(label_np)
    preds.append(pred_np)
    losses.append(loss_np)

100%|██████████| 276/276 [03:00<00:00,  1.53it/s]


In [13]:
result_table = pd.DataFrame().assign(name=transform_results_vector(names, item_enable=False),
                                     label=transform_results_vector(labels, item_enable=True),
                                     predictions=transform_results_vector(preds, item_enable=True),
                                     loss=transform_results_vector(losses, item_enable=True))

print(result_table)

                               name  label  predictions          loss
0         1643_i49_s11_r1_z5_c1.png    1.0     0.997776  2.223551e-03
1       16387_i608_s4_r2_z16_c1.png    0.0     0.024053  2.405262e-02
2         1506_i44_s11_r1_z5_c0.png    1.0     0.999987  1.347065e-05
3        9445_i321_s8_r1_z13_c1.png    1.0     1.000000  0.000000e+00
4        3558_i135_s2_r1_z11_c0.png    1.0     0.999999  8.344650e-07
...                             ...    ...          ...           ...
20667     9119_i310_s8_r1_z3_c1.png    1.0     0.986830  1.316994e-02
20668    4244_i157_s3_r1_z14_c0.png    1.0     1.000000  1.192093e-07
20669     3719_i140_s2_r1_z6_c1.png    1.0     0.999954  4.649162e-05
20670     5683_i204_s4_r1_z2_c1.png    1.0     0.981621  1.837891e-02
20671  11419_i380_s10_r2_z14_c1.png    1.0     0.999996  3.695488e-06

[20672 rows x 4 columns]


In [25]:
result_table = result_table.sort_values(by='loss', ascending=False)
print(result_table)
result_table.to_csv('C:/workspace/datasets/binary_sharpness/results_1.csv', header=True, index=False)

                               name  label   predictions  loss
17058     4978_i181_s3_r1_z7_c0.png    1.0  4.106361e-10   1.0
17405     4982_i181_s3_r1_z9_c0.png    1.0  1.009961e-09   1.0
12539    4992_i181_s3_r1_z14_c0.png    1.0  1.035908e-08   1.0
15311    4986_i181_s3_r1_z11_c0.png    1.0  1.705591e-09   1.0
18739    4990_i181_s3_r1_z13_c0.png    1.0  5.806085e-09   1.0
...                             ...    ...           ...   ...
5149    11438_i381_s10_r2_z7_c0.png    1.0  1.000000e+00   0.0
15620  11042_i369_s10_r2_z13_c0.png    1.0  1.000000e+00   0.0
19885  13214_i433_s12_r2_z11_c0.png    1.0  1.000000e+00   0.0
9135     5880_i209_s4_r1_z16_c0.png    1.0  1.000000e+00   0.0
11411  12408_i409_s11_r2_z16_c0.png    1.0  1.000000e+00   0.0

[20672 rows x 4 columns]
