In [None]:
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)
strategy = tf.distribute.MirroredStrategy()

In [None]:
from data_loader.create_dataset import create_tf_datasets

from models.attention_unet import AttentionUNet3D
from models.unet import UNet3D

from trainer.train import train_model

import numpy as np


In [None]:
train_dir = 'data/ImageSegmentation/BC/training_data'

train_dataset, val_dataset = create_tf_datasets(
    train_dir, 
    percent_val=0.2, 
    patch_shape=(64, 64, 64), 
    patch_step=32,
    random_state = 42
)

In [None]:
with strategy.scope():
    train_model(
        train_dataset = train_dataset, 
        val_dataset = val_dataset, 
        model = UNet3D().build_model(), 
        optimizer = 'adam', 
        loss = 'binary_crossentropy',
        metrics=['accuracy', 'precision', 'recall'],
        epochs = 1,
        batch_size = 16, 
        filename="unet3d.keras"
    )

In [None]:
with strategy.scope():
    train_model(
        train_dataset = train_dataset, 
        val_dataset = val_dataset, 
        model = AttentionUNet3D().build_model(), 
        optimizer = 'adam',
        loss = 'binary_crossentropy',
        metrics=['accuracy', 'precision', 'recall'],
        epochs = 1,
        batch_size = 16,
        filename="unet3d.keras"
    )