In [None]:
"""Evaluation of PointNet Classification model trained on MCB B dataset.
"""

In [None]:
import pathlib
import torch
from config import config
from typing import List
from typing import Union
from typing import Any
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
from sklearn.metrics import jaccard_score
import matplotlib
import matplotlib.pyplot as plt
from model.model import PointNetClassification
from utils import viewer
import model.dataset as dataset
from torchvision import transforms
import itertools
import numpy as np

In [None]:
def plot_loss(
    model_name: str, train_loss: List[float], valid_loss: List[float]
) -> None:
    """Visualizes the loss data of the trained model.

    Args:
        model_name: Name of the trained model.
        train_loss: Loss data generated during training.
        valid_loss: Loss data generated during validation.
    """
    plt.figure(figsize=(12, 7))
    plt.suptitle(model_name + " loss values")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.plot(train_loss)
    plt.plot(valid_loss)
    plt.legend(["Training Loss", "Validation Loss"])
    plt.show()

# Plot the loss

In [None]:
TRAINED_MODEL_PATH = "/media/gromovnik/Vice_SSD/01. Projects/01. THEIA/theia_pointnet/model/trained_models/mcb_2.pt"
TRAIN_LOSS_PATH = "/media/gromovnik/Vice_SSD/01. Projects/01. THEIA/theia_pointnet/model/trained_models/mcb_2_training_loss.txt"
VALID_LOSS_PATH = "/media/gromovnik/Vice_SSD/01. Projects/01. THEIA/theia_pointnet/model/trained_models/mcb_2_validation_loss.txt"

In [None]:
def txt_to_list(input_path: Union[pathlib.Path, str]) -> List[str]:
    """Loads a .txt file into a list.
    Args:
        input_path: Input path to the txt file.

    Returns:
        List containing data from the .txt file.
    """
    with open(input_path, "r") as f:
        data = f.read().splitlines()
    
    return data

In [None]:
model_name = pathlib.Path(TRAINED_MODEL_PATH).stem
train_loss= txt_to_list(TRAIN_LOSS_PATH)
valid_loss= txt_to_list(VALID_LOSS_PATH)

In [None]:
# Convert str to float & tidy up.
train_loss = [float(el) for el in train_loss]
train_loss = [round(el, 4) for el in train_loss]

valid_loss = [float(el) for el in valid_loss]
valid_loss = [round(el, 4) for el in valid_loss]

In [None]:
plot_loss(model_name, train_loss, valid_loss)

# Load the model & evaluate results

In [None]:
TRAINED_STATE_DICT_PATH = "/media/gromovnik/Vice_SSD/01. Projects/01. THEIA/theia_pointnet/model/trained_models/mcb_2_state_dict"

In [None]:
# Load global config.
config_file = config.Config()

In [None]:
# Set the device & clean the memory
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
torch.cuda.empty_cache()

In [None]:
# Load the model object & test dataset.
dataset_test_path = pathlib.Path(config_file.config["dataset"]["test"])
batch_size = config_file.config["batch_size"]
test_transforms = transforms.Compose([dataset.NormalizePc()])
dataset_test = dataset.McbData(dataset_test_path, test_transforms)
test_loader = torch.utils.data.DataLoader(
        dataset=dataset_test, batch_size=batch_size, shuffle=False
    ) # If shuffle is True, plotting is much harder as point clouds have to kept in memory.

learning_rate = config_file.config["lr"]
point_net = PointNetClassification(len(dataset_test.classes), learning_rate)
point_net.to(device)

In [None]:
# Load the trained model.
# GPU.
point_net.load_state_dict(torch.load(TRAINED_STATE_DICT_PATH))

# CPU.
#point_net.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/point_net/pre_trained_classification.pth', map_location=torch.device('cpu')))
point_net.eval();

In [None]:
# Perform inference on a whole test dataset.
all_preds = []
all_labels = []
with torch.no_grad():
    for data_idx, data in enumerate(test_loader):
        print('Batch [%4d / %4d]' % (data_idx+1, len(test_loader)))
        inputs = data["pc"].to(device).float()
        labels = data["category_idx"].to(device)
                 
        outputs, __, __ = point_net(inputs.transpose(1, 2))
        _, preds = torch.max(outputs.data, 1)

        all_preds += list(preds.cpu().numpy())
        all_labels += list(labels.cpu().numpy())

In [None]:
# Visualize results for a single dataset element.
n_sample = 2705
print("Sample: ", n_sample)
print("Point cloud: ")
print(dataset_test[n_sample]['pc'])
print("Predicted label: ", list(dataset_test.classes.keys())[all_preds[n_sample]])
print("True label: ", list(dataset_test.classes.keys())[all_labels[n_sample]])

In [None]:
viz = viewer.Viewer()
viz.add_pc(dataset_test[n_sample]['pc'], size=2.5)
viz.show()

In [None]:
# Create & visualize confusion matrix.

In [None]:
def plot_confusion_matrix(cm: np.ndarray, classes: List[str], normalize: bool = False,
                          title: str = 'Confusion matrix',
                          cmap: matplotlib.colors.LinearSegmentedColormap = plt.cm.Blues):
    """Visualizes confusion matrix.
    Source: https://deeplizard.com/learn/video/0LhiS6yu2qQ

    Args:
        cm: Confussion matrix.
        classes: List of classes.
        normalize: If true, confusion matrix will be normalized.
        title: Title of the plotted matrix.
        cmap: Used color map.
    """
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
cm = confusion_matrix(all_labels, all_preds);

In [None]:
plt.figure(figsize=(15,15))
plot_confusion_matrix(cm, list(dataset_test.classes.keys()), normalize=True)

In [None]:
# Compute F1 score.
f1 = f1_score(all_labels, all_preds, average='micro')
f1

In [None]:
# Compute IoU.
iou = jaccard_score(all_labels, all_preds, average='micro')
iou