# Inteligência Computacional em Saúde - Trabalho 02

## Carregamento e Pré-Processamentos dos Dados

In [1]:
# imports and setup
import os
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from typing import Tuple, List
from tqdm.notebook import tqdm

from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchsummary import summary
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import Dataset, DataLoader, Subset

from torchvision import models, datasets
from torchvision.transforms import v2 as transforms

In [2]:
DATASET_PATH = './data/bone-marrow-dataset'

CHARTS_PATH = "./charts"
os.makedirs(f"{CHARTS_PATH}", exist_ok=True)

MODELS_PATH = "./models"
os.makedirs(f"{MODELS_PATH}", exist_ok=True)

TARGET_CLASSES = ['ABE', 'BAS', 'BLA', 'EOS', 'FGC', 'MMZ', 'MYB', 'NGB', 'NGS', 'PMO']

In [3]:
class TransformedDataset(Dataset):
    """
    Custom Dataset wrapper that applies transformations to images dynamically.

    Args:
        data (Dataset): the original dataset
        transforms (callable, optional): transformations to apply to each image sample
    """
    def __init__(self, data: Dataset, transforms=None)-> None:
        """Initialize the TransformedDataset."""
        self._data = data
        self._transforms = transforms
        self.classes = self._data.dataset.classes

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self._data)

    def __getitem__(self, idx: int)-> Tuple[torch.Tensor, int]:
        """Retrieve a sample from the dataset at the given index, apply transformations if provided."""
        img, label = self._data[idx]
        if self._transforms:
            img = self._transforms(img)

        return img, label

In [4]:
full_dataset = datasets.ImageFolder(root=DATASET_PATH)

# extract labels for stratification
labels = full_dataset.targets
class_names = full_dataset.classes
indices = list(range(len(full_dataset)))

# first split: 70% for training, 30% for temporary set (val + test)
train_indices, temp_indices, train_labels, temp_labels = train_test_split(
    indices, labels, test_size=0.3, random_state=42, stratify=labels
)

# second split: 50/50 for validation and test (15% each)
val_indices, test_indices, _, _ = train_test_split(
    temp_indices, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels
)

# dataset creation
train_subset = Subset(full_dataset, train_indices)
val_subset = Subset(full_dataset, val_indices)
test_subset = Subset(full_dataset, test_indices)

In [5]:
# data augmentation and transformation pipelines
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [6]:
# wrap the subsets to apply the correct transformations
train_dataset = TransformedDataset(train_subset, transforms=train_transforms)
val_dataset = TransformedDataset(val_subset, transforms=val_test_transforms)
test_dataset = TransformedDataset(test_subset, transforms=val_test_transforms)

In [7]:
# define a weighted random sampler to handle class imbalance
y_train = torch.tensor(train_labels)
class_weights = 1. / torch.bincount(y_train)
sampler = WeightedRandomSampler(weights=class_weights[y_train], num_samples=len(class_weights[y_train]), replacement=True)

# create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [8]:
def show_images(dataloader: DataLoader, n_images: int = 10) -> None:
    """Displays a batch of images from a DataLoader."""

    # get one batch of images and labels from the DataLoader
    images, labels = next(iter(dataloader))
    class_names = dataloader.dataset.classes

    plt.figure(figsize=(16, 8))
    for i in range(n_images):
        if i >= len(images): break
        ax = plt.subplot(2, n_images // 2, i + 1)

        # undo normalization and convert tensor to numpy image
        img = images[i].numpy().transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1)

        plt.imshow(img)
        plt.title(class_names[labels[i]])
        plt.axis("off")

    plt.suptitle("Visualização de Imagens do DataLoader", fontsize=16)
    plt.show()

In [None]:
# displays train images
show_images(train_loader)

In [None]:
# displays validation images
show_images(val_loader)

## Funções Auxiliares 

