## Testing

In [1]:
import torch

### Clases más probables para una imagen

In [2]:
def pred_best_labels_image(net, image, n=5):
    net.eval()
    output = net(image.unsqueeze(0).to(net.device))[0]
    distr_class = output.softmax(dim=0)
    sorted_prob, sorted_idx = torch.sort(distr_class, descending=True, dim=0)
    
    return [(sorted_idx[i].item(), sorted_prob[i].item() * 100) for i in range(n)]

### Distribución de clases para una imagen

In [None]:
from matplotlib import pyplot as plt

def pred_distribution(net, image):
    net.eval()
    output = net(image.unsqueeze(0))
    probs = torch.nn.Softmax(dim=-1)(output)
    probs = probs.detach().cpu().numpy()
    
    plt.plot(probs[0])
    plt.xlabel('label')
    plt.show()

### Predecir clases para archivos de una carpeta

In [3]:
from matplotlib import pyplot as plt
import os

def predict_folder(net, folder_path, transformation=None, samples=0, labels=None):
    
    net.eval()
    
    valid_filenames = []
    folder_predictions = []
    
    for filename in os.listdir(folder_path):
        
        image_path = os.path.join(folder_path, filename)
        if os.path.splitext(image_path)[-1] in ('.jpg', '.jpeg'):
            image = Image.open(image_path)
            if transformation:
                image = transformation(image)
            output = net(image.unsqueeze(0).to(net.device))
            pred = output.argmax(dim=1).item()
            
            valid_filenames.append(filename)
            folder_predictions.append(pred)
    
    if samples == 'all':
        samples = len(valid_filenames)
        
    plt.figure(figsize = (9, 6))
    ncols, nrows = 3, -(samples//-3)  
    for i in range(samples):
        
        image_path = os.path.join(folder_path, valid_filenames[i])
        image = Image.open(image_path)
        label = labels[folder_predictions[i]] if labels else folder_predictions[i]

        plt.subplot(nrows, ncols, i+1)
        plt.title(label)
        plt.axis('off')
        plt.imshow(image)
    plt.tight_layout()

### Matriz de confusión

In [4]:
from sklearn import metrics
from matplotlib import pyplot as plt

def predict_dataset(net, dataset):
    net.eval()
    labels, preds = [], []
    for image, label in dataset:
        output = net(image.unsqueeze(0))
        pred = output.argmax(dim=1).item()
        labels.append(label)
        preds.append(pred)
    
    return labels, preds

def confusion_matrix(labels, preds):
    
    fig = plt.figure(figsize = (10, 8))
    ax = fig.add_subplot(1, 1, 1)
    cm = metrics.confusion_matrix(labels, preds)
    cm = metrics.ConfusionMatrixDisplay(cm)
    cm.plot(cmap = 'Oranges', ax = ax)
    print(metrics.classification_report(labels, preds))

### Predecir dataset

In [None]:
from matplotlib import pyplot as plt
from torchvision import transforms

def predict_dataset(net, dataset, n=21, random=True, unnormalize=None, classes_names=None):

    rows = -(n//-7)  # math.ceil(n/10)
    cols = -(n//-rows)  # math.ceil(n/rows)

    fig = plt.figure(figsize=(3*cols, 3*rows))
    plt.tight_layout()
    
    for i in range(n):
        index = torch.randint(0, len(dataset), [1]).item() if random else i
        image, label = dataset[index]
        pred = net(image.unsqueeze(0)).argmax(dim=1).item()
        
        if classes_names:
            label = classes_names[label]
            pred = classes_names[pred]
        
        if unnormalize == 'imagenet':
            mean, std = torch.tensor([(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)])
        elif unnormalize:
            mean, std = unnormalize
        else:
            mean, std = 0, 1
    
        image = image.permute(1, 2, 0) * std + mean
        
        fig.add_subplot(rows, cols, i + 1)
        plt.imshow(image)
        plt.axis('off')
        plt.title(f'{label} | {pred}')
        