# Prediction

In [None]:
import os
import sys
import matplotlib.pyplot as plt
import yaml

sys.path.append('..')

# Chargement de la configuration
config_path = '../configs/config_training.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration chargée:")
print(f"| Modèle: {config['model']['name']} ({config['model']['in_channels']} -> {config['model']['out_channels']})")

## Données

In [None]:
# Chargement des chemins depuis la config
test_img_path = config['data']['test_img_path']
test_label_path = config['data']['test_label_path']

test_img_files = os.listdir(test_img_path)
test_label_files = os.listdir(test_label_path)

test_img = [os.path.join(test_img_path, file) for file in test_img_files]
test_label = [os.path.join(test_label_path, file) for file in test_label_files]

print(f'Number of testing images: {len(test_img)}')
print(f'Number of testing labels: {len(test_label)}')

In [None]:
# Utils pour visualiser les images
from PIL import Image 

def load_img(file_path):
    return Image.open(file_path)

def load_label(file_path):
    # In grayscale (1 canal)
    return Image.open(file_path).convert('L')

## Dataloader test

In [None]:
import torchvision.transforms as transforms
from src.dataset.dataset import VesselDataset
from torch.utils.data import DataLoader

# Paramètres depuis la config
image_size = config['data']['image_size']

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Nombre d'images à tester (ajuster selon besoin)
nb_images_to_test = 200

test_dataset = VesselDataset(
    test_img[:nb_images_to_test], 
    test_label[:nb_images_to_test], 
    transform=transform
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print(f"Test dataset: {len(test_dataset)} images")

## Chargement

In [None]:
import torch
from src.utils.resume_training import load_previous_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Chemin du modèle à charger (utilise le chemin de la config ou un chemin personnalisé)
model_path = config['training']['resume_model_path']

model = load_previous_model(
    path=model_path,
    in_channels=config['model']['in_channels'],
    out_channels=config['model']['out_channels'],
    device=device
)

## Prédictions

In [None]:
from src.models.predict import predict

preds = predict(model, test_loader, device)
print(f"Prédictions générées: {len(preds)} images")

## Visualisation

In [None]:
# Affichage d'une prédiction
idx = 0

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Image originale
axes[0].imshow(load_img(test_img[idx]))
axes[0].set_title('Image originale')
axes[0].axis('off')

# Ground truth
axes[1].imshow(load_label(test_label[idx]), cmap='gray')
axes[1].set_title('Ground truth')
axes[1].axis('off')

# Prédiction
pred = preds[idx].to('cpu').squeeze().detach().numpy()
axes[2].imshow(pred, cmap='gray')
axes[2].set_title('Prédiction')
axes[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Affichage de plusieurs prédictions
def visualize_predictions(test_imgs, test_labels, predictions, indices=[0, 1, 2, 3]):
    n = len(indices)
    fig, axes = plt.subplots(n, 3, figsize=(15, 5 * n))
    
    for i, idx in enumerate(indices):
        if idx >= len(predictions):
            continue
            
        # Image originale
        axes[i, 0].imshow(load_img(test_imgs[idx]))
        axes[i, 0].set_title(f'Image {idx}')
        axes[i, 0].axis('off')

        # Ground truth
        axes[i, 1].imshow(load_label(test_labels[idx]), cmap='gray')
        axes[i, 1].set_title('Ground truth')
        axes[i, 1].axis('off')

        # Prédiction
        pred = predictions[idx].to('cpu').squeeze().detach().numpy()
        axes[i, 2].imshow(pred, cmap='gray')
        axes[i, 2].set_title('Prédiction')
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

# 4 premières
visualize_predictions(test_img, test_label, preds, indices=[0, 1, 2, 3])