# Detector Linked-CNN Fish4Knowledge

En primer lugar, establecemos las variables de entorno necesarias para usar las dos GPUs.

In [1]:
import os; os.environ['CUDA_VISIBLE_DEVICES']='0,1'

Importamos las librerías necesarias para ejecutar este cuaderno.

In [2]:
from pathlib import Path
import os
import numpy as np
import torch
from PIL import Image
from xml.dom import minidom
import torchvision
from faster_rcnn import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import transforms as T
import numpy as np
import pandas as pd

from engine_linked import train_one_epoch, evaluate
import utils

from fastai.vision.all import *

Listamos el directorio padre del proyecto para comprobar su contenido.

In [3]:
!ls /media/Shared/jossalgon/fishly/detector/Fish4Knowledge/

Fish4Knowledge.csv  fish_image	mask_image  models  models-smallbndbox


Definimos la ruta padre del proyecto donde se encuentran los datos.

In [4]:
root = Path('/media/Shared/jossalgon/fishly/detector/Fish4Knowledge/')

Cargamos el archivo CSV en forma de _dataframe_ de Pandas con las clases jerárquicas de cada especie y su ID.

In [5]:
classes = pd.read_csv(str(root/'Fish4Knowledge.csv'), delimiter=';').set_index('ID')
classes

Unnamed: 0_level_0,Order,Family,Subfamily,Genus,Specie
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,Perciformes,Pomacentridae,Pomacentrinae,Dascyllus,Dascyllus reticulatus
2,Perciformes,Pomacentridae,Pomacentrinae,Plectroglyphidodon,Plectroglyphidodon dickii
3,Perciformes,Pomacentridae,Pomacentrinae,Chromis,Chromis chrysura
4,Perciformes,Pomacentridae,Amphiprioninae,Amphiprion,Amphiprion clarkii
5,Perciformes,Chaetodontidae,,Chaetodon,Chaetodon lunulatus
6,Perciformes,Chaetodontidae,,Chaetodon,Chaetodon trifascialis
7,Beryciformes,Holocentridae,Myripristinae,Myripristis,Myripristis kuntee
8,Perciformes,Acanthuridae,Acanthurinae,Acanthurus,Acanthurus nigrofuscus
9,Perciformes,Labridae,Corinae,Hemigymnus,Hemigymnus fasciatus
10,Beryciformes,Holocentridae,Holocentrinae,Neoniphon,Neoniphon sammara


Mostramos la cantidad de clases por nivel de especificación con _describe_.

In [6]:
classes.describe()

Unnamed: 0,Order,Family,Subfamily,Genus,Specie
count,23,23,15,22,22
unique,3,13,8,20,22
top,Perciformes,Pomacentridae,Pomacentrinae,Chaetodon,Canthigaster valentini
freq,19,7,6,2,1


Guardamos las familias del conjunto de datos y el número de clases de especie.

In [7]:
FAMILIES = list(set(classes['Family']))
FAMILIES

['Chaetodontidae',
 'Pempheridae',
 'Pomacentridae',
 'Scaridae',
 'Tetraodontidae',
 'Zanclidae',
 'Lutjanidae',
 'Holocentridae',
 'Balistidae',
 'Siganidae',
 'Nemipteridae',
 'Labridae',
 'Acanthuridae']

In [8]:
NUM_CLASSES = len(set(classes['Specie']))+1

Definimos una función para obtener el ID de la especie mediante el fichero.

In [9]:
def get_fish_id_by_filename(fish_filename):
    fish_path = glob.glob(str(root/'fish_image'/'**'/fish_filename))
    if fish_path:
        fish_id = str(Path(fish_path[0]).parent).split('fish_')[-1]
        return int(fish_id)
    else:
        return None

Definimos una función para obtener el ID de la familia mediante el ID de la especie.

In [10]:
def get_family_id_by_fish_id(fish_id):
    family = classes.loc[fish_id]['Family']
    return FAMILIES.index(family)+1

Defimos el objeto _Dataset_ para obtener y gestionar los elementos.

In [11]:
class FishDataset(object):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(sorted(glob.iglob(str(root/'fish_image'/'**/*.png'), recursive=True)))
        self.masks = list(sorted(glob.iglob(str(root/'mask_image'/'**/*.png'), recursive=True)))
        self.classes = classes


    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, "fish_image", self.imgs[idx])
        mask_path = os.path.join(self.root, "mask_image", self.masks[idx])
        fish_id = get_fish_id_by_filename(img_path)
        family_id = get_family_id_by_fish_id(fish_id)
        img = Image.open(img_path).convert("RGB")
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = Image.open(mask_path)
        # convert the PIL Image into a numpy array
        mask = np.array(mask)
        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.as_tensor([(fish_id, family_id)]*num_objs)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

Creamos una clase para nuestro modelo que divida los _encoder_ en los distintos bloques. En el método _forward_ establecemos las conexiones entre los bloques para formar la arquitectura propuesta Linked-CNN, y devolver ambas salidas en forma de diccionario.

