# Training deep learning models for varroa mite detection

This Jupyter notebook allows one to reproduce the results presented in the article "Analysis of Varroa mite colony infestation level using new open software based on deep learning techniques" by Jose Divasón, Ana Romero, Francisco Javier Martínez de Pisón, Miguel A. Silvestre, Pilar Santolaria and Jesús L. Yániz.

### About the dataset

Note that, before executing this Jupyter notebook, one has to download the dataset. It is available through Zenodo: https://zenodo.org/doi/10.5281/zenodo.10231844. 

Concretely, you can download it from:
- Dataset + labels: https://zenodo.org/records/10231845/files/dataset.zip?download=1
- CSV file for separating in training and validation sets: https://zenodo.org/records/10231845/files/df_dataset.csv?download=1

If you do not want to apply deblurGAN techniques on your own, you can also download the dataset version with deblurGAN from: https://zenodo.org/records/10231845/files/images_deblurGAN.zip?download=1

Unzip the dataset without changing the names. That is, the structure should be as follows:
* dataset
    * images
    * labels
    * images_deblurGAN
* train_varroa_mite_detector.ipynb
* rest of the ipynb and py files

## Imports

In [1]:
from torch.utils.data import Dataset
import os, random, time
import cv2 as cv2
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np

import torch
import torchvision
from torchvision import transforms 
from torchvision import io
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN, fasterrcnn_resnet50_fpn_v2, fasterrcnn_resnet50_fpn, fasterrcnn_mobilenet_v3_large_320_fpn
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
import torchvision.transforms.functional as F_vision
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler
from torchmetrics.detection import MeanAveragePrecision
import timm

from tqdm.notebook import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import albumentations.pytorch
import pandas as pd
from PIL import Image
from fasterrcnn_vitdet import create_VIT_model

## Training architecture and hyperparameters

In this cell one can choose different models and combination of hyperparameters (crops, weights, threshold confidence, number of epochs, enable deblurGAN, batch size, ...).

Please, note that the batch size should depend on the available GPU.


In [2]:
# This is just the name of the output folder where the model will be saved
currentNotebook = "train_varroa_mite_detector.ipynb"

CROP_X, CROP_Y = 224, 224  # Crop size (ideally, it should be set as the input size of the backbone, but it is not mandatory)
BACKBONE_NAME = "resnet18_fpn" # Backbone name
WEIGHTS_BACKBONE = 'DEFAULT' 
THRESHOLD_CONFIDENCE = 0.50
EARLY_STOP = 60 # Stop training if there is no improvement in EARLY_STOP consecutive epochs
NUM_EPOCHS = 600 # Max number of epochs
START_VALID_AT_EPOCH = 90 # Epoch to start validation
DEBLURGAN = False # If deblurGAN is used or not.

SEED = 42
# Training parameters
BATCH_TRAIN = 6
SAMPLES_BY_EPOCH = 42 #64
NUM_WORKERS = 0
BATCH_VALID = 10
BATCH_TEST = 20
# Reduce LR on plateau technique
LR = 0.01 #1e-5 #Learning rate
FACTOR_REDUCE_ONPLATEAU = 0.75
PATIENCE = 10

bbox_params = A.BboxParams(format = 'pascal_voc',         
         min_visibility = 0.6,
         label_fields = ['labels'])

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
print(device)

plt.rcParams["savefig.bbox"] = 'tight'

cuda:0


## Functions

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
# To parse the annotations
import xml.etree.ElementTree as ET
import torchvision.transforms.functional as FT

# Label map
voc_labels = ('varroa', 'pupe')
label_map = {k: v + 1 for v, k in enumerate(voc_labels)}
label_map['background'] = 0
rev_label_map = {v: k for k, v in label_map.items()}  # Inverse mapping

