In [1]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from dataset import CoralDataModule
from model import CoralSegFormer

In [2]:
# UPDATE ME!
# Configure paths
user = "linneamw"
dataset_dir = f"/home/{user}/sadow_koastore/shared/coral_seg/processed_images_real2/"
results_dir = "../../results/"

# Configure hyperparameters
batch_size = 8 
epochs = 30
split_ratio = 0.8
num_workers = 4
samples_per_image = 100
crop_size = (512, 512)

In [3]:
# Initialize the data module
data_module = CoralDataModule(
    root_dir=dataset_dir, 
    batch_size=batch_size, 
    split_ratio=split_ratio,
    num_workers=num_workers,
    samples_per_image=samples_per_image,
    crop_size=crop_size
)

data_module.setup() 
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

Scanning /home/linneamw/sadow_koastore/shared/coral_seg/processed_images_real2/...
Found 118 valid image/mask pairs.
Training on 94 images.
Validating on 24 images.




In [8]:
# Load an example batch to determine input shape
example_batch = next(iter(train_loader))
print(example_batch)



[tensor([[[[ 0.5193,  0.5707,  0.7077,  ..., -2.0837, -2.0837, -2.0837],
          [ 0.6392,  0.6392,  0.6906,  ..., -2.0837, -2.0837, -2.0837],
          [ 0.6906,  0.6734,  0.6906,  ..., -2.0837, -2.0837, -2.0837],
          ...,
          [-1.6555, -1.4500, -1.2103,  ..., -1.0390, -1.0048, -1.0562],
          [-1.6042, -1.3815, -1.1418,  ..., -1.1418, -1.0390, -1.0219],
          [-1.4843, -1.2274, -1.0390,  ..., -1.2274, -1.0733, -1.0219]],

         [[ 0.7654,  0.8179,  0.9405,  ..., -1.9307, -1.9307, -1.9307],
          [ 0.8354,  0.8354,  0.9055,  ..., -1.9307, -1.9307, -1.9307],
          [ 0.8529,  0.8354,  0.8880,  ..., -1.9307, -1.9307, -1.9307],
          ...,
          [-1.4930, -1.2829, -1.0203,  ..., -0.8978, -0.8627, -0.8803],
          [-1.4230, -1.1954, -0.9503,  ..., -1.0028, -0.8978, -0.8627],
          [-1.3004, -1.0378, -0.8452,  ..., -1.0903, -0.9328, -0.8452]],

         [[ 0.4091,  0.4788,  0.6531,  ..., -1.6476, -1.6476, -1.6476],
          [ 0.4265,  0.4439, 

In [None]:
# Initialize the model
model = CoralSegFormer(learning_rate=3e-4)

# Callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath= results_dir + 'checkpoints',
    filename='coral-segformer-{epoch:02d}-{val_loss:.2f}',
    save_top_k=2,
    mode='min',
)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min'
)

callbacks = [checkpoint_callback, early_stop_callback]

In [None]:
# Initialize PyTorch Lightning Trainer
trainer = pl.Trainer(
    max_epochs=epochs,
    accelerator="auto", # Auto-detects GPU/CPU
    devices=1,
    callbacks=callbacks,
    log_every_n_steps=10
)

# Train
print("Starting Training...")
trainer.fit(model, data_module)

In [None]:
import torch
from model import CoralSegFormer

ckpt_path = "results/checkpoints/coral-segformer-epoch=12-val_loss=0.34.ckpt"

model = CoralSegFormer.load_from_checkpoint(ckpt_path)
model.eval()
model.cuda()  # remove if CPU-only

In [None]:
from pathlib import Path
import random
from PIL import Image
import torchvision.transforms as T
import torch

# root directory containing many subfolders
root_dir = Path("/path/to/test_root")

# find all image.png files recursively
image_paths = list(root_dir.rglob("image.png"))

assert len(image_paths) > 0, "No image.png files found!"

# randomly select one
img_path = random.choice(image_paths)
print(f"Using image: {img_path}")

# IMPORTANT: match training transforms
transform = T.Compose([
    T.Resize((512, 512)),
    T.ToTensor(),
    # Uncomment if used during training
    # T.Normalize(mean=[0.485, 0.456, 0.406],
    #             std=[0.229, 0.224, 0.225]),
])

img = Image.open(img_path).convert("RGB")

x = transform(img).unsqueeze(0)
x = x.cuda() if torch.cuda.is_available() else x

In [None]:
with torch.no_grad():
    logits = model(x)
    pred = torch.argmax(logits, dim=1)  # [1, H, W]

pred_mask = pred.squeeze().cpu().numpy()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

img_np = np.array(img.resize((512, 512)))

plt.figure(figsize=(12,4))

plt.subplot(1,3,1)
plt.imshow(img_np)
plt.title("Input image")
plt.axis("off")

plt.subplot(1,3,2)
plt.imshow(pred_mask, cmap="tab10", vmin=0, vmax=4)
plt.title("Predicted mask")
plt.axis("off")

plt.subplot(1,3,3)
plt.imshow(img_np)
plt.imshow(pred_mask, cmap="tab10", alpha=0.5, vmin=0, vmax=4)
plt.title("Overlay")
plt.axis("off")

plt.tight_layout()
plt.show()