In [11]:
def train_loop(model: nn.Module, train_loader: DataLoader, loss_fn, optimizer, device: torch.device) -> Tuple[float, List[int], List[int]]:
  """Trains the model over a dataset (training set)."""

  model.train()
  total_loss = 0.0
  all_preds, all_labels = [], []

  for inputs, labels in tqdm(train_loader, desc="Training"):
    # move inputs and labels to the selected device
    inputs, labels = inputs.to(device), labels.to(device)

    # forward pass
    outputs = model(inputs)
    loss = loss_fn(outputs, labels)

    # backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # accumulate loss
    total_loss += loss.cpu().item()

    # store predictions and labels for metrics
    all_labels.extend(labels.cpu().numpy())
    all_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())

  return total_loss / len(train_loader), all_labels, all_preds

def evaluation_loop(model: nn.Module, loader: DataLoader, loss_fn, device: torch.device) -> Tuple[float, List[int], List[int]]:
  """Evaluates the model over a dataset (validation or test set)."""

  model.eval()
  total_loss = 0
  all_preds, all_labels = [], []

  with torch.no_grad():
    for inputs, labels in tqdm(loader, desc="Evaluating"):
      # move inputs and labels to the selected device
      inputs, labels = inputs.to(device), labels.to(device)

      # forward pass
      outputs = model(inputs)
      loss = loss_fn(outputs, labels)
      total_loss += loss.item()

      # store predictions and labels for metrics
      preds = torch.argmax(outputs, dim=1).cpu().numpy()
      all_preds.extend(preds)
      all_labels.extend(labels.cpu().numpy())

    # return average loss and collected predictions and labels
    return total_loss / len(loader), all_labels, all_preds

In [12]:
def show_failure_cases_from_results(dataset, y_true, y_pred, class_names, n_failures=15):
    """Finds and displays images from a dataset where the prediction was incorrect."""
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # find the indices where the prediction was wrong
    mismatched_idx = np.where(y_true != y_pred)[0]
    failures_to_show_idx = np.random.choice(mismatched_idx, size=min(n_failures, len(mismatched_idx)), replace=False)

    # grid plotting logic
    cols = 5
    rows = int(np.ceil(len(failures_to_show_idx) / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(20, 3 * rows))
    axes = axes.flatten()

    for i, idx in enumerate(failures_to_show_idx):
        image_tensor, true_label_idx = dataset[idx]
        pred_label_idx = y_pred[idx]

        # reverts the normalization on a tensor image
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        tensor = image_tensor.clone().cpu().numpy().transpose((1, 2, 0))

        img_to_show = std * tensor + mean
        img_to_show = np.clip(img_to_show, 0, 1)

        # plot the image
        ax = axes[i]
        ax.imshow(img_to_show)
        ax.set_title(f"True: {class_names[true_label_idx]}\nPred: {class_names[pred_label_idx]}", color='red', fontsize=10)
        ax.axis('off')

    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

## CNN Customizada

In [None]:
# custon CNN creation
custom_cnn = nn.Sequential(
    nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),

    nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),

    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),

    nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),

    # classifier block
    nn.Flatten(),
    nn.Linear(in_features=128 * 14 * 14, out_features=512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(in_features=512, out_features=len(class_names))
)

# determine device and move model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
custom_cnn = custom_cnn.to(device)

# displays a summary of the custom CNN model
summary(custom_cnn, input_size=(3, 224, 224), device=device.type)

In [None]:
# initialize loss and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(custom_cnn.parameters(), lr=1e-3)

# metrics for early stopping
patience = 5
best_epoch = 0
epochs_no_improve = 0
best_val_loss = float('inf')

history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

