In [1]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from dataset import CoralDataModule
from model import CoralSegFormer

In [2]:
# UPDATE ME!
# Configure paths
user = "linneamw"
dataset_dir = f"/home/{user}/sadow_koastore/shared/coral_seg/processed_images_real2/"
results_dir = "../../results/"

# Configure hyperparameters
batch_size = 8 
epochs = 30
split_ratio = 0.8
num_workers = 4
samples_per_image = 100
crop_size = (512, 512)

In [3]:
# Initialize the data module
data_module = CoralDataModule(
    root_dir=dataset_dir, 
    batch_size=batch_size, 
    split_ratio=split_ratio,
    num_workers=num_workers,
    samples_per_image=samples_per_image,
    crop_size=crop_size
)

data_module.setup() 
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

Scanning /home/linneamw/sadow_koastore/shared/coral_seg/processed_images_real2/...
Found 118 valid image/mask pairs.
Training on 94 images.
Validating on 24 images.




In [None]:
# Load an example batch to determine input shape
example_batch = next(iter(train_loader))
input_shape = example_batch['image'].shape[1:]  # Exclude batch dimension

print(f"Input shape: {input_shape}")



In [None]:
# Initialize the model
model = CoralSegFormer(learning_rate=3e-4)

# Callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath= results_dir + 'checkpoints',
    filename='coral-segformer-{epoch:02d}-{val_loss:.2f}',
    save_top_k=2,
    mode='min',
)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min'
)

callbacks = [checkpoint_callback, early_stop_callback]

In [None]:
# Initialize PyTorch Lightning Trainer
trainer = pl.Trainer(
    max_epochs=epochs,
    accelerator="auto", # Auto-detects GPU/CPU
    devices=1,
    callbacks=callbacks,
    log_every_n_steps=10
)

# Train
print("Starting Training...")
trainer.fit(model, data_module)