In [12]:
class Resnet50CustomModelBackbone(Module):
    def __init__(self, coarse_encoder, fine_encoder):
        self.coarse_stem = coarse_encoder[:4]
        self.coarse_block1 = coarse_encoder[4]
        self.coarse_block2 = coarse_encoder[5]
        self.coarse_block3 = coarse_encoder[6]
        self.coarse_block4 = coarse_encoder[7]
        
        self.fine_stem = fine_encoder[:4]
        self.fine_block1 = fine_encoder[4]
        self.fine_block2 = fine_encoder[5]
        self.fine_block3 = fine_encoder[6]
        self.fine_block4 = fine_encoder[7]
        
    
    def forward(self, x):
        x_c = self.coarse_stem(x)
        x_f = self.fine_stem(x)
        x_f = x_f + x_c
        
        x_c = self.coarse_block1(x_c)
        x_f = self.fine_block1(x_f)
        x_f = x_f + x_c
        
        x_c = self.coarse_block2(x_c)
        x_f = self.fine_block2(x_f)
        x_f = x_f + x_c

        x_c = self.coarse_block3(x_c)
        x_f = self.fine_block3(x_f)

        x_c = self.coarse_block4(x_c)
        x_f = self.fine_block4(x_f)
        
        return {
                'fine': x_f,
                'coarse': x_c
               }

Creamos el modelo con la clase definida pasándole los cuerpos de ambos modelos con distintos niveles de especificación.

In [13]:
arch = resnet50
fine_body = create_body(arch, cut=-2, pretrained=True)
coarse_body = create_body(arch, cut=-2, pretrained=True)
backbone = Resnet50CustomModelBackbone(fine_body, coarse_body)

Establecemos las dimensiones de salida del _backbone_, creamos el _anchor_, _ROI_ y modelo FasterRCNN.

In [14]:
# load a pre-trained model for classification and return
# only the features
# backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# FasterRCNN needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
backbone.out_channels = 2048

# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios
anchor_generator = AnchorGenerator(sizes=((256, 512, 640, 800),),
                                   aspect_ratios=((0.5, 1.0, 2.0),))

# let's define what are the feature maps that we will
# use to perform the region of interest cropping, as well as
# the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an
# OrderedDict[Tensor], and in featmap_names you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
                                                output_size=7,
                                                sampling_ratio=2)

# put the pieces together inside a FasterRCNN model
model = FasterRCNN(backbone,
                   num_classes=NUM_CLASSES,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)

in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)

Definimos la función de aumento de datos.

In [15]:
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

Creamos los _dataloaders_, definimos las particiones de entrenamiento y validación, construimos el optimizador, entrenamos nuestro modelo durante 10 épocas y evaluamos.

In [16]:
models_dir = Path(root/'models')
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# use our dataset and defined transformations
dataset = FishDataset(root, get_transform(train=True))
dataset_test = FishDataset(root, get_transform(train=False))


# split the dataset in train and test set
n_test = int(len(dataset)*0.20)
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-n_test])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-n_test:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=1, shuffle=True, num_workers=4,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn) 


# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

# let's train it for 10 epochs
num_epochs = 10

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=5000)
    torch.save(model.state_dict(), str(models_dir/f"faster_rcnn_linked_resnet50-{epoch}epochs.pth"))
    
    # update the learning rate
    lr_scheduler.step()
    
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)



Epoch: [0]  [    0/21896]  eta: 13:47:34  lr: 0.000010  loss: 7.1627 (7.1627)  loss_classifier: 6.3913 (6.3913)  loss_box_reg: 0.0001 (0.0001)  loss_objectness: 0.7294 (0.7294)  loss_rpn_box_reg: 0.0420 (0.0420)  time: 2.2677  data: 0.0994  max mem: 10394
Epoch: [0]  [ 5000/21896]  eta: 2:03:23  lr: 0.005000  loss: 0.2294 (0.4671)  loss_classifier: 0.1224 (0.2964)  loss_box_reg: 0.1066 (0.1362)  loss_objectness: 0.0018 (0.0239)  loss_rpn_box_reg: 0.0036 (0.0107)  time: 0.4083  data: 0.0022  max mem: 10394
Epoch: [0]  [10000/21896]  eta: 1:23:34  lr: 0.005000  loss: 0.1913 (0.3777)  loss_classifier: 0.1255 (0.2389)  loss_box_reg: 0.0610 (0.1158)  loss_objectness: 0.0025 (0.0151)  loss_rpn_box_reg: 0.0032 (0.0078)  time: 0.4009  data: 0.0020  max mem: 10394
Epoch: [0]  [15000/21896]  eta: 0:47:33  lr: 0.005000  loss: 0.1784 (0.3328)  loss_classifier: 0.0937 (0.2108)  loss_box_reg: 0.0703 (0.1040)  loss_objectness: 0.0012 (0.0115)  loss_rpn_box_reg: 0.0015 (0.0065)  time: 0.3810  data: 0.