def parse_annotation(annotation_path):
    tree = ET.parse(annotation_path)
    root = tree.getroot()

    boxes = list()
    labels = list()
    difficulties = list()
    for object in root.iter('object'):

        difficult = int(object.find('difficult').text == '1')

        label = object.find('name').text.lower().strip()
        if label not in label_map:
            continue

        bbox = object.find('bndbox')
        xmin = int(bbox.find('xmin').text) - 1
        ymin = int(bbox.find('ymin').text) - 1
        xmax = int(bbox.find('xmax').text) - 1
        ymax = int(bbox.find('ymax').text) - 1

        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(label_map[label])
        difficulties.append(difficult)
    return {'boxes': boxes, 'labels': labels, 'difficulties': difficulties}


In [5]:
# Use from jupyter notebook notebook
from notebook import notebookapp
import urllib
import json
import ipykernel
from pathlib import Path

def create_output_dir(display = True, create=True, dirbase = '../results'):
    base_dir = Path(dirbase)
    if not os.path.exists(base_dir) and create:
        os.mkdir(base_dir)
    dir_models = currentNotebook[:-6]
    dir_models = dir_models.replace('.','_')
    output_dir = base_dir / Path(dir_models) #settings["globals"]["output_dir"])

    if not os.path.exists(output_dir) and create:
        os.mkdir(output_dir)
        if display:
            print("Directory " , output_dir ,  " Created ")
    elif display:
        print("Directory " , output_dir ,  " already exists")
    return output_dir

## Create output folder

In [6]:
output_dir = create_output_dir(display=True, create=True, dirbase = 'results/')
print(output_dir)

Directory  results/train_varroa_mite_detector  already exists
results/train_varroa_mite_detector


## Load dataset

In [7]:
df = pd.read_csv('df_dataset.csv')
print(df.shape)

(807, 13)


In [8]:
df_posic = df[['file_image', 'name', 'file_boxes', 'pos_img', 'is_train', 'size_img_x', 'size_img_y']].drop_duplicates().sort_values('pos_img').reset_index(drop=True)
print(df_posic.shape)
df_posic.head()

(64, 7)


Unnamed: 0,file_image,name,file_boxes,pos_img,is_train,size_img_x,size_img_y
0,dataset/images/IMG_5993.jpg,IMG_5993,dataset/labels/IMG_5993.xml,0,True,6048,8064
1,dataset/images/IMG_5994.jpg,IMG_5994,dataset/labels/IMG_5994.xml,1,False,6048,8064
2,dataset/images/IMG_5700.jpg,IMG_5700,dataset/labels/IMG_5700.xml,2,True,6048,8064
3,dataset/images/IMG_5815.jpg,IMG_5815,dataset/labels/IMG_5815.xml,3,True,6048,8064
4,dataset/images/IMG_5652.jpg,IMG_5652,dataset/labels/IMG_5652.xml,4,False,6048,8064


### Load the full dataset in RAM memory

In [9]:
import rawpy
import imageio
from pathlib import Path

lista_names_imagenes = df_posic['file_image'].values
numero_imagen = dict(zip(lista_names_imagenes, np.arange(len(lista_names_imagenes))))
df_posic['pos_img'] = df_posic['file_image'].map(numero_imagen)
df['pos_img'] = df['file_image'].map(numero_imagen)

list_imagenes = []
list_boxes = []
for idx in tqdm(df_posic.index):
    df_img = df_posic.loc[idx]
    file_image = df_img['file_image']
    file_boxes = df_img['file_boxes'] 
    if DEBLURGAN:
        file_image = "dataset/images_deblurGAN/"+Path(file_image).stem +".jpg"
    image = cv2.imread(file_image)
    image = image.astype(np.float32)
    image /= 255.0              
    list_imagenes.append(np.array(image))
    file_image = df_img['file_image']
    boxes = df.loc[df['file_image']==file_image, ['x_min', 'y_min', 'x_max', 'y_max']].values
    list_boxes.append(boxes)

  0%|          | 0/64 [00:00<?, ?it/s]

### Tiling the dataset

