In [None]:
""" Check pytorch installation """
import torch


torch.cuda.is_available()


In [12]:
""" Train the model on a synthetic dataset composed of squares with different colors """
import logging

from src.model.unet import UNetModel
from src.dataset.square_dataset import SquareDataset 
from src.benchmark.train import train, HyperParameters
from src.benchmark.test import test

NB_EPOCHS = 1000
BATCH_SIZE = 4
IMG_WIDTH = 204
IMG_HEIGHT = 204
LEARNING_RATE = 0.0001

MODEL_PATH = "./model_dict.pth"

logging.basicConfig(filename=None, filemode="w", level=logging.DEBUG)

# select device for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Selected device : {device}")

# create model
model = UNetModel(3, 1)

# create datasets
nb_images = 4
train_dataset = SquareDataset(3, IMG_WIDTH, IMG_HEIGHT, nb_images, 20, 80)
hyperparameters = HyperParameters(NB_EPOCHS, BATCH_SIZE, LEARNING_RATE)

train(device, model, train_dataset, hyperparameters)

torch.save(model.state_dict(), MODEL_PATH)


Selected device : cuda


INFO:root:build model
INFO:root:initialize loss and optimizer
INFO:root:Epoch [1/1000], Loss: 0.7086
INFO:root:Epoch 0, Learning Rate: 0.0001
INFO:root:Epoch [2/1000], Loss: 0.7068
INFO:root:Epoch 1, Learning Rate: 0.0001
INFO:root:Epoch [3/1000], Loss: 0.7049
INFO:root:Epoch 2, Learning Rate: 0.0001
INFO:root:Epoch [4/1000], Loss: 0.7029
INFO:root:Epoch 3, Learning Rate: 0.0001
INFO:root:Epoch [5/1000], Loss: 0.7009
INFO:root:Epoch 4, Learning Rate: 0.0001
INFO:root:Epoch [6/1000], Loss: 0.6988
INFO:root:Epoch 5, Learning Rate: 0.0001
INFO:root:Epoch [7/1000], Loss: 0.6967
INFO:root:Epoch 6, Learning Rate: 0.0001
INFO:root:Epoch [8/1000], Loss: 0.6945
INFO:root:Epoch 7, Learning Rate: 0.0001
INFO:root:Epoch [9/1000], Loss: 0.6919
INFO:root:Epoch 8, Learning Rate: 0.0001
INFO:root:Epoch [10/1000], Loss: 0.6890
INFO:root:Epoch 9, Learning Rate: 0.0001
INFO:root:Epoch [11/1000], Loss: 0.6858
INFO:root:Epoch 10, Learning Rate: 0.0001
INFO:root:Epoch [12/1000], Loss: 0.6820
INFO:root:Epoch

In [17]:
""" test the model on a synthetic dataset composed of squares with different colors """
model = UNetModel(3, 1).to(device)
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))


test(device, model, train_dataset, verbose = 1)


INFO:root:test batch 0
INFO:root:test batch 1
INFO:root:test batch 2
INFO:root:test batch 3
