In [None]:
# Copyright (c) 2023 William Locke

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

This notebook is intended to be run in Google Colab with access to corresponding Google Drive files. If running locally or on another service, change import and install code accordingly.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lu-liang-geo/UAV_Tree_Detection/blob/main/notebooks/Train_Box_Decoder.ipynb)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%capture
!pip install rasterio
!pip install supervision

In [None]:
#@title Copy personal SAM (can restart here for testing changes)

%cd /content
import os
if os.path.exists('/content/segment-anything'):
  !rm -r /content/segment-anything
!git clone https://github.com/lu-liang-geo/UAV_Tree_Detection.git
%cd /content/segment-anything
!pip install -q .
%cd /content

In [None]:
#@title Fixes an occasional bug upon restarting runtime

import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [None]:
#@title Import Modules
%cd /content
from segment_and_detect_anything import NEONTreeDataset, VectorDataset, train_one_epoch
from segment_and_detect_anything.detr import HungarianMatcher, SetCriterion
from segment_and_detect_anything.detr import misc as utils
from segment_and_detect_anything.modeling import BoxDecoder, TwoWayTransformer
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import supervision as sv
import rasterio
import torch
from torch.utils.data import DataLoader
from PIL import Image
import xml.etree.ElementTree as ET

In [None]:
# Preprocessed data used for VectorDataset
vector_path = '/content/drive/MyDrive/UAV/Data/NEONTreeEvaluation/Evaluation/Image Vectors'
ann_cxcywh_path = '/content/drive/MyDrive/UAV/Data/NEONTreeEvaluation/Evaluation/Annotations'
prompt_path = '/content/drive/MyDrive/UAV/Data/NEONTreeEvaluation/Evaluation/Prompts'

In [None]:
vector_ds = VectorDataset(image_path=vector_path, prompt_path=prompt_path, ann_path=ann_cxcywh_path)
vector_dl = DataLoader(vector_ds, batch_size=2, shuffle=False, collate_fn=utils.vector_collate_fn)

In [None]:
# The first time running through the Dataset (or Dataloader) takes about 5 minutes; subsequent runs
# take about 5 seconds. I do the first runthrough here so as not to affect training time.
for vector in vector_ds:
  pass

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

box_decoder = BoxDecoder(
    transformer = TwoWayTransformer(
        depth=2,
        embedding_dim=256,
        mlp_dim=2048,
        num_heads=8
    ),
    transformer_dim = 256,
    num_boxes = 200
)

matcher = HungarianMatcher(cost_class=0)

criterion = SetCriterion(num_classes=1,
                         matcher=matcher,
                         weight_dict={'loss_ce':1, 'cardinality':0, 'loss_bbox':1, 'loss_giou':1},
                         eos_coef=1,
                         losses=['labels','cardinality','boxes'])

optimizer = torch.optim.AdamW(box_decoder.parameters())

utils.model_size(box_decoder)

In [None]:
box_decoder.train()
criterion.train()
metric_logger = utils.MetricLogger(delimiter="  ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Epoch: [0]'
print_freq = 10

batch = next(iter(vector_dl))
batch_outputs = []
for vector in batch:
    rgb_vector = vector.get('rgb', torch.empty(0))
    multi_vector = vector.get('multi', torch.empty(0))
    image_vector = torch.cat((rgb_vector, multi_vector), dim=1).to(device)
    if image_vector.numel()==0:
        raise ValueError('Either RGB or Multi vector must be provided to model, but both are empty.')
    sparse_prompt = vector['prompt']['sparse'].to(device)
    position_prompt = vector['prompt']['position'].to(device)

    outputs = box_decoder(image_vector,
                      position_prompt,
                      sparse_prompt)
    batch_outputs.append(outputs)

preds = {k : torch.cat([output[k] for output in batch_outputs]) for k in ['pred_boxes', 'pred_logits']}
targets = [vector['annotation'] for vector in batch]
loss_dict = criterion(preds, targets)

In [None]:
loss_dict

In [None]:
metrics = train_one_epoch(box_decoder, criterion, vector_dl, optimizer, device, epoch=1)

In [None]:
# Save the first epoch's metrics.
model_folder = '/content/drive/MyDrive/UAV/Models'
model_name = '???'
with open(os.path.join(model_folder, 'Logs', f'{model_name}.txt'), 'a') as f:
  f.write(f'Epoch {i}\n')
  for k, v in metrics.items():
    f.write(f'{k:<30} {v:.10}\n')
  f.write('\n')

In [None]:
# If first epoch looks promising (model is learning) continue training for num_epochs.
# Save metrics at end of each epoch, save model at end of num_epochs.
num_epochs = 10
for i in range(2,num_epochs+1):
  metrics = train_one_epoch(box_decoder, criterion, vector_dl, optimizer, device, epoch=i)
  with open(os.path.join(model_folder, 'Logs', f'{model_name}.txt'), 'a') as f:
    f.write(f'Epoch {i}\n')
    for k, v in metrics.items():
      f.write(f'{k:<30} {v:.10}\n')
    f.write('\n')

torch.save({
            'epoch': i,
            'box_decoder_state_dict': box_decoder.state_dict(),
            'criterion_state_dict': criterion.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics,
            }, os.path.join(model_folder, f'{model_name}.pth'))

In [None]:
# Save model separately from above if necessary (e.g. if model training did not complete)
torch.save({
            'epoch': i,
            'box_decoder_state_dict': box_decoder.state_dict(),
            'criterion_state_dict': criterion.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics,
            }, os.path.join(model_folder, f'{model_name}.pth'))