# Imports

In [2]:
# Imports
import os

from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

from src.models.vitnet import VitNet
from src.utils import (
    Preprocessor,
    plot_image_and_prediction,
    predict_image,
    predict_patch,
    seed_everyting,
    train,
    test,
    save,
    get_splits,
    get_datasets,
    get_dataloaders,
    get_device,
    loss,
)

# Global parameters

In [3]:
patch_size = 64
img_dir = "data/images"
model_dir = "models"
patch_dir = "data/patches"
gedi_dir = "data/gedi"
random_state = 42
batch_size = 2
num_workers = os.cpu_count()
learning_rate = 1e-4
epochs = 2
bins = list(range(0, 55, 5))
device = get_device()

seed_everyting(random_state)

Using mps device


# Preprocess labels and patches

In [4]:
# Create preprocessor
preprocessor = Preprocessor(img_dir, patch_dir, gedi_dir, patch_size)

# Run preprocessor
preprocessor.run()

# Get patches
patches = preprocessor.patches

INFO:root:Starting preprocessing...
INFO:root:Directories validated.
INFO:root:Images loaded.
INFO:root:Number of images: 48
INFO:root:GEDI data loaded.
INFO:root:Loaded existing patch info file. Skipping image processing.
INFO:root:Number of patches: 101423
INFO:root:Number of labels: 629074


Total number of patches: 101423


# Create datasets & dataloader

In [4]:
# Create splits
train_df, val_df, test_df = get_splits(patches)

# Create datasets
train_ds, val_ds, test_ds = get_datasets(
    train_df, val_df, test_df, f"{patch_dir}/{patch_size}"
)

# Create dataloaders
train_dl, val_dl, test_dl = get_dataloaders(
    train_ds, val_ds, test_ds, batch_size, num_workers
)

# Create & Train model

In [5]:
model = VitNet(
    image_size=patch_size,
    hidden_size=patch_size * 2,
    intermediate_size=patch_size * 4,
).to(device)

# Create optimizer
optimizer = AdamW(model.parameters(), learning_rate)

# Create scheduler
scheduler = CosineAnnealingLR(optimizer, epochs)

# Training loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}\n-------------------------------")
    train(train_dl, model, loss, device, optimizer, scheduler)
    test(val_dl, model, loss, device)

Epoch 1
-------------------------------


Training:   0%|          | 1/350 [00:15<1:28:36, 15.23s/it]

Train loss: 10.599902  [    2/  700]


Training:  15%|█▍        | 52/350 [00:29<01:15,  3.93it/s] 

In [None]:
# save(model, os.path.join(model_dir, f"{model.name}.pt"))

# Visualise results

In [None]:
idx = 42
patch = test_ds[idx]
img, pred = predict_patch(model, patch, device)
plot_image_and_prediction(img, pred, 3)

In [None]:
image, prediction = predict_image(
    model, device, f"{img_dir}/L15-1059E-1348N.tif", patch_size
)

In [None]:
plot_image_and_prediction(image, prediction, 3)