In [0]:
if (
    spark.conf.get("spark.databricks.clusterUsageTags.clusterUnityCatalogMode")
    != "USER_ISOLATION"
):
    print("Dette er ikke et felles cluster!")

In [0]:
%run "/Workspace/Users/fabian.heflo@kartverket.no/Snuplasser/src/dataProcessing/transform"

In [0]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

import sys
import os
import mlflow
import mlflow.pytorch

from dataProcessing.dataset import SnuplassDataset, load_numpy_split_stack
from model.unet import UNet
from dataProcessing.augmentation_config import augmentation_profiles

In [0]:
def main():
    mlflow.pytorch.autolog()  # Lagrer modellen under Experiments. Kan hente modellen med model = mlflow.pytorch.load_model("runs:/<run_id>/model")

    cfg = augmentation_profiles["default"]
    batch_size = 8
    num_epochs = 1
    learning_rate = 1e-3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_ids, val_ids, _ = load_numpy_split_stack(
        image_dir="/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/img/",
        mask_dir="/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/lab/",
        dom_dir="/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/dom/",
    )

    train_dataset = SnuplassDataset(
        image_dir="/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/img/",
        mask_dir="/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/lab/",
        dom_dir="/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/dom/",
        file_list=train_ids,
        transform=get_train_transforms(cfg, ratio=None),  # ratio=None for baseline
        # For å bruke augmentering, sett ratio til en verdi mellom 0 og 1
    )

    val_dataset = SnuplassDataset(
        image_dir="/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/img/",
        mask_dir="/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/lab/",
        dom_dir="/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/dom/",
        file_list=val_ids,
        transform=get_val_transforms(),
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = UNet(n_channels=4, n_classes=1, bilinear=False).to(
        device
    )  # bare å bytte modell

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    with mlflow.start_run(run_name="UNet_baseline_4ch"):
        for epoch in range(num_epochs):
            # Trening
            model.train()
            total_loss = 0

            for images, masks in tqdm(
                train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"
            ):
                images, masks = images.to(device).float(), masks.to(device).float()
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, masks)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            avg_train_loss = total_loss / len(train_loader)
            print(f"\nTrain loss: {avg_train_loss:.4f}")

            # Validering
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for images, masks in tqdm(
                    val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"
                ):
                    images, masks = images.to(device).float(), masks.to(device).float()
                    outputs = model(images)
                    loss = criterion(outputs.squeeze(1), masks)
                    val_loss += loss.item()

            avg_val_loss = val_loss / len(val_loader)
            print(f"Val loss: {avg_val_loss:.4f}")
        # writer.close()

    print("✅ Trening ferdig")


main()