In [1]:
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from thermoclassifier.combined.net import ThermoClassifier
from thermoclassifier.dataset.dataset_creator import *

In [2]:
measurement = 'C'
batch_size = 256
seq_len = 5

dc = DatasetCreator(elements=None, splits=(1., 0.), validation=False, seq_len=seq_len, measurement=measurement, user='phase')
test_dataset, _, _ = dc.get_datasets()

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Dataset shape:  (25605, 5, 4)


In [3]:
net = ThermoClassifier()

In [4]:
element_correct = 0
element_incorrect = 0
phase_correct = 0
phase_incorrect = 0
combined_correct = 0
combined_incorrect = 0

for d in test_loader:
    # Get the predictions
    inp = d[:, :, :-2]
    inp[:, :, 0] /= 1000
    predictions = net(inp.float()).squeeze()
    
    # Get the correct/incorrect element predictions
    element_predictions = predictions[:, 0, 2]
    element_targets = d[:, 0, 2]
    correct = (element_predictions == element_targets).sum().item()
    element_correct += correct
    element_incorrect += len(element_targets) - correct
    
    # Get the correct/incorrect phase predictions
    phase_predictions = predictions[:, :, 3]
    phase_targets = d[:, :, 3]
    correct = (phase_predictions == phase_targets).sum().item()
    phase_correct += correct
    phase_incorrect += np.prod(phase_predictions.shape) - correct
    
    # Get the combined correct/incorrect predictions
    combined_predictions = predictions[:, :, [2, 3]]
    combined_targets = d[:, :, [2, 3]]
    correct = (combined_predictions == combined_targets).prod(dim=-1).sum().item()
    combined_correct += correct
    combined_incorrect += np.prod(combined_predictions.shape[:2]) - correct

print('Element accuracy: ', element_correct/(element_correct + element_incorrect))
print('Phase accuracy: ', phase_correct/(phase_correct + phase_incorrect))
print('Combined accuracy: ', combined_correct/(combined_correct + combined_incorrect))

Element accuracy:  0.9785588752196837
Phase accuracy:  0.9455262644014841
Combined accuracy:  0.9364186682288616