In [10]:
print(image.shape)
valid_imagenes = []
valid_boxes = []
valid_labels = []
valid_names = []
positions=[] # For each sub-image, this saves its correspoindg position (x,y) in the full image.
for nrow in tqdm(range(len(df_posic))):
    row = df_posic.iloc[nrow]
    if not row['is_train']:
        pos_x, pos_y = 0, 0
        size_img_x = row['size_img_x']
        size_img_y = row['size_img_y']
        image_orig = list_imagenes[nrow]
        boxes_orig = list_boxes[nrow]
        labels = torch.ones(len(boxes_orig), dtype=torch.int64)
        while True:                
            x_max = pos_x+CROP_X
            y_max = pos_y+CROP_Y
            if x_max>=size_img_x:                   
                x_max = size_img_x
            if y_max>=size_img_y:                  
                y_max = size_img_y
            transform = A.Compose([A.Crop(x_min = pos_x, y_min=pos_y, x_max=x_max, y_max=y_max), 
                           ToTensorV2()],
                           bbox_params=bbox_params)
            result_transform = transform(image=image_orig, bboxes=boxes_orig, labels=labels)
            valid_imagenes.append(result_transform['image'])
            valid_boxes.append(torch.as_tensor(result_transform['bboxes'], dtype=torch.float32))
            valid_labels.append(torch.ones(len(result_transform['bboxes']), dtype=torch.int64))
            valid_names.append(row['name'])
            positions.append((pos_x, pos_y))
            pos_x += CROP_X
            if pos_x>=size_img_x:
                pos_x = 0
                pos_y += CROP_Y
                # Sal
                if pos_y >= size_img_y:
                    break  

(8064, 6048, 3)


  0%|          | 0/64 [00:00<?, ?it/s]

## DataLoader

In [11]:
def collate_fn(batch):
    return tuple(zip(*batch))
    
class varroa_dataloader(Dataset):

    def __init__(self, df: pd.DataFrame, transforms=None, modo='train'):
        super().__init__()
        self.df = df        
        self.transforms = transforms
        self.modo = modo

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, index: int):

        if self.modo == 'train':
            df_img = self.df.iloc[index]
            pos_img = df_img['pos_img']
            image = list_imagenes[pos_img].copy()
            boxes = list_boxes[pos_img]
            image_name = df_img['name']
            # there is only one class
            labels = torch.ones(len(boxes), dtype=torch.int64)

            target = {}
            target['image_id'] = torch.tensor([pos_img])

            
            transform = A.Compose([
                A.RandomCrop(height=CROP_X, width=CROP_Y), 
                A.Rotate(),
                A.HorizontalFlip(),
                A.VerticalFlip(),
                A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05),
                ToTensorV2()], bbox_params=bbox_params)    
            
            # Balance between images with and without varroa mite during training (p=0.50)
            result_transform = transform(image=image, bboxes=boxes, labels=labels)
            if np.random.random()>0.50:
                while len(result_transform['bboxes'])==0:
                    result_transform = transform(image=image, bboxes=boxes, labels=labels)
            else:
                while len(result_transform['bboxes'])!=0:
                    result_transform = transform(image=image, bboxes=boxes, labels=labels)
                    
            num_box = len(result_transform['bboxes'])
            image = result_transform['image']
            if num_box>0:
                target['boxes'] = torch.as_tensor(result_transform['bboxes'], dtype=torch.float32)
            else:
                target['boxes'] = torch.as_tensor(torch.zeros((0,4), dtype=torch.float32))
            target["labels"] = torch.ones((num_box,), dtype=torch.int64)
            target["area"] = torch.as_tensor((boxes[:, 3] - boxes[:, 1])*(boxes[:, 2] - boxes[:, 0]))
            target["iscrowd"] = torch.zeros((num_box,), dtype=torch.int64)
        
        if self.modo == 'valid':
            boxes = valid_boxes[index]
            image = valid_imagenes[index]
            labels = valid_labels[index]
            if len(boxes)==0:
                boxes = torch.as_tensor(torch.zeros((0,4), dtype=torch.float32))
                labels = torch.ones((0,), dtype=torch.int64)
            target = {}
            num_box = len(boxes)
            target['boxes'] = boxes
            target['labels'] = labels
            target['image_id'] = torch.tensor([index])
            target["area"] = torch.as_tensor((boxes[:, 3] - boxes[:, 1])*(boxes[:, 2] - boxes[:, 0]))
            target["iscrowd"] = torch.zeros((num_box,), dtype=torch.int64)
            image_name = valid_names[index]
        return image, target, image_name


