In [None]:
""" Check pytorch installation """
import torch


torch.cuda.is_available()


In [None]:
""" Train the model on a synthetic dataset composed of squares with different colors """
import logging

import torch

from src.model.unet import UNetModel
from src.dataset.square_dataset import SquareDataset 
from src.benchmark.train import train, HyperParameters
from src.benchmark.test import test

NB_EPOCHS = 100
BATCH_SIZE = 2
IMG_WIDTH = 204
IMG_HEIGHT = 204
LEARNING_RATE = 0.001

MODEL_PATH = "./model_dict.pth"

logging.basicConfig(filename=None, filemode="w", level=logging.DEBUG)

# select device for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Selected device : {device}")

# create model
model = UNetModel(3, 1)

# create datasets
nb_images = 100
train_dataset = SquareDataset(3, IMG_WIDTH, IMG_HEIGHT, nb_images, 20, 80)
validation_dataset = SquareDataset(3, IMG_WIDTH, IMG_HEIGHT, nb_images, 20, 80)
hyperparameters = HyperParameters(NB_EPOCHS, BATCH_SIZE, LEARNING_RATE)

train(device, model, train_dataset, validation_dataset, hyperparameters)

torch.save(model.state_dict(), MODEL_PATH)


In [None]:
""" test the model on a synthetic dataset composed of squares with different colors """
model = UNetModel(3, 1).to(device)
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))


test(device, model, validation_dataset, verbose=1)
