In [107]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data import Subset
from torch.utils.data import random_split
from torch.utils.data import ConcatDataset
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
from torchvision.datasets import ImageFolder
from PIL import Image
import numpy as np
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F

test_file = 'data_train.npy'
labels_file = 'labels_train.npy'


In [108]:
class MathSymbols(Dataset):
    def __init__(self, data_file, data_transform=None, label_transform=None):
        '''
        data_file: Data file for training the model (should be .npy file)
        labels_file: Target labels (should be .npy file)   
        '''
        # Load data 
        self.data = np.load(data_file)
        #self.labels = np.load(labels_file)
        self.data_transform = data_transform
        #self.label_transform = label_transform
        
        # Reshape data
        self.data = np.reshape(np.transpose(self.data), (np.shape(self.data)[1], 100, 100))

        # Convert the images to have 3 channels so it will function with pretrained models
        self.data = self.convert_to_rgb(self.data)

    def __len__(self):
        # Length of data (and labels) array
        return len(self.data)

    def __getitem__(self, idx):
        # Fetch the dataset item
        sample = self.data[idx]
        #label = self.labels[idx]
        sample = Image.fromarray(sample)  # Convert numpy array to PIL Image
        
        if self.data_transform:
            # Transform the data
            sample = self.data_transform(sample)
        
       # if self.label_transform:
            # Transform the label
        #    label = self.label_transform(label)
        
        # Convert sample and label to tensor
        #label = torch.tensor(label, dtype=torch.long)
        
        return sample

    def convert_to_rgb(self, grayscale_images):
        # Add first input channel
        grayscale_images = np.expand_dims(grayscale_images, -1)
        # Change to 3 input channels to match with pretrained models
        rgb_images = grayscale_images.repeat(3, axis=-1)
        return rgb_images
    
    
    
class TransformSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x

    def __len__(self):
        return len(self.subset)
    
    
def Test(X):
    
    num_classes = 10 
    
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(0.5, 0.5)])

    batch_size = 8

    test_set = MathSymbols(data_file = X)

    # Apply transformations using the wrapper class
    test = TransformSubset(test_set, transform=transform)


    # Test set
    testloader = torch.utils.data.DataLoader(test, batch_size=batch_size,
                                            shuffle=False, num_workers=2)
    
    net = models.resnet50(pretrained=True)
    net.fc = nn.Linear(net.fc.in_features, num_classes) 

    # Load the saved state dictionary into the model
    net.load_state_dict(torch.load('Best_Model.pth'))
    
    predictions = np.array([])
    
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            images = data
            # calculate outputs by running images through the network
            outputs = net(images) 
            _, predicted = torch.max(outputs.data, 1)
            
            predictions = np.append(predictions, predicted)
        
           
    return predictions


In [109]:
Y = Test(X = test_file)

In [110]:
correct = 0
total = 0

num_classes = 10 

labels = np.load(labels_file)

total = labels.size

for i in range(total):
    
    if (labels[i] == Y[i]):
        correct += 1
    

print(f'Accuracy is: {100 * correct // total} %')

Accuracy is: 98 %