## Model creation

In [12]:
def return_backbone(backbone_name, weights_backbone='DEFAULT'):
    if backbone_name == 'resnet_18':
        resnet_net = torchvision.models.resnet18(weights=weights_backbone)
        modules = list(resnet_net.children())[:-2]
        backbone = torch.nn.Sequential(*modules)
        backbone.out_channels = 512

    elif backbone_name == 'resnet_34':
        resnet_net = torchvision.models.resnet34(weights=weights_backbone)
        modules = list(resnet_net.children())[:-2]
        backbone = torch.nn.Sequential(*modules)
        backbone.out_channels = 512
        
    elif backbone_name == 'resnet_50':
        resnet_net = torchvision.models.resnet50(weights=weights_backbone)
        modules = list(resnet_net.children())[:-2]
        backbone = torch.nn.Sequential(*modules)
        backbone.out_channels = 2048

    elif backbone_name == 'resnet_101':
        resnet_net = torchvision.models.resnet101(weights=weights_backbone)
        modules = list(resnet_net.children())[:-2]
        backbone = torch.nn.Sequential(*modules)
        backbone.out_channels = 2048

    elif backbone_name == 'resnet_152':
        resnet_net = torchvision.models.resnet152(weights=weights_backbone)
        modules = list(resnet_net.children())[:-2]
        backbone = torch.nn.Sequential(*modules)
        backbone.out_channels = 2048
 
    elif backbone_name == 'resnet_50_modified_stride_1':
        resnet_net = resnet50(weights=weights_backbone)
        modules = list(resnet_net.children())[:-2]
        backbone = torch.nn.Sequential(*modules)
        backbone.out_channels = 2048

    elif backbone_name == 'resnext101_32x8d':
        resnet_net = torchvision.models.resnext101_32x8d(weights=weights_backbone)
        modules = list(resnet_net.children())[:-2]
        backbone = torch.nn.Sequential(*modules)
        backbone.out_channels = 2048
    
    elif backbone_name == "regnet_y_400mf":
        model_backbone = torchvision.models.regnet_y_400mf(weights=weights_backbone)
        backbone = torch.nn.Sequential(*list(model_backbone.children())[:-2])
        backbone.out_channels = 440
        
    elif backbone_name == "efficientnet_b0":        
        if weights_backbone == "DEFAULT": 
            weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
            backbone = torchvision.models.efficientnet_b0(weights=weights).features
        else:
            backbone = torchvision.models.efficientnet_b0().features 
        backbone.out_channels = 1280
    
    elif backbone_name == "efficientnet_b1":  
        if weights_backbone == "DEFAULT": 
            weights = torchvision.models.EfficientNet_B1_Weights.DEFAULT
            backbone = torchvision.models.efficientnet_b1(weights=weights).features
        else:
            backbone = torchvision.models.efficientnet_b1().features 
        backbone.out_channels = 1280
    elif backbone_name == "efficientnet_b2":   
        if weights_backbone == "DEFAULT": 
            weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
            backbone = torchvision.models.efficientnet_b2(weights=weights).features
        else:
            backbone = torchvision.models.efficientnet_b2().features 
        backbone.out_channels = 1408        
    elif backbone_name == "efficientnet_b3":
        if weights_backbone == "DEFAULT": 
            weights = torchvision.models.EfficientNet_B3_Weights.DEFAULT
            backbone = torchvision.models.efficientnet_b3(weights=weights).features
        else:
            backbone = torchvision.models.efficientnet_b3().features         
        backbone.out_channels = 1536
    elif backbone_name == "efficientnet_b4":
        if weights_backbone == "DEFAULT": 
            weights = torchvision.models.EfficientNet_B4_Weights.DEFAULT
            backbone = torchvision.models.efficientnet_b4(weights=weights).features
        else:
            backbone = torchvision.models.efficientnet_b4().features     
        backbone.out_channels = 1792
    elif backbone_name == "efficientnet_b5":
        if weights_backbone == "DEFAULT": 
            weights = torchvision.models.EfficientNet_B5_Weights.DEFAULT
            backbone = torchvision.models.efficientnet_b5(weights=weights).features
        else:
            backbone = torchvision.models.efficientnet_b5().features         
        backbone.out_channels = 2048
    elif backbone_name == "efficientnet_b6":
        if weights_backbone == "DEFAULT": 
            weights = torchvision.models.EfficientNet_B6_Weights.DEFAULT
            backbone = torchvision.models.efficientnet_b6(weights=weights).features
        else:
            backbone = torchvision.models.efficientnet_b6().features         
        backbone.out_channels = 2304
    elif backbone_name == "efficientnet_b7":
        if weights_backbone == "DEFAULT": 
            weights = torchvision.models.EfficientNet_B7_Weights.DEFAULT
            backbone = torchvision.models.efficientnet_b7(weights=weights).features
        else:
            backbone = torchvision.models.efficientnet_b7().features 
        backbone.out_channels = 2560
        
    return backbone