for epoch in range(20):
    # training step
    train_loss, true_train, pred_train = train_loop(custom_cnn, train_loader, loss_fn, optimizer, device)
    train_acc = accuracy_score(true_train, pred_train)

    # validation step
    val_loss, true_val, pred_val = evaluation_loop(custom_cnn, val_loader, loss_fn, device)
    val_acc = accuracy_score(true_val, pred_val)

    # store metrics
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)

    print(f"Epoch {epoch+1}/{20} Summary | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"Epoch {epoch+1}/{20} Summary | Train Acc: {train_acc:.4f}  | Val Acc: {val_acc:.4f}")

    # early stopping verification
    if val_loss < best_val_loss:
        best_epoch = epoch
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(custom_cnn.state_dict(), f"{MODELS_PATH}/best_custom_cnn.pth")
        print(f"Validation loss improved. Saving model to best_custom_cnn.pth")
    else:
        epochs_no_improve += 1
        print(f"Validation loss did not improve. Patience: {epochs_no_improve}/{patience}")

    if epochs_no_improve >= patience:
        print(f"\nEarly stopping triggered after {patience} epochs with no improvement.")
        break

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# plot loss values
ax1.plot(history['train_loss'], label='Loss de Treinamento', marker='o', color='red')
ax1.plot(history['val_loss'], label='Loss de Validação', marker='o', color='blue')
ax1.axvline(x=best_epoch, color='g', linestyle='--', label=f'Melhor Modelo (Época {best_epoch+1})')
ax1.set_title("Evolução da Loss por Época")
ax1.set_xlabel('Épocas')
ax1.set_ylabel('Loss')
ax1.legend()

# plot accuracy values
ax2.plot(history['train_acc'], label='Acurácia de Treinamento', marker='o', color='red')
ax2.plot(history['val_acc'], label='Acurácia de Validação', marker='o', color='blue')
ax2.axvline(x=best_epoch, color='g', linestyle='--', label=f'Melhor Modelo (Época {best_epoch+1})')
ax2.set_title('Evolução da Acurácia por Época')
ax2.set_xlabel('Épocas')
ax2.set_ylabel('Acurácia')
ax2.legend()

plt.savefig(f"{CHARTS_PATH}/custom_cnn_training.svg", format='svg', bbox_inches='tight')
plt.show()

In [None]:
# displays the classification report
custom_cnn.load_state_dict(torch.load(f"{MODELS_PATH}/best_custom_cnn.pth"))
custom_cnn_test_loss, custom_cnn_true_test, custom_cnn_pred_test = evaluation_loop(custom_cnn, test_loader, loss_fn, device)
print(classification_report(custom_cnn_true_test, custom_cnn_pred_test, target_names=class_names, zero_division=0))

In [None]:
# plot the confusion matrix
fig, ax = plt.subplots(figsize=(10, 8))
ConfusionMatrixDisplay.from_predictions(y_true=custom_cnn_true_test, y_pred=custom_cnn_pred_test, cmap='Reds', display_labels=class_names, normalize='true', ax=ax)
ax.set_title('CNN Customizada - Matriz de Confusão')
ax.set_ylabel('Classe Real')
ax.set_xlabel('Classe Prevista')

plt.savefig(f"{CHARTS_PATH}/custom_cnn_confusion_matrix.svg", format='svg', bbox_inches='tight')
plt.show()

In [None]:
show_failure_cases_from_results(test_dataset, custom_cnn_true_test, custom_cnn_pred_test, class_names)

##  CNN Pré-Treinada com Camadas Convolucionais Congeladas

In [None]:
pretrained_model = models.resnet50(weights="IMAGENET1K_V2")

# freeze all layers to prevent them from being updated during training
for param in pretrained_model.parameters():
    param.requires_grad = False

# replace the final fc layer to match the number of classes in the dataset
num_features = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(num_features, len(class_names))

# determine device and move model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model = pretrained_model.to(device)

# displays a summary of the CNN model
summary(pretrained_model, input_size=(3, 224, 224), device=device.type)

In [None]:
# initialize loss and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(pretrained_model.fc.parameters(), lr=1e-3)

# metrics for early stopping
patience = 5
best_epoch = 0
epochs_no_improve = 0
best_val_loss = float('inf')

history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

for epoch in range(20):
    # training step
    train_loss, true_train, pred_train = train_loop(pretrained_model, train_loader, loss_fn, optimizer, device)
    train_acc = accuracy_score(true_train, pred_train)

    # validation step
    val_loss, true_val, pred_val = evaluation_loop(pretrained_model, val_loader, loss_fn, device)
    val_acc = accuracy_score(true_val, pred_val)

    # store metrics
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)

    print(f"Epoch {epoch+1}/{20} Summary | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"Epoch {epoch+1}/{20} Summary | Train Acc: {train_acc:.4f}  | Val Acc: {val_acc:.4f}")

    # early stopping verification
    if val_loss < best_val_loss:
        best_epoch = epoch
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(pretrained_model.state_dict(), f"{MODELS_PATH}/best_frozen_pretrained_model.pth")
        print(f"Validation loss improved. Saving model to best_frozen_pretrained_model.pth")
    else:
        epochs_no_improve += 1
        print(f"Validation loss did not improve. Patience: {epochs_no_improve}/{patience}")

    if epochs_no_improve >= patience:
        print(f"\nEarly stopping triggered after {patience} epochs with no improvement.")
        break

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# plot loss values
ax1.plot(history['train_loss'], label='Loss de Treinamento', marker='o', color='red')
ax1.plot(history['val_loss'], label='Loss de Validação', marker='o', color='blue')
ax1.axvline(x=best_epoch, color='g', linestyle='--', label=f'Melhor Modelo (Época {best_epoch+1})')
ax1.set_title("Evolução da Loss por Época")
ax1.set_xlabel('Épocas')
ax1.set_ylabel('Loss')
ax1.legend()

# plot accuracy values
ax2.plot(history['train_acc'], label='Acurácia de Treinamento', marker='o', color='red')
ax2.plot(history['val_acc'], label='Acurácia de Validação', marker='o', color='blue')
ax2.axvline(x=best_epoch, color='g', linestyle='--', label=f'Melhor Modelo (Época {best_epoch+1})')
ax2.set_title('Evolução da Acurácia por Época')
ax2.set_xlabel('Épocas')
ax2.set_ylabel('Acurácia')
ax2.legend()

plt.savefig(f"{CHARTS_PATH}/frozen_pretained_model_training.svg", format='svg', bbox_inches='tight')
plt.show()

In [None]:
# displays the classification report
pretrained_model.load_state_dict(torch.load(f"{MODELS_PATH}/best_frozen_pretrained_model.pth"))
pretrained_model_test_loss, pretrained_model_true_test, pretrained_model_pred_test = evaluation_loop(pretrained_model, test_loader, loss_fn, device)
print(classification_report(pretrained_model_true_test, pretrained_model_pred_test, target_names=class_names))

In [None]:
# plot the confusion matrix
fig, ax = plt.subplots(figsize=(10, 8))
ConfusionMatrixDisplay.from_predictions(y_true=pretrained_model_true_test, y_pred=pretrained_model_pred_test, cmap='Reds', display_labels=class_names, normalize='true', ax=ax)
ax.set_title('CNN Pré-Treinada - Matriz de Confusão')
ax.set_ylabel('Classe Real')
ax.set_xlabel('Classe Prevista')

plt.savefig(f"{CHARTS_PATH}/pretrained_confusion_matrix.svg", format='svg', bbox_inches='tight')
plt.show()

In [None]:
show_failure_cases_from_results(test_dataset, pretrained_model_true_test, pretrained_model_pred_test, class_names)

## CNN Totalmente Pré-Treinada (*Fine-Tuning*)

In [None]:
finetuned_model = models.resnet50(weights="IMAGENET1K_V2")

# replace the final fc layer to match the number of classes in the dataset
num_features = finetuned_model.fc.in_features
finetuned_model.fc = nn.Linear(num_features, len(class_names))

# freeze all layers of the model
for param in finetuned_model.parameters():
    param.requires_grad = False

# unfreeze only the parameters of the new classifier head
for param in finetuned_model.fc.parameters():
    param.requires_grad = True

# determine device and move model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
finetuned_model = finetuned_model.to(device)

# displays a summary of the CNN model
summary(finetuned_model, input_size=(3, 224, 224), device=device.type)

In [None]:
# initialize loss function and optimizer for the head
loss_fn = nn.CrossEntropyLoss()
optimizer_head = optim.Adam(finetuned_model.fc.parameters(), lr=1e-3)

# train the head for a few epochs to warm it up
for epoch in range(5):
    train_loss, _, _ = train_loop(finetuned_model, train_loader, loss_fn, optimizer_head, device)
    val_loss, _, _ = evaluation_loop(finetuned_model, val_loader, loss_fn, device)
    print(f"Epoch [Head] {epoch+1}/5 | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

In [None]:
# unfreeze all layers for fine-tuning
for param in finetuned_model.parameters():
    param.requires_grad = True

# create optimizer with a low learning rate for the entire model
optimizer_finetune = optim.Adam(finetuned_model.parameters(), lr=1e-5)

# metrics for early stopping
patience = 5
best_epoch = 0
epochs_no_improve = 0
best_val_loss = float('inf')

history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

for epoch in range(20):
    # training step
    train_loss, true_train, pred_train = train_loop(finetuned_model, train_loader, loss_fn, optimizer_finetune, device)
    train_acc = accuracy_score(true_train, pred_train)

    # validation step
    val_loss, true_val, pred_val = evaluation_loop(finetuned_model, val_loader, loss_fn, device)
    val_acc = accuracy_score(true_val, pred_val)

    # store metrics
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)

    print(f"Epoch {epoch+1}/{20} Summary | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"Epoch {epoch+1}/{20} Summary | Train Acc: {train_acc:.4f}  | Val Acc: {val_acc:.4f}")

    # early stopping verification
    if val_loss < best_val_loss:
        best_epoch = epoch
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(finetuned_model.state_dict(), f"{MODELS_PATH}/best_finetuned_model.pth")
        print(f"Validation loss improved. Saving model to best_finetuned_model.pth")
    else:
        epochs_no_improve += 1
        print(f"Validation loss did not improve. Patience: {epochs_no_improve}/{patience}")

    if epochs_no_improve >= patience:
        print(f"\nEarly stopping triggered after {patience} epochs with no improvement.")
        break

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# plot loss values
ax1.plot(history['train_loss'], label='Loss de Treinamento', marker='o', color='red')
ax1.plot(history['val_loss'], label='Loss de Validação', marker='o', color='blue')
ax1.axvline(x=best_epoch, color='g', linestyle='--', label=f'Melhor Modelo (Época {best_epoch+1})')
ax1.set_title("Evolução da Loss por Época")
ax1.set_xlabel('Épocas')
ax1.set_ylabel('Loss')
ax1.legend()

# plot accuracy values
ax2.plot(history['train_acc'], label='Acurácia de Treinamento', marker='o', color='red')
ax2.plot(history['val_acc'], label='Acurácia de Validação', marker='o', color='blue')
ax2.axvline(x=best_epoch, color='g', linestyle='--', label=f'Melhor Modelo (Época {best_epoch+1})')
ax2.set_title('Evolução da Acurácia por Época')
ax2.set_xlabel('Épocas')
ax2.set_ylabel('Acurácia')
ax2.legend()

plt.savefig(f"{CHARTS_PATH}/finetuned_model_training.svg", format='svg', bbox_inches='tight')
plt.show()

In [None]:
# displays the classification report
finetuned_model.load_state_dict(torch.load(f"{MODELS_PATH}/best_finetuned_model.pth"))
finetuned_model_test_loss, finetuned_model_true_test, finetuned_model_pred_test = evaluation_loop(finetuned_model, test_loader, loss_fn, device)
print(classification_report(finetuned_model_true_test, finetuned_model_pred_test, target_names=class_names, zero_division=0))

In [None]:
# plot the confusion matrix
fig, ax = plt.subplots(figsize=(10, 8))
ConfusionMatrixDisplay.from_predictions(y_true=finetuned_model_true_test, y_pred=finetuned_model_pred_test, cmap='Reds', display_labels=class_names, normalize='true', ax=ax)
ax.set_title('CNN com Fine-Tuning - Matriz de Confusão')
ax.set_ylabel('Classe Real')
ax.set_xlabel('Classe Prevista')

plt.savefig(f"{CHARTS_PATH}/finetuned_confusion_matrix.svg", format='svg', bbox_inches='tight')
plt.show()

In [None]:
show_failure_cases_from_results(test_dataset, finetuned_model_true_test, finetuned_model_pred_test, class_names)