In [3]:
import torch
from torch.nn import functional as F
from models.feature_extractor import ConvMixerFeatureExtractor
from models.motion_classifier import MotionClassifier
from utils.data import CustomDataset
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

In [16]:

@torch.no_grad()
def test():
    testset = CustomDataset("test")
    test_loader = DataLoader(testset, batch_size=8, num_workers=4)

    feature_extractor = ConvMixerFeatureExtractor()
    motion_classifier = MotionClassifier(feature_extractor).to(0)

    motion_classifier.load_state_dict(torch.load("ckpts/motion_classifier_20220829.pt")["motion_classifier"])

    motion_classifier.eval()

    fusion_acc = 0.0
    frame_acc = 0.0
    sensor_acc = 0.0

    for frames, sensors, labels in test_loader:
        frames = frames.to(0)
        sensors = sensors.to(0)
        labels = labels.long().to(0)

        result =  motion_classifier(frames, sensors)

        fusion_preds = result["logits"].argmax(dim=1).cpu().numpy()
        frame_preds = result["image_motion_logits"].argmax(dim=1).cpu().numpy()
        sensor_preds = result["sensor_motion_logits"].argmax(dim=1).cpu().numpy()

        fusion_acc += accuracy_score(labels.cpu().numpy(), fusion_preds) / len(test_loader)
        frame_acc += accuracy_score(labels.cpu().numpy(), frame_preds) / len(test_loader)
        sensor_acc += accuracy_score(labels.cpu().numpy(), sensor_preds) / len(test_loader)

    return fusion_acc, frame_acc, sensor_acc

In [17]:
test()

(0.9375, 0.78125, 0.84375)