if BACKBONE_NAME == 'fasterrcnn_resnet50_fpn':
    model = fasterrcnn_resnet50_fpn(num_classes=2) 
elif BACKBONE_NAME == "vitdet":
    model = create_VIT_model(num_classes=2)
# For fpn-based backbones, see: https://github.com/pytorch/vision/blob/main/torchvision/models/detection/backbone_utils.py
# Possible values are 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 
# 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
elif BACKBONE_NAME.endswith("_fpn"):  
    BACKBONE_NAME = BACKBONE_NAME[:-4] # Drop the _fpn
    backbone = resnet_fpn_backbone(backbone_name=BACKBONE_NAME, weights=WEIGHTS_BACKBONE)
    model = FasterRCNN(backbone, num_classes=2)    
else:
    backbone = return_backbone(BACKBONE_NAME, WEIGHTS_BACKBONE)
    anchor_generator = AnchorGenerator(sizes=((16,32,64, 128, 256),), 
                                        aspect_ratios=((0.5, 1.0, 2.0),))
    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
                                                        output_size=7, sampling_ratio=2)
    model = FasterRCNN(backbone, num_classes=2, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler)
    
model.to(device)

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=1e-05)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=1e-05)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=1e-05)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride

## Train/Val Functions

In [13]:
class Averager:      ##Return the average loss 
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0

    def send(self, value):
        self.current_total += value
        self.iterations += 1

    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations

    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0

In [17]:
def train_epoch(model, loader, optimizer, scheduler, epoch, loss_obj='loss_objectness'):
    model.train()
    loss_classifier = Averager()
    loss_box_reg = Averager()
    loss_objectness = Averager()
    loss_rpn_box_reg = Averager()
    tqdm_bar = tqdm(loader)
    for images, targets, image_ids in tqdm_bar:            
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)   ##Return the loss
        
        loss_classifier.send(loss_dict['loss_classifier'].item())  #Average out the loss
        loss_box_reg.send(loss_dict['loss_box_reg'].item())  #Average out the loss
        loss_objectness.send(loss_dict['loss_objectness'].item())  #Average out the loss
        loss_rpn_box_reg.send(loss_dict['loss_rpn_box_reg'].item())  #Average out the loss
        
        losses = sum(loss for loss in loss_dict.values()) #loss_dict[loss_obj].values()
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        tqdm_bar.set_description(f"Train E:{epoch} - Loss:{loss_dict[loss_obj].item():0.4f}")
    tqdm_bar.close()
    torch.cuda.empty_cache()
    return loss_classifier.value, loss_box_reg.value, loss_objectness.value, loss_rpn_box_reg.value

