In [None]:
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim

PROJECT_FOLDER = Path("..").resolve()
DATA_FOLDER = PROJECT_FOLDER / "data"
SRC_PATH = PROJECT_FOLDER / "src"
TILES_DIR = DATA_FOLDER / "tiles"

if str(SRC_PATH) not in sys.path:
    sys.path.insert(0, str(SRC_PATH))

print(f"Project Folder: {PROJECT_FOLDER}")
print(f"Tiles Directory: {TILES_DIR}")

In [None]:
from tortoise.dataloader import build_dataloaders
from tortoise.model import R2AttU_Net
from tortoise.train import train_model

In [None]:
train_loader, val_loader, test_loader = build_dataloaders(
    tiles_dir=TILES_DIR,
    batch_size=4, 
    use_ms=True,
    use_rgb=False,
    seed=42,
    train_ratio=0.8,
    val_ratio=0.1
)

print(f"Train samples: {len(train_loader.dataset)}")
print(f"Val samples: {len(val_loader.dataset)}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = R2AttU_Net(img_ch=13, output_ch=1, t=2).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

pos_weight = torch.tensor([1.0]).to(device)

In [None]:
checkpoint_path = PROJECT_FOLDER / "models" / "example_model.pth"

model, train_losses, val_losses, train_ious, val_ious = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    pos_weight=pos_weight,
    num_epochs=2,          
    checkpoint_path=checkpoint_path,
    use_amp=True,               # this is for Mixed Precision (FP16)
    early_stopping_patience=5   # this is to stop training if no improvement for 5 epochs
)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Training & Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_ious, label='Train IoU')
plt.plot(val_ious, label='Val IoU')
plt.title('Training & Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.legend()

plt.tight_layout()
plt.show()