Notebook muy sencillito para mostrar cómo hemos construido y entrenado a la red.

In [None]:
# IMPORTANTE: Instalar la base de datos siguiendo las instrucciones del README
# e indicar en estas dos variables los direcctorios!!
TRAIN_DATA_PATH = "TrainReal"
TEST_DATA_PATH = "TestReal"

In [None]:
# Si se ejecuta desde colab, descomentar esta celda para descargar y
# descomprimir los datos de forma automática.

# from IPython.display import clear_output

# !wget "https://lilablobssc.blob.core.windows.net/conservationdrones/v01/conservation_drones_train_real.zip"
# !unzip conservation_drones_train_real.zip 
# clear_output()

In [7]:
import os
import numpy as np
import pandas as pd
import matplotlib.pylab as plt

from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

WEIGHTS = FasterRCNN_ResNet50_FPN_Weights.DEFAULT

def get_tunned_model(num_classes):
  # Cargamos el modelo pre-entrenado
  model = fasterrcnn_resnet50_fpn(weights=WEIGHTS)

  # Reemplazamos el clasificador de la red
  in_features = model.roi_heads.box_predictor.cls_score.in_features
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
  return model

def get_transform():
  return WEIGHTS.transforms()

## Comprobamos que podemos inferir sobre la red sin entrenar

In [None]:
from UAVIR_tools import BIRDSAIDataset, imgBoxes

# Testeamos el tunned_model no entrenado
transforms = get_transform()

dataset = BIRDSAIDataset(TRAIN_DATA_PATH)
img, _ = dataset[0]
images = [transforms(d) for d in [img]]

model = get_tunned_model(3)
model_eval = model.eval()
outputs = model_eval(images)

# Mostramos el resultado
img_box = imgBoxes(img, outputs[0]["boxes"], outputs[0]["labels"])

fig = plt.figure()
ax = fig.add_subplot()
ax.axis('off')

ax.imshow(img_box)
plt.show()

## Entrenamos el modelo

In [None]:
import torch
import torchvision

from UAVIR_tools.detection.engine import train_one_epoch, evaluate
import UAVIR_tools.detection.utils as utils

# ------- MAIN --------- #
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device: {}\n".format(device))

# our dataset has two classes only - background and person
num_classes = 3 # 2 human, 1 animal, 0 background

# use our dataset and defined transformations
dataset = BIRDSAIDataset(TRAIN_DATA_PATH, get_transform(), 4000)
dataset_test = BIRDSAIDataset(TEST_DATA_PATH, get_transform())

# split the dataset in train and test set
# indices = torch.randperm(len(dataset)).tolist()
# dataset = torch.utils.data.Subset(dataset, indices[:-50])
# dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
  dataset, batch_size=10, shuffle=True, num_workers=4,
  collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
  dataset_test, batch_size=1, shuffle=False, num_workers=4,
  collate_fn=utils.collate_fn)

# get the model using our helper function
model = get_tunned_model(num_classes) 

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=3,
                                                gamma=0.1)

# let's train it for 10 epochs
num_epochs = 3

for epoch in range(num_epochs):
  # train for one epoch, printing every 10 iterations
  train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
  # update the learning rate
  lr_scheduler.step()
  # evaluate on the test dataset
  evaluate(model, data_loader_test, device=device)

## Comprobamos resultados realizando alguna inferencia

In [None]:
from torchvision import transforms as T

transforms = get_transform()

# Seleccionamos el vídeo y el frame
movie_id = 21
frame = 200

frame_init, frame_end, n_frames = dataset.getMovieBoxInfo(movie_id)

# Recogemos la img del dataset
img, _ = dataset[frame_init + frame]
images = [img.to(device)]

# Inferimos
model_eval = model.eval()
outputs = model_eval(images)

# Mostramos el resultado
FIGSIZE = [16, 9]
RES = 1280 # 720p

transform_toImg = T.ToPILImage()
img = transform_toImg(img)
img_box = imgBoxes(img, outputs[0]["boxes"], outputs[0]["labels"])

fig = plt.figure(figsize=FIGSIZE, dpi=RES/FIGSIZE[0])
ax = fig.add_subplot()
ax.axis('off')

ax.imshow(img_box)
plt.show()