def valid_epoch(model, loader):
    model.eval()
    metric = MeanAveragePrecision(iou_type="bbox", iou_thresholds=[0.50])
    seleccion_list = []
    targets_list = []
    for images, targets, image_name in tqdm(loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        with torch.no_grad():
            preds = model(images)
        for num_pred in range(len(images)):
            pred_uniq = preds[num_pred]
            cuales = pred_uniq['scores']> THRESHOLD_CONFIDENCE
            seleccion = dict(boxes=pred_uniq['boxes'][cuales].float(),
                             scores=pred_uniq['scores'][cuales].float(),
                             labels=pred_uniq['labels'][cuales])
            seleccion_list.append(seleccion)
            targets_list.append(targets[num_pred])                    
    all_images_preds = []            
    for i, ((x,y), b) in enumerate(zip(positions, seleccion_list)):         
        if (x,y) == (0,0): # New image
            full_image_bboxes = []            
            full_image_scores = []
        if b['boxes'].numel() != 0:            
            for (x_min, y_min, x_max, y_max) in b['boxes'].cpu().numpy():
                full_image_bboxes.append((x_min + x , y_min + y,  x_max + x, y_max + y))            
            aux_scores = []
            for score in b['scores'].cpu().numpy():
                full_image_scores.append(score)            
        if i+1 == len(positions) or positions[i+1]==(0,0):
            all_images_preds.append(dict(boxes=torch.tensor(full_image_bboxes), scores=torch.tensor(full_image_scores), labels=torch.tensor([1]*len(full_image_scores))))

    # Real annotations (ground truth)
    all_targets=[]
    for nrow in tqdm(range(len(df_posic))):
        row = df_posic.iloc[nrow]
        if not row['is_train']:
            a = parse_annotation(row['file_boxes'])
            all_targets.append(dict(boxes=torch.tensor(a['boxes']),labels=torch.tensor(a['labels'])))
                
    metric.update(all_images_preds, all_targets)
    metricas = metric.compute()
    return metricas

## Training phase

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', 
                                                       factor=FACTOR_REDUCE_ONPLATEAU, patience=PATIENCE)

validset = varroa_dataloader(df = pd.DataFrame({'name':valid_names}), modo='valid')
valid_data_loader = DataLoader(validset, batch_size=BATCH_VALID, shuffle=False, drop_last=False, 
                           num_workers=NUM_WORKERS, collate_fn=collate_fn)

start = time.time()
map_metric = -0.01
best_map_metric = -0.01
map_50 = -0.01
best_map_50 = -0.01
map_75 = -0.01
best_map_75 = -0.01
mar_1 = -0.01
best_mar_1 = -0.01
mar_10 = -0.01
best_mar_10 = -0.01
mar_100 = -0.01
best_mar_100 = -0.01

