In [None]:
from get_data_v2 import get_dataloader_v2
from get_data import get_dataloader
import pathlib
import torch 
from tqdm import tqdm 
from train import get_needed_metrics
from efficient_net import Efficient_Net
import torch.nn as nn
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
import random 

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
test_path = '/home/shirshak/Thesis_Data/DOES/TEST_raw/'

In [None]:
test_loader = get_dataloader(test_path, get_path=True, batch_size = 128, shuffle=False)

In [None]:
root=pathlib.Path(test_path)
classes=sorted([j.name.split('/')[-1] for j in root.iterdir()])
classes

In [None]:
model = Efficient_Net(classes=classes)
loss_func = nn.CrossEntropyLoss()

model.to(device)
checkpoint_path = "/home/shirshak/Thesis_Classification_Code/model_best_val_f1.pth"
checkpoint = torch.load(checkpoint_path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])


In [None]:
model.eval()
acc_test_epoch, precision_test_epoch, recall_test_epoch, f1_test_epoch  = [], [], [], []

overall_input, overall_labels, overall_predicted, overall_image_path = [], [], [], []
with torch.no_grad():
    for inputs, labels, image_path in tqdm(test_loader, desc=f'Testing', unit='batch'):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        test_loss = loss_func(outputs, labels)
        _, predicted_test = torch.max(outputs, 1)

        overall_input.extend(inputs)
        overall_labels.extend(labels)
        overall_predicted.extend(predicted_test)
        overall_image_path.extend(image_path)

        acc_batch_test, precision_batch_test, recall_batch_test, f1_batch_test = get_needed_metrics(labels.cpu().detach().tolist(), predicted_test.cpu().detach().tolist())

        acc_test_epoch.append(acc_batch_test)
        precision_test_epoch.append(precision_batch_test)
        recall_test_epoch.append(recall_batch_test)
        f1_test_epoch.append(f1_batch_test)


    print(
        f'Val Loss: {test_loss.item():.4f}, '
        f'Val Accuracy: {torch.tensor(acc_test_epoch).mean() * 100:.2f}%, '
        f'Val Precision: {torch.tensor(precision_test_epoch).mean() * 100:.2f}%, '
        f'Val Recall: {torch.tensor(recall_test_epoch).mean() * 100:.2f}%, '
        f'Val F1: {torch.tensor(f1_test_epoch).mean() * 100:.2f}%')

In [None]:
overall_labels, overall_predicted

In [None]:
confusion_matrix_chart = confusion_matrix(torch.tensor(overall_labels).cpu(), torch.tensor(overall_predicted).cpu())
cm_display = ConfusionMatrixDisplay(confusion_matrix = confusion_matrix_chart, display_labels = ['BG', 'E1', 'E2', 'E3', 'E40', 'E5H', 'E6', 'E8', 'EHRB'])

cm_display.plot()
plt.title("Confusion Matrix")
plt.savefig('confusion_matrix.png', dpi=500)
plt.show()
plt.close()

In [None]:
right_cases = [(x_w, y_w, yp_w, img_p) for x_w, y_w, yp_w, img_p in zip(overall_input, overall_labels, overall_predicted, overall_image_path) if y_w == yp_w]
wrong_cases = [(x_w, y_w, yp_w, img_p) for x_w, y_w, yp_w, img_p in zip(overall_input, overall_labels, overall_predicted, overall_image_path) if y_w != yp_w]

In [None]:
classes

In [None]:
some_right_examples = random.sample(right_cases, min(20, len(right_cases)))

for count, right_case in enumerate(some_right_examples):
    fig, ax = plt.subplots(1, 2, figsize=(5, 5))
    
    ax[0].imshow(Image.open(right_case[3]))
    ax[0].set_title('Original Image')
    ax[0].axis('off')
    ax[0].text(0.5, -0.1, f'Real : {classes[right_case[1]]}', ha='center', va='center', transform=ax[0].transAxes, fontsize=10)

    ax[1].imshow(torchvision.transforms.ToPILImage()(right_case[0]))
    ax[1].set_title('Transformed Image')
    ax[1].axis('off')
    ax[1].text(0.5, -0.1, f'Predicted : {classes[right_case[2]]}', ha='center', va='center', transform=ax[1].transAxes, fontsize=10)

    plt.tight_layout()
    # plt.savefig(f"/home/shirshak/Glaucoma_Efficientnet_simple/glaucoma_test_images/correct{count}.jpg")
    plt.show()
    plt.close()

In [None]:
some_wrong_examples = random.sample(wrong_cases, min(20, len(wrong_cases)))

for count,wrong_case in enumerate(some_wrong_examples):
    fig, ax = plt.subplots(1,2, figsize=(5,5))

    ax[0].imshow(Image.open(wrong_case[3]))
    ax[0].set_title('Original Image')
    ax[0].axis('off')
    ax[0].text(0.5, -0.1, f'Real : {classes[wrong_case[1]]}', ha='center', va='center', transform=ax[0].transAxes, fontsize=10)

    ax[1].imshow(torchvision.transforms.ToPILImage()(wrong_case[0]))
    ax[1].set_title('Transformed Image')
    ax[1].axis('off')
    ax[1].text(0.5, -0.1, f'Predicted : {classes[wrong_case[2]]}', ha='center', va='center', transform=ax[1].transAxes, fontsize=10)

    plt.tight_layout()
    plt.show()
    # plt.savefig(f"/home/shirshak/Glaucoma_Efficientnet_simple/glaucoma_test_images/wrong{count}.jpg")
    plt.close()