# Imports

In [None]:
import os

from tqdm import tqdm

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

from src.utils.preprocessing import Preprocessor
from src.utils.dataset import ForestDataset
from src.models.unet import UNet

from src.utils.loss import loss as criterion

# Notebook settings

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir logs
# %matplotlib inline

# Global parameters

In [None]:
IMG_DIR = "data/images"
PATCH_DIR = "data/patches"
LOGS_DIR = "logs"
MODEL_DIR = "models"
GEDI_FILE = "data/gedi/gedi_complete.fth"
PATCHES_FILE = "data/info/patches.fth"
BATCH_SIZE = 64
NUM_WORKERS = 6
LEARNING_RATE = 1e-4
EPOCHS = 1
DEVICE = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {DEVICE} device")

# Preprocess labels and patches

| image | patch | n_labels |
| - | - | - | 
| L15-1059E-1355N	 | 0	 | 28 |

In [None]:
preprocessor = Preprocessor(
    patches_file=PATCHES_FILE,
    img_dir=IMG_DIR,
    patch_dir=PATCH_DIR,
    gedi_file=GEDI_FILE,
)

preprocessor.run()

patches = preprocessor.patches

print(f"Total number of patches: {len(patches)}")

## Filter patches

In [None]:
# TODO

# Create training & validation splits

In [None]:
# TODO

## Create dataset & dataloader

In [None]:
dataset = ForestDataset(
    img_dir=IMG_DIR,
    patch_dir=PATCH_DIR,
    patches=patches,
)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# Create U-Net model

In [None]:
# Initialize our model
model = UNet().to(DEVICE)

# Specify a loss function and an optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Create a tensorboard writer
writer = SummaryWriter(LOGS_DIR)

# Train U-Net model

In [None]:
# Training loop
for epoch in range(EPOCHS):
    for i, (inputs, targets) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}")):
        # Move inputs and targets to the appropriate device
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE)

        # Forward pass
        outputs = model(inputs)

        # Compute loss
        loss = criterion(outputs, targets)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log the loss value to TensorBoard
        writer.add_scalar("loss/train", loss, epoch * len(dataloader) + i)

    # Print loss every epoch
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item()}")


# Close the SummaryWriter when you're done
writer.close()

# Export model

In [None]:
torch.save(model.state_dict(), os.path.join(MODEL_DIR, "unet.pt"))

# Load pre-trained model

In [None]:
# model = load_model(model, "models/unet.pt", DEVICE)
# model

# Visualise results

In [None]:
# predict_image(model, "data/images/L15-1059E-1348N.tif", DEVICE)

In [None]:
# np.mean(prediction[prediction > 0.5])

In [None]:
# plot_image("data/images/L15-1059E-1348N.tif")