In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import glob
import os
import json
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import re
import warnings
import numba
warnings.filterwarnings('ignore')

In [None]:
from google.colab import drive
print("Montando o Google Drive...")
drive.mount('/content/drive')
print("Drive montado com sucesso!")
BASE_PATH = "/content/drive/MyDrive/TCC"


Montando o Google Drive...
Mounted at /content/drive
Drive montado com sucesso!


In [None]:

class FMRISkimmedDataset(Dataset):
    """
    Dataset para imagens fMRI extra√≠das com algoritmo de skimming
    Estrutura: slices/sub-x/orientacao/slice_orientacao_XX_idxYYY.png
    Exemplo: slices/sub-10159/axial/slice_axial_01_idx022.png
    """
    def __init__(self, metadata_df, base_path, transform=None, slices_per_view=2, validate=True):
        """
        Args:
            metadata_df: DataFrame com metadados dos pacientes
            base_path: Caminho base dos dados (j√° aponta para a pasta slices)
            transform: Transforma√ß√µes a aplicar nas imagens
            slices_per_view: N√∫mero de slices por orienta√ß√£o
            validate: Se True, valida disponibilidade de slices (padr√£o: True)
        """
        self.metadata_df = metadata_df.reset_index(drop=True)
        self.base_path = base_path
        self.transform = transform
        self.slices_per_view = slices_per_view
        self.label_map = {'CONTROL': 0, 'SCHZ': 1}

        # Valida que todos os pacientes t√™m slices dispon√≠veis
        if validate:
            self._validate_dataset()

    def __len__(self):
        return len(self.metadata_df)

    def _validate_dataset(self):
        """
        Valida que todos os pacientes no dataset t√™m slices dispon√≠veis
        Remove pacientes sem imagens processadas
        """
        valid_indices = []

        for idx, row in self.metadata_df.iterrows():
            participant_id = row['participant_id']
            subject_dir = os.path.join(self.base_path, f"sub-{participant_id}")

            # Verifica se o diret√≥rio do sujeito existe
            if not os.path.exists(subject_dir):
                continue

            # Verifica se tem subpastas de orienta√ß√£o e arquivos dentro
            has_all_orientations = True
            for orientation in ['sagital', 'coronal', 'axial']:
                orientation_dir = os.path.join(subject_dir, orientation)

                # Verifica se a pasta da orienta√ß√£o existe
                if not os.path.exists(orientation_dir):
                    has_all_orientations = False
                    break

                # Verifica se tem arquivos .png dentro
                pattern = os.path.join(orientation_dir, f"slice_{orientation}_*_idx*.png")
                files = glob.glob(pattern)
                if len(files) == 0:
                    has_all_orientations = False
                    break

            if has_all_orientations:
                valid_indices.append(idx)

        # Filtra o dataframe para manter apenas pacientes v√°lidos
        original_size = len(self.metadata_df)
        self.metadata_df = self.metadata_df.iloc[valid_indices].reset_index(drop=True)
        removed = original_size - len(self.metadata_df)

        if removed > 0:
            print(f"{removed} pacientes removidos (sem slices processados)")
        print(f"‚úì {len(self.metadata_df)} pacientes v√°lidos no dataset")

    def _get_ranked_slices(self, participant_id, orientation, n_slices):
        """
        Carrega os N primeiros slices de uma orienta√ß√£o
        Estrutura: slices/sub-x/orientacao/slice_orientacao_XX_idxYYY.png
        Exemplo: slices/sub-10159/axial/slice_axial_01_idx022.png
        """
        # Caminho da pasta da orienta√ß√£o espec√≠fica
        orientation_dir = os.path.join(self.base_path, f"sub-{participant_id}", orientation)

        # Busca todos os arquivos da orienta√ß√£o
        pattern = os.path.join(orientation_dir, f"slice_{orientation}_*_idx*.png")
        all_files = glob.glob(pattern)

        if len(all_files) == 0:
            return []

        # Extrai o n√∫mero de ordem de cada arquivo (o XX em slice_orientacao_XX_idx)
        rank_files = []
        for filepath in all_files:
            filename = os.path.basename(filepath)
            # Padr√£o: slice_orientacao_XX_idxYYY.png
            match = re.search(rf'slice_{orientation}_(\d+)_idx', filename)
            if match:
                rank = int(match.group(1))
                rank_files.append((rank, filepath))

        # Ordena pelo n√∫mero de ordem
        rank_files.sort(key=lambda x: x[0])

        # Pega os N primeiros
        selected_files = [filepath for rank, filepath in rank_files[:n_slices]]

        return selected_files

    def __getitem__(self, idx):
        row = self.metadata_df.iloc[idx]
        participant_id = row['participant_id']
        label = self.label_map[row['diagnosis']]

        # Carrega slices de cada orienta√ß√£o
        orientations = ['sagital', 'coronal', 'axial']
        all_slices = []
        found_slices_count = 0

        for orientation in orientations:
            slice_paths = self._get_ranked_slices(participant_id, orientation, self.slices_per_view)

            if len(slice_paths) < self.slices_per_view:
                 print(f"Participante {participant_id}, Orienta√ß√£o {orientation}: Encontrados {len(slice_paths)} slices, esperado {self.slices_per_view}")

            for img_path in slice_paths:
                try:
                    img = Image.open(img_path).convert('L')  # Grayscale
                    if self.transform:
                        img = self.transform(img)
                    all_slices.append(img)
                    found_slices_count += 1
                except Exception as e:
                    print(f"Erro ao carregar slice {img_path}: {e}")
                    continue

        # Verifica o n√∫mero total de slices encontrados
        expected_total_slices = self.slices_per_view * len(orientations)
        if found_slices_count != expected_total_slices:
             print(f"Participante {participant_id}: Total de slices {found_slices_count}, esperado {expected_total_slices}")

        if not all_slices:
            print(f"Participante {participant_id}: NENHUM slice carregado!")
            return None, None

        # Verifica se todos os tensores t√™m o mesmo tamanho antes de empilhar
        if len(all_slices) > 0:
            first_shape = all_slices[0].shape
            if not all(t.shape == first_shape for t in all_slices):
                 print(f"Participante {participant_id}: Slices com shapes diferentes! {[t.shape for t in all_slices]}")
                 return None, None

        try:
            images = torch.stack(all_slices, dim=0)
        except Exception as e:
            print(f"Erro ao empilhar slices para {participant_id}: {e}")
            print(f"Shapes dos slices: {[t.shape for t in all_slices]}")
            return None, None

        return images, label