best_epoch = 0
early_stop_count = 0
obj_metric = 'map_50'
metricas = []
best_metricas = []
res = []
for epoch in range(NUM_EPOCHS):
    seed_everything(SEED*(epoch+1))
    trainset = varroa_dataloader(df = df[df.is_train].sample(SAMPLES_BY_EPOCH, replace=False), modo='train')
    train_data_loader = DataLoader(trainset, batch_size=BATCH_TRAIN, shuffle=True, drop_last=False, collate_fn=collate_fn)
    #                                num_workers=NUM_WORKERS, 
    start_epoch = time.time()
    
    if scheduler.__class__ ==  torch.optim.lr_scheduler.OneCycleLR:
        lr = scheduler.get_last_lr()[0]
    else:
        lr = optimizer.param_groups[0]['lr']
    
    loss_classifier, loss_box_reg, loss_objectness, loss_rpn_box_reg = train_epoch(model, train_data_loader,  
                                                                               optimizer, scheduler, epoch)
    torch.cuda.empty_cache()
    print('MEAN=', np.mean(np.array([loss_classifier, loss_box_reg, loss_objectness, loss_rpn_box_reg])), 
          '[',loss_classifier, loss_box_reg, loss_objectness, loss_rpn_box_reg, ']')
    
    # During the first epochs do not perform validation or update the scheduler.
    # --------------------------------------------------------------
    if epoch > START_VALID_AT_EPOCH:
        # Validation
        metricas = valid_epoch(model, valid_data_loader)
        torch.cuda.empty_cache()
        
        # Save best model
        map_metric = metricas[obj_metric].item()
        map_50 = metricas['map_50'].item()
        map_75 = metricas['map_75'].item()
        
        mar_1 = metricas['mar_1'].item()
        mar_10 = metricas['mar_10'].item()
        mar_100 = metricas['mar_100'].item()

        if map_metric > best_map_metric:
            print(f"########## >>>>>>>> Model Improved {obj_metric} From {best_map_metric} to {map_metric}")
            torch.save(model.state_dict(), output_dir / f'modelo.bin')
            best_map_metric = map_metric
            best_epoch = epoch
            early_stop_count = 0
            best_metricas = metricas
        else: 
            early_stop_count += 1
            if early_stop_count>=EARLY_STOP:
                break
        
        if scheduler is not None:
            scheduler.step(map_metric)


    tiempo = round(((time.time() - start)/60),2)
    tiempo_epoch = round(((time.time() - start_epoch)/60),2)
    #             clear_output(wait=True) #to clean warnings
    print(f'Epoch={epoch:02d} LR={lr:0.08f} min={tiempo_epoch:.01f}/{tiempo:.01f} {obj_metric}={map_metric:.03f}\n{metricas}')
    print(f'BestE={best_epoch:02d} {obj_metric}={best_map_metric:.03f}\n{best_metricas}')
    res.append(dict({'epoch':epoch, 'lr':lr, 'tiempo':tiempo,
                     'map_metric':map_metric, 'best_map_metric':best_map_metric, 
                     'map_50': map_50, 'map_75': map_75, 
                     'mar_1': mar_1, 'mar_10': mar_10, 'mar_100': mar_100
                     }))
    res_df = pd.DataFrame(res)
    res_df.to_csv(output_dir / f'modelo_res.csv')

    # Draw curves if not in console
    # -----------------------------
    fig, axs = plt.subplots(2,2, figsize=(15,15))

    axs[1,0].plot(res_df['map_metric'].values, label='map')
    axs[1,0].set_xlabel('Epochs')
    axs[1,0].set_ylabel('MAP')
    axs[1,0].set_title(f'MAP={map_metric:.4f} BestMap={best_map_metric:.4f} in Epoch{best_epoch}')

    axs[0,0].plot(res_df['map_50'].values, label='map_50')
    axs[0,0].plot(res_df['map_75'].values, label='map_75')
    axs[0,0].set_xlabel('Epochs')
    axs[0,0].set_ylabel('MAPS')
    axs[0,0].set_title(f'map_50={map_50:.4f}({np.max(res_df["map_50"].values):.4f}) map_75={map_75:.4f}({np.max(res_df["map_75"].values):.4f})')

    axs[0,1].plot(res_df['mar_10'].values, label='mar_10')
    axs[0,1].plot(res_df['mar_100'].values, label='mar_100')

    axs[0,1].set_xlabel('Epochs')
    axs[0,1].set_ylabel('MARS')
    axs[0,1].set_title(f'mar_10={mar_10:.4f}({np.max(res_df["mar_10"].values):.4f}) mar_100={mar_100:.4f}({np.max(res_df["mar_100"].values):.4f})')
    
    axs[1,1].plot(res_df['lr'].values)
    axs[1,1].set_xlabel('Epochs')
    axs[1,1].set_ylabel('Learning Rate')
    axs[1,1].set_title(f'Learning Rate={lr:.8f} Max={res_df.lr.max():.8f} Min={res_df.lr.min():.8f}')
    fig.savefig(output_dir / f'modelo_errors.png',facecolor='white', edgecolor='white')
    plt.close(fig)

  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.10561758452760321 [ 0.09000826054917914 0.02476246069584574 0.30455915204116274 0.0031404648242252214 ]
