In [None]:
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim

PROJECT_FOLDER = Path("..").resolve()
DATA_FOLDER = PROJECT_FOLDER / "data"
SRC_PATH = PROJECT_FOLDER / "src"
TILES_DIR = DATA_FOLDER / "tiles"

if str(SRC_PATH) not in sys.path:
    sys.path.insert(0, str(SRC_PATH))

print(f"Project Folder: {PROJECT_FOLDER}")
print(f"Tiles Directory: {TILES_DIR}")

In [None]:
from tortoise.train import *
from tortoise.dataloader import *
from tortoise.hparams import *
from tortoise.checkpoints import *
from tortoise.utils import *

In [None]:
from tortoise.dataloader import build_dataloaders

train_loader, val_loader, test_loader, set_maps = build_dataloaders(
    tiles_dir=tiles_dir,
    csv_file=DATA_FOLDER / "tile_index.csv",
    batch_size= 128,
    use_ms=True,
    use_rgb=False,
    seed=42,
    train_ratio=0.8,
    val_ratio=0.1,
    num_workers=4,
)
print(f"Train samples: {len(train_loader.dataset)}")
print(f"Val samples: {len(val_loader.dataset)}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Make sure to put appropriate hyperparameter in configs/hyperparams.yml
hparams = load_hparams()

In [None]:

# Optimizer and learning rate scheduler
use_amp = True
optimizer = build_optimizer(model, hparams)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
scaler = torch.amp.GradScaler() if use_amp else None


num_epochs = hparams['train']['epochs']
pos_weight = torch.tensor([2.75], device=device)

checkpoint_path = PROJECT_FOLDER / "checkpoints" / "best_model.pth"

model, train_losses, val_losses, train_ious, val_ious = train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    device,
    pos_weight,
    num_epochs=num_epochs,
    checkpoint_path=checkpoint_path,
    use_amp=True,
    scaler = scaler,
    alpha = 0.8,
    threshold = 0.6,
    early_stopping_patience = None,
)




In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Training & Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_ious, label='Train IoU')
plt.plot(val_ious, label='Val IoU')
plt.title('Training & Validation IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.legend()

plt.tight_layout()
plt.show()