In [None]:
import logging
import os
import sys
from pathlib import Path
from random_word import RandomWords


import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler, RandomSampler
from tqdm import tqdm
import pandas as pd
import numpy as np

# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from unet.dataset import DeadwoodDataset
from unet.dice_score import dice_loss
from unet.evaluate import evaluate
from unet.unet_model import UNet

In [None]:
# import importlib

# importlib.reload(sys.modules["unet.dataset"])

In [None]:
experiment_name = "100k_samples_biome_stratified_treefold_smallval_testset"

In [None]:
debug: bool = False

use_wanb: bool = not debug
save_checkpoint: bool = not debug

# data paths
register_file = "/net/scratch/jmoehring/tiles_register_biome_bin.csv"
checkpoint_dir = "/net/scratch/jmoehring/checkpoints"

# data params
no_folds: int = 3
test_size: float = 0.2
random_seed: int = 100
batch_size: int = 64
balaning_factor: float = 0.5
epochs: int = 20
epoch_train_samples: int = 100000
epoch_val_samples: int = epoch_train_samples * test_size
train_bins: np.array = np.arange(0, 0.21, 0.04)

# model params
learning_rate: float = 1e-5
amp: bool = False
weight_decay: float = 1e-8
momentum: float = 0.999
gradient_clipping: float = 1.0

In [None]:
register_df = pd.read_csv(register_file)

In [None]:
dataset = DeadwoodDataset(
    register_df=register_df,
    no_folds=no_folds,
    random_seed=random_seed,
    bins=train_bins,
)

In [None]:
def initialize_experiment(run_name: str):
    # initialize a new experiment run
    experiment = wandb.init(
        project="standing-deadwood-unet-pro",
        resume="allow",
        name=run_name,
    )
    experiment.config.update(
        dict(
            data=dict(
                register_file=register_file,
                no_folds=no_folds,
                test_size=test_size,
                random_seed=random_seed,
                batch_size=batch_size,
                save_checkpoint=save_checkpoint,
                balaning_factor=balaning_factor,
                epochs=epochs,
                epoch_train_samples=epoch_train_samples,
                epoch_val_samples=epoch_val_samples,
                train_bins=train_bins,
            ),
            model=dict(
                epochs=epochs,
                learning_rate=learning_rate,
                amp=amp,
                weight_decay=weight_decay,
                momentum=momentum,
                gradient_clipping=gradient_clipping,
            ),
        )
    )
    return experiment

In [None]:
loader_args = {
    "batch_size": batch_size,
    "num_workers": 12,
    "pin_memory": True,
    "shuffle": False,
}

In [None]:
if use_wanb:
    experiment = initialize_experiment(experiment_name)

In [None]:
logs = []
for fold in range(no_folds):
    # preferably use GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # model with three input channels (RGB)
    model = UNet(n_channels=3, n_classes=1, bilinear=True)
    model = model.to(memory_format=torch.channels_last)
    model.to(device=device)

    if torch.cuda.device_count() > 1:
        # train on GPU 0 and one
        model = nn.DataParallel(model, device_ids=[0, 1, 2])

    # loss function (binary cross entropy)
    criterion = nn.BCEWithLogitsLoss()

    # optimizer
    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
        momentum=momentum,
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=5)
    grad_scaler = torch.cuda.amp.grad_scaler.GradScaler(enabled=amp)
    if use_wanb:
        wandb.watch(model, log="all")

    train_set, val_set = dataset.get_train_val_fold(fold)

    train_sampler = WeightedRandomSampler(
        dataset.get_train_sample_weights(fold=fold, balancing_factor=balaning_factor),
        epoch_train_samples,
        replacement=False,
    )

    test_sampler = RandomSampler(val_set)

    train_loader = DataLoader(train_set, sampler=train_sampler, **loader_args)
    val_loader = DataLoader(val_set, **loader_args)

    step = 0
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0

        with tqdm(
            total=len(train_loader) * batch_size,
            desc=f"Epoch {epoch}/{epochs}",
            unit="img",
        ) as pbar:
            for images, masks_true, _ in train_loader:
                images = images.to(
                    device=device,
                    dtype=torch.float32,
                    memory_format=torch.channels_last,
                )
                masks_true = masks_true.to(device=device, dtype=torch.long).squeeze(1)

                with torch.amp.autocast(
                    device.type if device.type != "mps" else "cpu", enabled=amp
                ):
                    masks_pred = model(images).squeeze(1)

                    loss = criterion(masks_pred.squeeze(1), masks_true.float())
                    loss += dice_loss(
                        F.sigmoid(masks_pred.squeeze(1)),
                        masks_true.float(),
                    )
                    optimizer.zero_grad(set_to_none=True)
                    grad_scaler.scale(loss).backward()
                    torch.nn.utils.clip_grad.clip_grad_norm_(
                        model.parameters(), gradient_clipping
                    )
                    grad_scaler.step(optimizer)
                    grad_scaler.update()

                    pbar.update(images.shape[0])
                    epoch_loss += loss.item()
                    step += 1

        val_dice_score, val_dice_loss = evaluate(
            model, criterion, val_loader, device, amp
        )
        scheduler.step(val_dice_score)
        log = {
            "train loss": epoch_loss / len(train_loader),
            "val loss": val_dice_loss,
            "val dice score": val_dice_score.item(),
            "epoch": epoch,
            "fold": fold,
        }
        logs.append(log)
        if use_wanb:
            experiment.log(log)
        if save_checkpoint:
            Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
            run_dir = os.path.join(checkpoint_dir, experiment_name)
            Path(run_dir).mkdir(parents=True, exist_ok=True)
            state_dict = model.state_dict()
            torch.save(
                state_dict,
                os.path.join(run_dir, f"fold_{fold}_epoch_{epoch}.pth"),
            )
            logging.info(f"Checkpoint {epoch} saved!")

In [None]:
# log best model with highest val dice score to wandb
best_model = max(logs, key=lambda x: x["val dice score"])
if use_wanb:
    experiment.log({"best_model": best_model})
    experiment.finish()

In [None]:
run_dir = os.path.join(checkpoint_dir, experiment_name)
run_results_df = pd.DataFrame.from_records(logs)
run_results_df.to_csv(os.path.join(run_dir, f"{experiment_name}.csv"))

In [None]:
run_results_df.head()