Epoch=00 LR=0.01000000 min=0.1/0.1 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.05335484154056757 [ 0.10486388499183315 0.058472695627382824 0.04731254492487226 0.00277024061818208 ]
Epoch=01 LR=0.01000000 min=0.2/0.3 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.0373574011443582 [ 0.05666079106075423 0.0558941164719207 0.03437674737402371 0.0024979496707341503 ]
Epoch=02 LR=0.01000000 min=0.2/0.5 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.03741548860645188 [ 0.056832353451422284 0.06620009749063424 0.023807055982095853 0.002822447501655136 ]
Epoch=03 LR=0.01000000 min=0.2/0.7 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.02848701334525166 [ 0.04743823009942259 0.047425808651106696 0.016466155648231506 0.002617858982245837 ]
Epoch=04 LR=0.01000000 min=0.2/0.8 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.019241351957524397 [ 0.028787184772746905 0.0334027408223067 0.012877819261380605 0.0018976629736633705 ]
Epoch=05 LR=0.01000000 min=0.1/1.0 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.02382920861938536 [ 0.032665000430175235 0.04718231436397348 0.012819996702351741 0.002649522981040978 ]
Epoch=06 LR=0.01000000 min=0.2/1.2 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.02718536490907094 [ 0.038878529997808595 0.053407368649329455 0.013977105729281902 0.0024784552598638193 ]
Epoch=07 LR=0.01000000 min=0.2/1.4 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.026483453785268857 [ 0.034922591443838816 0.054562525025435855 0.01264394773170352 0.0038047509400972296 ]
Epoch=08 LR=0.01000000 min=0.2/1.6 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.027706270842047936 [ 0.04539698548614979 0.05342158967895167 0.009981683788022824 0.002024824415067477 ]
Epoch=09 LR=0.01000000 min=0.3/1.9 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.028240554751911468 [ 0.04075551897819553 0.05615568320666041 0.01315432381150978 0.0028966930112801492 ]
Epoch=10 LR=0.01000000 min=0.3/2.3 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.027477828371047508 [ 0.0462259440017598 0.04876852833798954 0.012563549647373813 0.0023532914970668833 ]
Epoch=11 LR=0.01000000 min=0.3/2.6 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.02409028102868303 [ 0.03451697288879326 0.050957903265953064 0.0087884979854737 0.002097749974512096 ]
Epoch=12 LR=0.01000000 min=0.3/2.9 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.023915012527530574 [ 0.032869967365903516 0.04859990120998451 0.011375106605035918 0.00281507492919835 ]
Epoch=13 LR=0.01000000 min=0.3/3.2 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.0229837648886522 [ 0.02885410681899105 0.05061797824289117 0.01034968625754118 0.0021132882351854016 ]
Epoch=14 LR=0.01000000 min=0.3/3.5 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.025240340552824946 [ 0.03530736440526588 0.05192537432802575 0.010155593843332358 0.0035730296346758094 ]
Epoch=15 LR=0.01000000 min=0.2/3.6 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.024146883316071968 [ 0.030843390950134823 0.05490754118987492 0.008226568211934395 0.0026100329123437405 ]
Epoch=16 LR=0.01000000 min=0.2/3.9 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.022298034919783407 [ 0.0306672961118498 0.04835241234728268 0.0073764756255384 0.0027959555944627418 ]
Epoch=17 LR=0.01000000 min=0.3/4.2 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.024347012978978455 [ 0.035611379093357494 0.053325611033609936 0.007120123964601329 0.0013309378243450607 ]
Epoch=18 LR=0.01000000 min=0.3/4.5 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.028363605433176935 [ 0.03853653557598591 0.0620160877172436 0.010443715511688165 0.0024580829277900712 ]
Epoch=19 LR=0.01000000 min=0.2/4.7 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

MEAN= 0.024452639318561915 [ 0.03517950206462826 0.05034496236060347 0.009582265446494733 0.002703827402521191 ]
Epoch=20 LR=0.01000000 min=0.2/4.9 map_50=-0.010
[]
BestE=00 map_50=-0.010
[]


  0%|          | 0/7 [00:00<?, ?it/s]

**Once here, the training is over!**