In [None]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)
strategy = tf.distribute.MirroredStrategy()

In [None]:
from src.data_loader.create_dataset import create_tf_datasets

from src.models.attention_unet import AttentionUNet3D
from src.models.unet import UNet3D

from src.trainer.train import train_model

# Input Parameters

In [None]:
# Dataset parameters
train_dir = '/path/to/training_data/'                          # Path to original training data
train_dir_augm = '/path/to/training_data_augm/'                # Path to augmented training data
percent_val = 0.2                                              # Fraction of patches used for validation
patch_shape = (64, 64, 64)                                     # Size of the patches
patch_step = 64                                                # Step size between patches
random_state = 42                                              # Random seed for reproducibility

In [None]:
# Training parameters
models = ['UNet3D', 'AttentionUNet3D']        # Models to be trained
optimizer = 'adam'                            # Optimization algorithm
loss = 'binary_crossentropy'                  # Loss function
metrics = ['accuracy', 'precision', 'recall'] # Evaluation metrics
epochs = 50                                   # Maximum number of epochs
batch_size = 4                                # Batch size for training
model_names = ['bc_unet3d', 'bc_attentionunet3d'] # Names for saving models
out_dir = '/content/save_models/'                  # Output directory for models

# Create datasets and train models for original data

In [None]:
# Create dataset for the original data
train_dataset, val_dataset = create_tf_datasets(
    train_dir,
    percent_val=percent_val,
    patch_shape=patch_shape,
    patch_step=patch_step,
    random_state=random_state
)

In [None]:
# Train models for the original data
for modelId, model_name in zip(models, model_names):
    with strategy.scope():
        model = None  # Initialize model to None to handle undefined cases
        if modelId == 'UNet3D':
            model = UNet3D().build_model()
        elif modelId == 'AttentionUNet3D':
            model = AttentionUNet3D().build_model()
        else:
            print(f"Error: The model '{modelId}' is not defined. Please check the model name.")
    
        print(f"Training model '{modelId}' and saving as '{model_name}'...")   
        
        train_model(
            train_dataset = train_dataset, 
            val_dataset = val_dataset, 
            model = model, 
            optimizer = optimizer, 
            loss = loss,
            metrics = metrics,
            epochs = epochs,
            batch_size = batch_size, 
            filename=model_name+".keras",
            model_name=model_name
        )

# Create datasets and train models for augmented data

In [None]:
# Create dataset for the augmented data
train_dataset, val_dataset = create_tf_datasets(
    train_dir_augm,
    percent_val=percent_val,
    patch_shape=patch_shape,
    patch_step=patch_step,
    random_state=random_state
)

In [None]:
# Train models for the augmented data
for modelId, model_name in zip(models, model_names):
    with strategy.scope():
        model = None  # Initialize model to None to handle undefined cases
        if modelId == 'UNet3D':
            model = UNet3D().build_model()
        elif modelId == 'AttentionUNet3D':
            model = AttentionUNet3D().build_model()
        else:
            print(f"Error: The model '{modelId}' is not defined. Please check the model name.")
    
        print(f"Training model '{modelId}' and saving as '{model_name}'...")   
        
        train_model(
            train_dataset = train_dataset, 
            val_dataset = val_dataset, 
            model = model, 
            optimizer = optimizer, 
            loss = loss,
            metrics = metrics,
            epochs = epochs,
            batch_size = batch_size, 
            filename=model_name+".keras",
            model_name=model_name
        )

In [None]:
print("All calculations are successfully finished")