In [None]:
""" Check pytorch installation """
import torch


torch.cuda.is_available()


In [None]:
""" Train the model on a synthetic dataset composed of squares with different colors (1-label problem) """
import logging

import torch

from src.model.unet import UNetModel
from src.dataset.square_dataset import SquareDataset
from src.benchmark.train import train, HyperParameters, EarlyStoppingParams


NB_EPOCHS = 20
BATCH_SIZE = 16
IMG_WIDTH = 128
IMG_HEIGHT = 128
LEARNING_RATE = 0.0005
BASE_FM_NUMBER = 24

MODEL_PATH = "./square_model_dict.pth"

logging.basicConfig(filename=None, filemode="w", level=logging.DEBUG)

# select device for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Selected device : {device}")

# create model
model = UNetModel(3, 1, BASE_FM_NUMBER)

# create datasets
nb_images = 100
train_dataset = SquareDataset(3, IMG_WIDTH, IMG_HEIGHT, nb_images, 20, 80)
validation_dataset = SquareDataset(3, IMG_WIDTH, IMG_HEIGHT, nb_images, 20, 80)
hyperparameters = HyperParameters(
    NB_EPOCHS, BATCH_SIZE, LEARNING_RATE, EarlyStoppingParams(3, 0.001)
)

train(device, model, None, train_dataset, validation_dataset, hyperparameters, MODEL_PATH)


In [None]:
""" test the model on a synthetic dataset composed of squares with different colors (1-label problem)"""
from src.benchmark.test import test, TestMetrics


nb_images = 10  # test set of 10 images
model = UNetModel(3, 1, BASE_FM_NUMBER).to(device)
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
test_dataset = SquareDataset(3, IMG_WIDTH, IMG_HEIGHT, nb_images, 20, 80)

test(device, model, test_dataset, [TestMetrics.ACCURACY], verbose=1)


In [None]:
""" Train the model on a jsrt dataset composed of Chest X-ray images (4-labels problem) """
import logging
import os
from pathlib import Path

import torch

from src.model.unet import UNetModel
from src.dataset.custom_image_mask_dataset import CustomImageMaskDataset
from src.dataset.data_augmentation.color_augmentation import ColorimetricAugmentationParams
from src.dataset.data_augmentation.spatial_augmentation import SpatialAugmentationParams
from src.dataset.data_augmentation.augmented_semantic_segmentation import (
    AugmentedSemanticSegmentationDataset,
)
from src.benchmark.train import train, HyperParameters, EarlyStoppingParams

NB_EPOCHS = 1000
BATCH_SIZE = 8
# IMG_WIDTH = 128
# IMG_HEIGHT = 128
LEARNING_RATE = 0.0005
BASE_FM_NUMBER = 24
NB_LABELS = 4

MODEL_PATH = "./jsrt_model_dict.pth"

# deactivate PIL debug ifo polutting output...
pil_logger = logging.getLogger("PIL")
pil_logger.setLevel(logging.INFO)

logging.basicConfig(filename=None, filemode="w", level=logging.INFO)

# select device for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Selected device : {device}")

# create model
model = UNetModel(1, NB_LABELS, BASE_FM_NUMBER)

# create datasets
current_dir = Path(os.getcwd())
train_dataset = CustomImageMaskDataset(current_dir.joinpath(r"../data/JSRT-segmentation/train"))
validation_dataset = CustomImageMaskDataset(current_dir.joinpath(r"../data/JSRT-segmentation/val"))
hyperparameters = HyperParameters(
    NB_EPOCHS, BATCH_SIZE, LEARNING_RATE, EarlyStoppingParams(5, 0.001)
)

# Add data augmentation
spatial_params = SpatialAugmentationParams()
colorimetric_params = ColorimetricAugmentationParams()
train_dataset = AugmentedSemanticSegmentationDataset(
    train_dataset, spatial_params, colorimetric_params
)
validation_dataset = AugmentedSemanticSegmentationDataset(
    validation_dataset, spatial_params, colorimetric_params
)


train(
    device,
    model,
    None,
    train_dataset,
    validation_dataset,
    hyperparameters,
    MODEL_PATH,
)


In [None]:
""" test the model on the jsrt dataset composed of Chest X-ray images (4-labels problem)  """
from src.benchmark.test import test, TestMetrics

model = UNetModel(1, NB_LABELS, BASE_FM_NUMBER).to(device)
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))

test_dataset = CustomImageMaskDataset(current_dir.joinpath(r"../data/JSRT-segmentation/test"))

test(
    device,
    model,
    test_dataset,
    [
        TestMetrics.ACCURACY,
        TestMetrics.RECALL_PER_LABEL,
        TestMetrics.PRECISION_PER_LABEL,
        TestMetrics.F1_SCORE_PER_LABEL,
        TestMetrics.IOU_PER_LABEL,
    ],
    verbose=0,
)