In [None]:
class MultiSliceViT(nn.Module):
    """
    Vision Transformer que processa m√∫ltiplos slices individualmente e depois agrega
    """
    def __init__(self, num_slices=6, num_classes=2, vit_model='vit_b_16', pretrained=True):
        super(MultiSliceViT, self).__init__()

        self.num_slices = num_slices

        # Escolhe o modelo ViT
        if vit_model == 'vit_b_16':
            weights = 'DEFAULT' if pretrained else None
            vit = models.vit_b_16(weights=weights)
            feature_dim = 768
        elif vit_model == 'vit_b_32':
            weights = 'DEFAULT' if pretrained else None
            vit = models.vit_b_32(weights=weights)
            feature_dim = 768
        elif vit_model == 'vit_l_16':
            weights = 'DEFAULT' if pretrained else None
            vit = models.vit_l_16(weights=weights)
            feature_dim = 1024
        else:
            raise ValueError(f"Modelo {vit_model} n√£o suportado")

        # Adapta patch embedding para aceitar 1 canal (grayscale)
        original_conv = vit.conv_proj
        vit.conv_proj = nn.Conv2d(
            1,
            original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride,
            padding=original_conv.padding,
            bias=False
        )

        # Se pretrained, inicializa o novo conv somando os pesos RGB
        if pretrained:
            with torch.no_grad():
                vit.conv_proj.weight = nn.Parameter(
                    original_conv.weight.sum(dim=1, keepdim=True)
                )

        # Remove a cabe√ßa de classifica√ß√£o original
        vit.heads = nn.Identity()

        self.encoder = vit

        # Camadas de agrega√ß√£o
        self.aggregation = nn.Sequential(
            nn.Linear(feature_dim * num_slices, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        """
        Args:
            x: Tensor de shape (batch_size, num_slices, C, H, W)

        Returns:
            output: Tensor de shape (batch_size, num_classes)
        """
        batch_size, num_slices, C, H, W = x.size()

        slice_features = []
        for i in range(num_slices):
            slice_i = x[:, i, :, :, :]  # (batch_size, C, H, W)
            feat = self.encoder(slice_i)  # (batch_size, feature_dim)
            slice_features.append(feat)

        # Concatena features de todos os slices
        combined = torch.cat(slice_features, dim=1)  # (batch_size, feature_dim * num_slices)

        # Classifica√ß√£o final
        output = self.aggregation(combined)

        return output

In [None]:
class MultiSliceViT(nn.Module):
    """
    Vision Transformer que processa m√∫ltiplos slices individualmente e depois agrega
    com Transformer para capturar rela√ß√µes espaciais entre slices.
    """
    def __init__(
        self,
        num_slices=6,
        num_classes=2,
        vit_model='vit_b_16',
        pretrained=True,
        use_slice_transformer=True,
        transformer_layers=2,
        transformer_heads=8
    ):
        super(MultiSliceViT, self).__init__()

        self.num_slices = num_slices
        self.use_slice_transformer = use_slice_transformer

        # Escolhe o modelo ViT
        if vit_model == 'vit_b_16':
            weights = 'DEFAULT' if pretrained else None
            vit = models.vit_b_16(weights=weights)
            self.feature_dim = 768
        elif vit_model == 'vit_b_32':
            weights = 'DEFAULT' if pretrained else None
            vit = models.vit_b_32(weights=weights)
            self.feature_dim = 768
        elif vit_model == 'vit_l_16':
            weights = 'DEFAULT' if pretrained else None
            vit = models.vit_l_16(weights=weights)
            self.feature_dim = 1024
        else:
            raise ValueError(f"Modelo {vit_model} n√£o suportado")

        # Adapta patch embedding para aceitar 1 canal (grayscale)
        original_conv = vit.conv_proj
        vit.conv_proj = nn.Conv2d(
            1,
            original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride,
            padding=original_conv.padding,
            bias=False
        )

        # Se pretrained, inicializa o novo conv somando os pesos RGB
        if pretrained:
            with torch.no_grad():
                vit.conv_proj.weight = nn.Parameter(
                    original_conv.weight.sum(dim=1, keepdim=True)
                )

        # Remove a cabe√ßa de classifica√ß√£o original
        vit.heads = nn.Identity()

        self.encoder = vit

        # Positional encoding para os slices (adiciona informa√ß√£o de ordem/posi√ß√£o)
        self.slice_pos_embedding = nn.Parameter(
            torch.randn(1, num_slices, self.feature_dim) * 0.02
        )

        if use_slice_transformer:
            # Transformer para capturar rela√ß√µes entre slices
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=self.feature_dim,
                nhead=transformer_heads,
                dim_feedforward=self.feature_dim * 4,
                dropout=0.1,
                batch_first=True
            )
            self.slice_transformer = nn.TransformerEncoder(
                encoder_layer,
                num_layers=transformer_layers
            )

            # Cabe√ßa de classifica√ß√£o mais leve
            self.classifier = nn.Sequential(
                nn.LayerNorm(self.feature_dim),
                nn.Linear(self.feature_dim, 256),
                nn.GELU(),
                nn.Dropout(0.3),
                nn.Linear(256, num_classes)
            )
        else:
            # Agrega√ß√£o simples (concatena√ß√£o + MLP)
            self.classifier = nn.Sequential(
                nn.Linear(self.feature_dim * num_slices, 512),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(512, 128),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(128, num_classes)
            )

    def forward(self, x):
        """
        Args:
            x: Tensor de shape (batch_size, num_slices, C, H, W)

        Returns:
            output: Tensor de shape (batch_size, num_classes)
        """
        batch_size, num_slices, C, H, W = x.size()

        # Processa todos os slices em paralelo (muito mais eficiente!)
        x_flat = x.view(batch_size * num_slices, C, H, W)
        features = self.encoder(x_flat)  # (batch_size * num_slices, feature_dim)

        # Reorganiza para separar batch e slices
        features = features.view(batch_size, num_slices, self.feature_dim)

        if self.use_slice_transformer:
            # Adiciona positional encoding
            features = features + self.slice_pos_embedding

            # Captura rela√ß√µes entre slices com Transformer
            features = self.slice_transformer(features)  # (batch_size, num_slices, feature_dim)

            # Global average pooling sobre os slices
            pooled = features.mean(dim=1)  # (batch_size, feature_dim)

            # Classifica√ß√£o
            output = self.classifier(pooled)
        else:
            # Concatena features de todos os slices
            combined = features.view(batch_size, -1)  # (batch_size, feature_dim * num_slices)
            output = self.classifier(combined)

        return output


# Exemplo de uso alternativo: Attention-based aggregation
class MultiSliceViTWithAttention(MultiSliceViT):
    """
    Variante que usa attention pooling ao inv√©s de average pooling
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        if self.use_slice_transformer:
            # Attention pooling
            self.attention_pool = nn.Sequential(
                nn.Linear(self.feature_dim, 128),
                nn.Tanh(),
                nn.Linear(128, 1)
            )

    def forward(self, x):
        batch_size, num_slices, C, H, W = x.size()

        # Processa todos os slices em paralelo
        x_flat = x.view(batch_size * num_slices, C, H, W)
        features = self.encoder(x_flat)
        features = features.view(batch_size, num_slices, self.feature_dim)

        if self.use_slice_transformer:
            # Adiciona positional encoding
            features = features + self.slice_pos_embedding

            # Transformer
            features = self.slice_transformer(features)

            # Attention pooling: aprende quais slices s√£o mais importantes
            attention_scores = self.attention_pool(features)  # (batch_size, num_slices, 1)
            attention_weights = torch.softmax(attention_scores, dim=1)

            # Weighted sum dos slices
            pooled = (features * attention_weights).sum(dim=1)  # (batch_size, feature_dim)

            output = self.classifier(pooled)
        else:
            combined = features.view(batch_size, -1)
            output = self.classifier(combined)

        return output

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)

    return epoch_loss, epoch_acc


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())

    epoch_loss = running_loss / len(loader.dataset)

    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)

    try:
        auc = roc_auc_score(all_labels, all_probs)
    except:
        auc = 0.0

    cm = confusion_matrix(all_labels, all_preds)

    print(cm)

    # Calcula sensibilidade e especificidade
    tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    metrics = {
        'loss': epoch_loss,
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'f1': f1,
        'auc': auc,
        'confusion_matrix': cm.tolist()
    }

    return metrics


def train_model(model, train_loader, val_loader, criterion, optimizer,
                num_epochs, device, scheduler=None, early_stopping_patience=5):

    best_val_loss = float('inf')
    best_model_wts = None
    patience_counter = 0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_metrics = evaluate(model, val_loader, criterion, device)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_metrics['loss'])
        history['val_acc'].append(val_metrics['accuracy'])

        # Atualiza scheduler com base na perda de valida√ß√£o
        if scheduler:
            scheduler.step(val_metrics['loss'])

        print(f"√âpoca {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_loss:.4f}, Val Loss: {val_metrics['loss']:.4f}, "
              f"Train Acc: {train_acc:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")

        # Early stopping baseado na loss (melhor indicador de overfitting)
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            best_model_wts = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= early_stopping_patience:
            print(f"Early stopping na √©poca {epoch+1}")
            break

    if best_model_wts:
        model.load_state_dict(best_model_wts)

    return model, history

In [None]:
import numpy as np
import torchvision

# Fun√ß√£o para desnormalizar as imagens (se estiver usando normaliza√ß√£o)
def denormalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """Desnormaliza um tensor de imagem"""
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

# Fun√ß√£o para mostrar um batch de imagens
def show_batch(dataloader, num_images=16, denorm=True):
    """
    Mostra um batch de imagens do dataloader

    Args:
        dataloader: seu train_loader ou val_loader
        num_images: quantidade de imagens para mostrar
        denorm: se True, desnormaliza as imagens
    """
    # Pega um batch
    images, labels = next(iter(dataloader))

    # Limita ao n√∫mero solicitado
    images = images[:num_images]
    labels = labels[:num_images]

    # Desnormaliza se necess√°rio
    if denorm:
        images = denormalize(images.clone())

    # Clipar valores para [0, 1]
    images = torch.clamp(images, 0, 1)

    # Cria grid de imagens
    grid = torchvision.utils.make_grid(images, nrow=4, padding=2)

    # Converte para numpy
    grid_np = grid.permute(1, 2, 0).cpu().numpy()

    # Plota
    plt.figure(figsize=(15, 15))
    plt.imshow(grid_np)
    plt.title(f'Batch de imagens - Labels: {labels.tolist()}')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

    # Imprime informa√ß√µes
    print(f"Shape das imagens: {images.shape}")
    print(f"Labels: {labels.tolist()}")
    print(f"Min pixel value: {images.min().item():.3f}")
    print(f"Max pixel value: {images.max().item():.3f}")


In [None]:
class ExperimentPipeline:
    """
    Pipeline para treinar modelo com diferentes quantidades de slices
    """
    def __init__(self, metadata_df, base_path, results_dir, device):
        self.metadata_df = metadata_df
        self.base_path = base_path
        self.results_dir = results_dir
        os.makedirs(results_dir, exist_ok=True)
        self.device = device

        # Transforma√ß√µes
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

        # Armazena resultados de todas as execu√ß√µes
        self.all_results = []

    def run_experiment(self, slices_per_view, batch_size=2, num_epochs=20, lr=1e-5):
        """
        Executa um experimento com N slices por orienta√ß√£o
        """
        print(f"\n{'='*80}")
        print(f"EXPERIMENTO: {slices_per_view} slices por orienta√ß√£o ({slices_per_view*3} total)")
        print(f"{'='*80}\n")

        # IMPORTANTE: Primeiro valida o dataset completo
        print("Validando disponibilidade de slices...")
        valid_metadata = self._filter_valid_patients(self.metadata_df, slices_per_view)

        if len(valid_metadata) == 0:
            print("Nenhum paciente v√°lido encontrado!")
            return None

        print(f"‚úì {len(valid_metadata)} pacientes com slices suficientes")
        print(f"   Distribui√ß√£o: {valid_metadata['diagnosis'].value_counts().to_dict()}\n")

        # DEPOIS faz o split
        train_df, val_df = train_test_split(
            valid_metadata,
            test_size=0.2,
            random_state=42,
            stratify=valid_metadata['diagnosis']
        )

        # Datasets (SEM valida√ß√£o interna, j√° foi feita)
        train_dataset = FMRISkimmedDataset(
            train_df, self.base_path, self.transform, slices_per_view,
            validate=False  # Importante!
        )
        val_dataset = FMRISkimmedDataset(
            val_df, self.base_path, self.transform, slices_per_view,
            validate=False  # Importante!
        )


        from torch.utils.data import WeightedRandomSampler

        # Conta quantas amostras h√° por classe
        class_counts = train_df['diagnosis'].value_counts()
        print(f"Distribui√ß√£o original: {class_counts.to_dict()}")

        # Calcula peso inverso (classe rara tem peso maior)
        weights_per_class = {
            label: 1.0 / count for label, count in class_counts.items()
        }

        # Gera vetor de pesos por amostra
        sample_weights = train_df['diagnosis'].map(weights_per_class).values

        # Cria sampler com oversampling proporcional
        sampler = WeightedRandomSampler(
            weights=torch.DoubleTensor(sample_weights),
            num_samples=len(train_df),     # mant√©m o tamanho do dataset
            replacement=True
        )

        # DataLoader com sampler
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=2,
            pin_memory=True
        )

        # Loader de valida√ß√£o sem oversampling
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )

        print(f"Dataset: {len(train_dataset)} treino, {len(val_dataset)} valida√ß√£o")

        # Modelo
        total_slices = slices_per_view * 3
        model = MultiSliceViT(num_slices=total_slices, num_classes=2).to(self.device)

        # Otimiza√ß√£o
        # Calcula pesos com base na propor√ß√£o das classes
        class_counts = valid_metadata['diagnosis'].value_counts().sort_index()
        total = class_counts.sum()
        weights = [total / class_counts[i] for i in range(len(class_counts))]
        weights = torch.tensor(weights, dtype=torch.float).to(self.device)

        print(total)
        print(weights)

        # Loss ponderada

        #criterion = nn.CrossEntropyLoss(weight=weights)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

        # Treinamento
        print("\nIniciando treinamento...\n")
        model, history = train_model(
            model, train_loader, val_loader, criterion, optimizer,
            num_epochs, self.device, early_stopping_patience=5
        )

        # Avalia√ß√£o final
        final_metrics = evaluate(model, val_loader, criterion, self.device)

        print(f"\n{'='*60}")
        print(f"RESULTADOS FINAIS - {slices_per_view} slices por orienta√ß√£o:")
        print(f"{'='*60}")
        print(f"Acur√°cia:       {final_metrics['accuracy']:.4f}")
        print(f"Precis√£o:       {final_metrics['precision']:.4f}")
        print(f"Recall:         {final_metrics['recall']:.4f}")
        print(f"Sensibilidade:  {final_metrics['sensitivity']:.4f}")
        print(f"Especificidade: {final_metrics['specificity']:.4f}")
        print(f"F1-Score:       {final_metrics['f1']:.4f}")
        print(f"AUC-ROC:        {final_metrics['auc']:.4f}")
        print(f"{'='*60}\n")

        #self.print_confusion_details(final_metrics['confusion_matrix'])

        # Salva resultados
        results = {
            'slices_per_view': slices_per_view,
            'total_slices': total_slices,
            'metrics': final_metrics,
            'history': history,
            'config': {
                'batch_size': batch_size,
                'num_epochs': num_epochs,
                'learning_rate': lr,
                'train_size': len(train_dataset),
                'val_size': len(val_dataset)
            }
        }

        self.all_results.append(results)
        self._save_experiment_results(results, slices_per_view)

        # Salva modelo
        model_path = os.path.join(self.results_dir, f'model_{slices_per_view}slices.pth')
        torch.save(model.state_dict(), model_path)

        return results

    def _filter_valid_patients(self, metadata_df, min_slices_per_view):
        """
        Filtra pacientes que t√™m slices suficientes em todas as orienta√ß√µes
        """
        valid_indices = []

        for idx, row in metadata_df.iterrows():
            participant_id = row['participant_id']
            slices_dir = os.path.join(self.base_path, f"sub-{participant_id}")


            # Verifica se o diret√≥rio existe
            if not os.path.exists(slices_dir):
                continue

            # Verifica se tem slices suficientes em todas as orienta√ß√µes
            has_enough_slices = True
            for orientation in ['sagital', 'coronal', 'axial']:
                pattern = os.path.join(slices_dir, f"{orientation}/slice_{orientation}_*_idx*.png")
                files = glob.glob(pattern)
                if len(files) < min_slices_per_view:
                    has_enough_slices = False
                    break

            if has_enough_slices:
                valid_indices.append(idx)

        return metadata_df.iloc[valid_indices].reset_index(drop=True)

    def _save_experiment_results(self, results, slices_per_view):
        """
        Salva resultados de um experimento espec√≠fico
        """
        exp_dir = os.path.join(self.results_dir, f'exp_{slices_per_view}slices')
        os.makedirs(exp_dir, exist_ok=True)

        # Salva JSON com m√©tricas
        json_path = os.path.join(exp_dir, 'metrics.json')
        with open(json_path, 'w') as f:
            json.dump({
                'slices_per_view': results['slices_per_view'],
                'total_slices': results['total_slices'],
                'accuracy': results['metrics']['accuracy'],
                'precision': results['metrics']['precision'],
                'recall': results['metrics']['recall'],
                'sensitivity': results['metrics']['sensitivity'],
                'specificity': results['metrics']['specificity'],
                'f1': results['metrics']['f1'],
                'auc': results['metrics']['auc'],
                'confusion_matrix': results['metrics']['confusion_matrix'],
                'config': results['config']
            }, f, indent=4)

        # Plot de matriz de confus√£o
        self.print_confusion_details(results['metrics']['confusion_matrix'])

        # Plot de curvas de aprendizado
        self._plot_learning_curves(results['history'], exp_dir)


    def _plot_learning_curves(self, history, save_dir):
        """
        Plota e salva curvas de aprendizado
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

        # Loss
        ax1.plot(history['train_loss'], label='Treino', marker='o')
        ax1.plot(history['val_loss'], label='Valida√ß√£o', marker='s')
        ax1.set_xlabel('√âpoca')
        ax1.set_ylabel('Loss')
        ax1.set_title('Curva de Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Accuracy
        ax2.plot(history['train_acc'], label='Treino', marker='o')
        ax2.plot(history['val_acc'], label='Valida√ß√£o', marker='s')
        ax2.set_xlabel('√âpoca')
        ax2.set_ylabel('Acur√°cia')
        ax2.set_title('Curva de Acur√°cia')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'learning_curves.png'), dpi=300, bbox_inches='tight')
        plt.close()

    def run_all_experiments(self, slice_counts=[2, 4, 6, 8, 10, 12, 14, 16],
                          batch_size=8, num_epochs=20, lr=1e-4):
        """
        Executa todos os experimentos sequencialmente
        """
        print(f"\nINICIANDO PIPELINE DE EXPERIMENTOS")
        print(f"Configura√ß√µes variadas de slices: {slice_counts}")
        print(f"Batch size: {batch_size}, √âpocas: {num_epochs}, LR: {lr}\n")

        start_time = datetime.now()

        for slices in slice_counts:
            try:
                self.run_experiment(slices, batch_size, num_epochs, lr)
            except Exception as e:
                print(f"\nERRO no experimento com {slices} slices: {e}\n")
                continue

        end_time = datetime.now()
        duration = end_time - start_time

        print(f"\nPIPELINE COMPLETO!")
        print(f"Tempo total: {duration}")

        # Gera relat√≥rio comparativo
        self.generate_comparative_report()

    def generate_comparative_report(self):
        """
        Gera relat√≥rio comparativo de todos os experimentos
        """
        print(f"\n{'='*80}")
        print("RELAT√ìRIO COMPARATIVO DE TODOS OS EXPERIMENTOS")
        print(f"{'='*80}\n")

        # Tabela comparativa
        comparison_data = []
        for result in self.all_results:
            comparison_data.append({
                'Slices/View': result['slices_per_view'],
                'Total Slices': result['total_slices'],
                'Accuracy': f"{result['metrics']['accuracy']:.4f}",
                'Precision': f"{result['metrics']['precision']:.4f}",
                'Recall': f"{result['metrics']['recall']:.4f}",
                'F1-Score': f"{result['metrics']['f1']:.4f}",
                'AUC': f"{result['metrics']['auc']:.4f}",
                'Sensitivity': f"{result['metrics']['sensitivity']:.4f}",
                'Specificity': f"{result['metrics']['specificity']:.4f}"
            })

        df_comparison = pd.DataFrame(comparison_data)
        print(df_comparison.to_string(index=False))

        # Salva CSV
        csv_path = os.path.join(self.results_dir, 'comparative_results.csv')
        df_comparison.to_csv(csv_path, index=False)
        print(f"\n‚úì Resultados salvos em: {csv_path}")

        # Plot comparativo
        self._plot_comparative_metrics()

        # Identifica melhor configura√ß√£o
        if self.all_results: # Check if all_results is not empty
            best_result = max(self.all_results, key=lambda x: x['metrics']['accuracy'])
            print(f"\nMELHOR CONFIGURA√á√ÉO:")
            print(f"   Slices por orienta√ß√£o: {best_result['slices_per_view']}")
            print(f"   Acur√°cia: {best_result['metrics']['accuracy']:.4f}")
            print(f"   AUC: {best_result['metrics']['auc']:.4f}\n")
        else:
            print("\nN√£o foi poss√≠vel identificar a melhor configura√ß√£o, pois nenhum experimento foi conclu√≠do com sucesso.")


    def print_confusion_details(self, cm):
      """
      Plota e salva matriz de confus√£o detalhada com contagens e percentuais
      """
      # Calcula totais
      total = np.sum(cm)
      total_control = np.sum(cm[0, :])  # Total de CONTROL (linha 0)
      total_schz = np.sum(cm[1, :])     # Total de SCHZ (linha 1)

      # Extrai valores da matriz
      tn, fp = cm[0, 0], cm[0, 1]  # True Negative, False Positive
      fn, tp = cm[1, 0], cm[1, 1]  # False Negative, True Positive

      # Calcula percentuais por classe
      tn_pct = (tn / total_control * 100) if total_control > 0 else 0
      fp_pct = (fp / total_control * 100) if total_control > 0 else 0
      fn_pct = (fn / total_schz * 100) if total_schz > 0 else 0
      tp_pct = (tp / total_schz * 100) if total_schz > 0 else 0

      # Cria figura com duas subplots
      fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

      # ==================== SUBPLOT 1: Valores Absolutos ====================
      sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax1,
                  xticklabels=['CONTROL', 'SCHZ'],
                  yticklabels=['CONTROL', 'SCHZ'],
                  cbar_kws={'label': 'Contagem'},
                  annot_kws={'size': 16, 'weight': 'bold'})
      ax1.set_title('Matriz de Confus√£o - Valores Absolutos', fontsize=14, weight='bold')
      ax1.set_ylabel('Classe Real', fontsize=12)
      ax1.set_xlabel('Classe Predita', fontsize=12)

      # ==================== SUBPLOT 2: Percentuais ====================
      # Cria matriz de percentuais
      cm_pct = np.array([[tn_pct, fp_pct],
                        [fn_pct, tp_pct]])

      sns.heatmap(cm_pct, annot=True, fmt='.1f', cmap='Greens', ax=ax2,
                  xticklabels=['CONTROL', 'SCHZ'],
                  yticklabels=['CONTROL', 'SCHZ'],
                  cbar_kws={'label': '% da Classe Real'},
                  annot_kws={'size': 16, 'weight': 'bold'})
      ax2.set_title('Matriz de Confus√£o - Percentuais por Classe', fontsize=14, weight='bold')
      ax2.set_ylabel('Classe Real', fontsize=12)
      ax2.set_xlabel('Classe Predita', fontsize=12)

      # ==================== Adiciona texto explicativo ====================
      textstr = (
          f'AN√ÅLISE DETALHADA\n'
          f'‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n'
          f'Total de amostras: {total}\n\n'
          f'CONTROL (Real: {total_control} amostras):\n'
          f'  ‚úì Acertos: {tn} ({tn_pct:.1f}%)\n'
          f'  ‚úó Erros:   {fp} ({fp_pct:.1f}%) ‚Üí preditos como SCHZ\n\n'
          f'SCHZ (Real: {total_schz} amostras):\n'
          f'  ‚úì Acertos: {tp} ({tp_pct:.1f}%)\n'
          f'  ‚úó Erros:   {fn} ({fn_pct:.1f}%) ‚Üí preditos como CONTROL\n\n'
          f'‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ\n'
          f'Acur√°cia Geral: {(tn + tp) / total * 100:.2f}%'
      )

      # Adiciona caixa de texto abaixo dos gr√°ficos
      fig.text(0.5, -0.15, textstr,
              ha='center', va='top',
              fontsize=11, family='monospace',
              bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

      plt.tight_layout()
      plt.close()


    def _plot_comparative_metrics(self):
        """
        Plota gr√°ficos comparativos de m√©tricas
        """
        # Filtra resultados v√°lidos (onde run_experiment n√£o retornou None)
        valid_results = [r for r in self.all_results if r is not None]

        if not valid_results:
            print("\nN√£o h√° resultados v√°lidos para gerar gr√°ficos comparativos.")
            return

        slice_counts = [r['slices_per_view'] for r in valid_results]
        accuracies = [r['metrics']['accuracy'] for r in valid_results]
        f1_scores = [r['metrics']['f1'] for r in valid_results]
        aucs = [r['metrics']['auc'] for r in valid_results]
        sensitivities = [r['metrics']['sensitivity'] for r in valid_results]
        specificities = [r['metrics']['specificity'] for r in valid_results]

        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        # Accuracy
        axes[0, 0].plot(slice_counts, accuracies, marker='o', linewidth=2, markersize=8)
        axes[0, 0].set_xlabel('Slices por Orienta√ß√£o')
        axes[0, 0].set_ylabel('Acur√°cia')
        axes[0, 0].set_title('Acur√°cia vs N√∫mero de Slices')
        axes[0, 0].grid(True, alpha=0.3)
        axes[0, 0].set_xticks(slice_counts)

        # F1-Score
        axes[0, 1].plot(slice_counts, f1_scores, marker='s', linewidth=2, markersize=8, color='green')
        axes[0, 1].set_xlabel('Slices por Orienta√ß√£o')
        axes[0, 1].set_ylabel('F1-Score')
        axes[0, 1].set_title('F1-Score vs N√∫mero de Slices')
        axes[0, 1].grid(True, alpha=0.3)
        axes[0, 1].set_xticks(slice_counts)

        # AUC
        axes[1, 0].plot(slice_counts, aucs, marker='^', linewidth=2, markersize=8, color='red')
        axes[1, 0].set_xlabel('Slices por Orienta√ß√£o')
        axes[1, 0].set_ylabel('AUC-ROC')
        axes[1, 0].set_title('AUC-ROC vs N√∫mero de Slices')
        axes[1, 0].grid(True, alpha=0.3)
        axes[1, 0].set_xticks(slice_counts)

        # Sensitivity & Specificity
        axes[1, 1].plot(slice_counts, sensitivities, marker='o', linewidth=2, markersize=8, label='Sensibilidade')
        axes[1, 1].plot(slice_counts, specificities, marker='s', linewidth=2, markersize=8, label='Especificidade')
        axes[1, 1].set_xlabel('Slices por Orienta√ß√£o')
        axes[1, 1].set_ylabel('Score')
        axes[1, 1].set_title('Sensibilidade & Especificidade vs Slices')
        axes[1, 1].grid(True, alpha=0.3)
        axes[1, 1].set_xticks(slice_counts)
        axes[1, 1].legend()

        plt.tight_layout()
        plt.savefig(os.path.join(self.results_dir, 'comparative_metrics.png'), dpi=300, bbox_inches='tight')
        plt.close()

        print(f"‚úì Gr√°ficos comparativos salvos em: {self.results_dir}/comparative_metrics.png")

In [None]:
if __name__ == "__main__":
    # Configura√ß√µes
    BASE_PATH = "/content/drive/MyDrive/TCC/resultados_global/slices"
    RESULTS_DIR = "/content/drive/MyDrive/TCC/experiment_results"
    PARTICIPANTS_TSV = "/content/drive/MyDrive/TCC/dataset/ds000030-download/participants.tsv"

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Usando device: {device}\n")

    # Carrega metadados
    print("Carregando metadados...")
    metadata_df = pd.read_csv(PARTICIPANTS_TSV, sep='\t')
    metadata_df['participant_id'] = metadata_df['participant_id'].str.replace('sub-', '')

    # FILTRA APENAS CONTROL (1XXXX) e SCHZ (5XXXX)
    print("Filtrando apenas pacientes CONTROL e SCHZ...")
    metadata_df = metadata_df[metadata_df['diagnosis'].isin(['CONTROL', 'SCHZ'])]

    # Opcional: Filtra apenas IDs que come√ßam com 1 ou 5
    metadata_df['first_digit'] = metadata_df['participant_id'].astype(str).str[0]
    metadata_df = metadata_df[metadata_df['first_digit'].isin(['1', '5'])]
    metadata_df = metadata_df.drop(columns=['first_digit'])

    metadata_df['label'] = metadata_df['diagnosis'].map({'CONTROL': 0, 'SCHZ': 1})

    print(f"Total de participantes CONTROL/SCHZ: {len(metadata_df)}")
    print(f"Distribui√ß√£o: {metadata_df['diagnosis'].value_counts().to_dict()}\n")

    # DEBUG: Verifica estrutura de diret√≥rios
    print("DEBUG: Verificando estrutura de diret√≥rios...")
    print(f"Base path: {BASE_PATH}")

    # Lista alguns sujeitos
    if os.path.exists(BASE_PATH):
        subdirs = [d for d in os.listdir(BASE_PATH) if os.path.isdir(os.path.join(BASE_PATH, d)) and d.startswith('sub-')]
        print(f"‚úì Total de pastas sub-* encontradas: {len(subdirs)}")
        print(f"Primeiras 3 pastas: {subdirs[:3]}\n")

        # Pega um exemplo para an√°lise
        if len(subdirs) > 0:
            example_sub = subdirs[0]
            example_path = os.path.join(BASE_PATH, example_sub)

            print(f"Analisando: {example_sub}")
            print(f"Caminho completo: {example_path}\n")

            # Lista subpastas de orienta√ß√£o
            subfolders = [d for d in os.listdir(example_path) if os.path.isdir(os.path.join(example_path, d))]
            print(f"Subpastas encontradas: {subfolders}\n")

            # Verifica cada orienta√ß√£o
            print("="*80)
            for orientation in ['sagital', 'coronal', 'axial']:
                orientation_dir = os.path.join(example_path, orientation)

                if os.path.exists(orientation_dir):
                    # Lista arquivos .png
                    png_files = [f for f in os.listdir(orientation_dir) if f.endswith('.png')]
                    npy_files = [f for f in os.listdir(orientation_dir) if f.endswith('.npy')]

                    print(f"üìÇ Orienta√ß√£o '{orientation}':")
                    print(f"   Caminho: {orientation_dir}")
                    print(f"   Arquivos .png: {len(png_files)}")
                    print(f"   Arquivos .npy: {len(npy_files)}")

                    if len(png_files) > 0:
                        print(f"   Primeiros 3 .png: {png_files[:3]}")
                    if len(npy_files) > 0:
                        print(f"   Primeiros 3 .npy: {npy_files[:3]}")

                    # Testa padr√£o de busca
                    pattern = os.path.join(orientation_dir, f"slice_{orientation}_*_idx*.png")
                    matched_files = glob.glob(pattern)
                    print(f"   Padr√£o glob encontrou: {len(matched_files)} arquivos")
                    if len(matched_files) > 0:
                        print(f"   Exemplo: {os.path.basename(matched_files[0])}")
                else:
                    print(f"Pasta '{orientation}' n√£o existe!")

                print()

    else:
        print(f"Caminho n√£o existe: {BASE_PATH}")

    print("="*80 + "\n")
    print("Pressione Ctrl+C para interromper ou aguarde para continuar...")
    print("="*80 + "\n")

    # Inicializa pipeline
    pipeline = ExperimentPipeline(
        metadata_df=metadata_df,
        base_path=BASE_PATH,
        results_dir=RESULTS_DIR,
        device=device
    )

    # Executa todos os experimentos
    pipeline.run_all_experiments(
        #slice_counts=[2, 4, 6, 8, 10, 12, 14, 16],  # De 2 em 2 at√© 16
        slice_counts=[4],  # De 2 em 2 at√© 16
        batch_size=2,
        num_epochs=20,
        lr=1e-5
    )

    print("TODOS OS EXPERIMENTOS CONCLU√çDOS!")
    print(f"Resultados salvos em: {RESULTS_DIR}")

Usando device: cpu

Carregando metadados...
Filtrando apenas pacientes CONTROL e SCHZ...
Total de participantes CONTROL/SCHZ: 180
Distribui√ß√£o: {'CONTROL': 130, 'SCHZ': 50}

DEBUG: Verificando estrutura de diret√≥rios...
Base path: /content/drive/MyDrive/TCC/resultados_global/slices
‚úì Total de pastas sub-* encontradas: 176
Primeiras 3 pastas: ['sub-10159', 'sub-10171', 'sub-10189']

Analisando: sub-10159
Caminho completo: /content/drive/MyDrive/TCC/resultados_global/slices/sub-10159

Subpastas encontradas: ['sagital', 'coronal', 'axial']

üìÇ Orienta√ß√£o 'sagital':
   Caminho: /content/drive/MyDrive/TCC/resultados_global/slices/sub-10159/sagital
   Arquivos .png: 16
   Arquivos .npy: 16
   Primeiros 3 .png: ['slice_sagital_01_idx040.png', 'slice_sagital_02_idx041.png', 'slice_sagital_04_idx043.png']
   Primeiros 3 .npy: ['slice_sagital_02_idx041.npy', 'slice_sagital_01_idx040.npy', 'slice_sagital_04_idx043.npy']
   Padr√£o glob encontrou: 16 arquivos
   Exemplo: slice_sagital_01_

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as T
from PIL import Image
import cv2

def generate_gradcam(model, input_tensor, target_class=None, device='cuda', slice_idx=0):
    """
    Gera Grad-CAM para um slice espec√≠fico de entrada em modelos baseados em ViT.

    Args:
        model: modelo MultiSliceViT treinado
        input_tensor: tensor de shape (1, num_slices, 1, H, W)
        target_class: classe alvo (int). Se None, usa a classe predita
        device: cuda ou cpu
        slice_idx: √≠ndice do slice a ser visualizado
    """
    model.eval()
    input_tensor = input_tensor.to(device).requires_grad_(True)

    # Seleciona o slice desejado
    x = input_tensor[:, slice_idx, :, :, :]  # (1, 1, H, W)

    # Passa pelo encoder ViT
    vit = model.encoder
    with torch.no_grad():
        features = vit._process_input(x)
        B, N, _ = features.shape
        cls_token = vit.cls_token.expand(B, -1, -1)
        features = torch.cat((cls_token, features), dim=1)
        features = features + vit.encoder.pos_embed[:, :features.size(1), :]
        features = vit.encoder.dropout(features)

    # Ativa hook na √∫ltima camada de aten√ß√£o
    attn_weights = []
    def hook_attention(module, input, output):
        attn_weights.append(output[1].detach())

    handle = vit.encoder.layers[-1].attention.attn_drop.register_forward_hook(hook_attention)

    # Forward normal para obter logits
    output = model(input_tensor)
    pred_class = output.argmax(dim=1).item() if target_class is None else target_class

    # Gradiente da classe alvo
    model.zero_grad()
    class_score = output[0, pred_class]
    class_score.backward(retain_graph=True)

    # Obt√©m o mapa de aten√ß√£o m√©dio
    attn_map = attn_weights[-1].mean(dim=1).squeeze().cpu().numpy()  # (num_heads, tokens, tokens) ‚Üí (tokens, tokens)
    attn_map = attn_map[0, 1:]  # ignora CLS token
    side = int(np.sqrt(attn_map.shape[0]))
    attn_map = attn_map.reshape(side, side)
    attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())

    # Converte o slice original para imagem
    img_np = input_tensor[0, slice_idx, 0].detach().cpu().numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
    img_np = cv2.resize(img_np, (attn_map.shape[1], attn_map.shape[0]))

    # Combina aten√ß√£o com imagem original
    heatmap = cv2.applyColorMap(np.uint8(255 * attn_map), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = 0.5 * heatmap + np.repeat(img_np[..., np.newaxis], 3, axis=2)
    cam = cam / cam.max()

    # Plota resultado
    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(img_np, cmap='gray')
    plt.title(f'Slice {slice_idx} Original')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(cam)
    plt.title(f'Grad-CAM (Classe {pred_class})')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

    handle.remove()
    return cam