In [1]:
import torch
import sys
import os, json, cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.transforms import functional as F
from utils_scripts import utils
from utils_scripts.classes.metric_logger import MetricLogger

# import transforms, utils, engine, train

import copy
import io
from contextlib import redirect_stdout

# import pycocotools.mask as mask_util
# from pycocotools.coco import COCO
# from pycocotools.cocoeval import COCOeval
# from coco_utils import get_coco_api_from_dataset
import torchvision.models.detection.mask_rcnn
import math
import time

from torchvision.models.detection import keypointrcnn_resnet50_fpn
import warnings

In [2]:
false = False
true = True

KEYPOINTS_NUM = 17

In [3]:
class DatasetClass(Dataset):
    def __init__(self, 
                 root:str, # root folder
                 annos:str, #annottaions file
                 split:list, #mask of test data over all data
                 transform=None, demo=False, only_image=False):
        self.root = root
        with open(annos) as json_file:
            data = json.load(json_file)
            self.annos = np.asarray(data['annotations'])[split]
            self.imgs = np.asarray(
                [os.path.join(self.root, img_dict['file_name']) 
                 for img_dict in np.asarray(data['images'])]
            )[split]
        self.transform = transform
        self.demo = demo
        self.only_image = only_image

                # delete bad images - we assume all are good
#         bad = [
#             'content/socket_v2/frame51_augmented_order_32.jpg',
#             'content/socket_v2/frame60_augmented_order_8.jpg',
#             'content/socket_v2/frame60_augmented_order_22.jpg',
#             'content/socket_v2/frame62_augmented_order_37.jpg',
#             'content/socket_v2/frame60_augmented_order_6.jpg',
#             'content/socket_v2/frame52_augmented_order_15.jpg',
#             'content/socket_v2/frame51_augmented_order_2.jpg',
#             'content/socket_v2/frame61_augmented_order_5.jpg',
#             'content/socket_v2/frame62_augmented_order_0.jpg',
#         ]

#         for bad_name in bad:
#             indx = np.where(self.imgs == bad_name)
#             np.delete(self.imgs,indx)
#             np.delete(self.annos,indx)
        # modify bboxes

        self.bboxes = []
        for indx in range(len(self.annos)):
            bboxes_original = [self.annos[indx]['bbox']]
            bboxes_original[0][2] += bboxes_original[0][0]
            bboxes_original[0][3] += bboxes_original[0][1]
            self.bboxes.append(bboxes_original)

            
    def __getitem__(self, idx):
        if self.only_image:
            img_path = self.only_image
            idx = list(self.imgs).index(img_path)
        else:
            img_path = self.imgs[idx]

        img_original = cv2.imread(img_path)
        img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)

        # bboxes_original = [self.annos[idx]['bbox']]
        # bboxes_original[0][2]+=bboxes_original[0][0]
        # bboxes_original[0][3]+=bboxes_original[0][1]

        bboxes_original = self.bboxes[idx]
        # All objects are glue tubes
        bboxes_labels_original = ['Socket' for _ in bboxes_original]

        keypoints_original = self.annos[idx]['keypoints']
        keypoints_original = [np.asarray(keypoints_original).reshape(-1, 3)]

        # mask = np.array([
#         True,True,False,
#         True,True,False
#         ,True,True,False
#         ,True,True,False,
#         True,True,False
#         ,True,True,False
#         ,True,True,False,
#         True,True,False,
#         True,True,False,
#         True,True,False,
#         True,True,False
#         ,True,True,False
#         ,True,True,False])

        if self.transform:
            # Converting keypoints from [x,y,visibility]-format to [x, y]-format + Flattening nested list of keypoints
            # For example, if we have the following list of keypoints for three objects (each object has two keypoints):
            # [[obj1_kp1, obj1_kp2], [obj2_kp1, obj2_kp2], [obj3_kp1, obj3_kp2]], where each keypoint is in [x, y]-format
            # Then we need to convert it to the following list:
            # [obj1_kp1, obj1_kp2, obj2_kp1, obj2_kp2, obj3_kp1, obj3_kp2]
            keypoints_original_flattened = [el[0:2] for kp in keypoints_original for el in kp]

            # Apply augmentations
            # try to apply aug:
            try:
                transformed = self.transform(image=img_original, bboxes=bboxes_original,
                                             bboxes_labels=bboxes_labels_original, keypoints=keypoints_original_flattened)
                img = transformed['image']
                bboxes = transformed['bboxes']
                #if bboxes[0] < 1 or bboxes[1] < 1 or bboxes[2] < 1 or bboxes[3] < 1:
                #    print(bboxes)
                # print("transformed['keypoints']",transformed['keypoints'])
                # Unflattening list transformed['keypoints']
                # For example, if we have the following list of keypoints for three objects (each object has two keypoints):
                # [obj1_kp1, obj1_kp2, obj2_kp1, obj2_kp2, obj3_kp1, obj3_kp2], where each keypoint is in [x, y]-format
                # Then we need to convert it to the following list:
                # [[obj1_kp1, obj1_kp2], [obj2_kp1, obj2_kp2], [obj3_kp1, obj3_kp2]]
                #print(f"'{img_path}',")
                keypoints_transformed_unflattened =\
                    np.reshape(
                        np.array(transformed['keypoints']),
                        (-1, KEYPOINTS_NUM, 2)
                    ).tolist()

# Converting transformed keypoints from 
# [x, y]-format to [x,y,visibility]-format 
# by appending original visibilities to transformed
# coordinates of keypoints
                keypoints = []
                for o_idx, obj in enumerate(keypoints_transformed_unflattened):  # Iterating over objects
                    obj_keypoints = []
                    for k_idx, kp in enumerate(obj):  # Iterating over keypoints in each object
                        # kp - coordinates of keypoint
                        # keypoints_original[o_idx][k_idx][2] - original visibility of keypoint
                        obj_keypoints.append(kp + [keypoints_original[o_idx][k_idx][2]])
                    keypoints.append(obj_keypoints)
            except Exception as e:
                #print(f'Exception {e}. Apply no augmentation to the image')
                img, bboxes, keypoints = img_original, bboxes_original, keypoints_original

        else:
            #print(bboxes_original)
            img, bboxes, keypoints = img_original, bboxes_original, keypoints_original

            # Convert everything into a torch tensor
        #print('bboxes_original', bboxes_original)
        #print('bboxes',bboxes)
        bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
        #print('bboxes tensor', bboxes)
        target = {}
        target["boxes"] = bboxes
        target["labels"] = torch.as_tensor([1 for _ in bboxes], dtype=torch.int64)  # all objects are glue tubes
        target["image_id"] = torch.tensor([idx])
        target["area"] = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
        target["iscrowd"] = torch.zeros(len(bboxes), dtype=torch.int64)
        target["keypoints"] = torch.as_tensor(keypoints, dtype=torch.float32)
        img = F.to_tensor(img)

        bboxes_original = torch.as_tensor(bboxes_original, dtype=torch.float32)
        target_original = {}
        target_original["boxes"] = bboxes_original
        target_original["labels"] = torch.as_tensor([1 for _ in bboxes_original],
                                                    dtype=torch.int64)  # all objects are glue tubes
        target_original["image_id"] = torch.tensor([idx])
        target_original["area"] = (bboxes_original[:, 3] - bboxes_original[:, 1]) * (
                    bboxes_original[:, 2] - bboxes_original[:, 0])
        target_original["iscrowd"] = torch.zeros(len(bboxes_original), dtype=torch.int64)
        target_original["keypoints"] = torch.as_tensor(keypoints_original, dtype=torch.float32)
        img_original = F.to_tensor(img_original)
        #print(target)
        if self.demo:
            return img, target, img_original, target_original
        else:
            return img, target

    def __len__(self):
        return len(self.imgs)

In [4]:

def get_model(num_keypoints=KEYPOINTS_NUM,
              weights_path=None,
              load_device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'),
              pretrained_net=False):
    if not pretrained_net and not weights_path:
        warnings.warn('model is not pretrained and no state dict provided, the learning process will be long!')
    anchor_generator = AnchorGenerator(
        sizes=(32, 64, 128, 256, 512),
        aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)
    )
    model = keypointrcnn_resnet50_fpn(pretrained=pretrained_net, 
                                      pretrained_backbone=True,
                                      num_keypoints=num_keypoints,
                                      num_classes=2, # Background is the first class, object is the second class
                                      rpn_anchor_generator=anchor_generator)

    if weights_path:
        state_dict = torch.load(weights_path, map_location=load_device)
        model.load_state_dict(state_dict)

    return model

In [None]:
DATASET_LEN = 3468
TRAIN_SIZE = 3100
imgs_folder = 'data/big_experiment'
annotations_file = 'data/big_experiment/extended_17.json'

num_epochs = 450
BATCH_SIZE = 2

# create split
split = np.full(DATASET_LEN, False)
split[:TRAIN_SIZE] = True
np.random.shuffle(split)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

dataset_train = DatasetClass(
    root=imgs_folder,
    annos=annotations_file,
    split=split,
    transform=utils.train_transform(), demo=False)

dataset_test = DatasetClass(
    root=imgs_folder,
    annos=annotations_file,
    split=~split,
    transform=None, demo=False)

data_loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=utils.collate_fn)
data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=utils.collate_fn)

model = get_model(num_keypoints=KEYPOINTS_NUM)#, weights_path=None, pretrained_net=True)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, 
                            lr=8*1e-3,
                            momentum=0.9,
                            weight_decay=0.00012)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=10,
                                               gamma=0.9)

for epoch in range(num_epochs):
    response = utils.train_one_epoch(model,
                    optimizer,
                    data_loader_train,
                    device,
                    epoch,
                    print_freq=100)
    print(f'Epoch {epoch} finished', end='')
    lr_scheduler.step()
#     utils.evaluate(model, data_loader_test, device)   # TODO VALIDATION LOSS ПОПРАВИТЬ
    print(', done validation')
    if epoch % 5 == 0 and epoch != 0:
        torch.save(model.state_dict(), f'outputs/keypointsrcnn_weights_epoch_{epoch}.pth')

# Save model weights after training
torch.save(model.state_dict(), 'outputs/keypointsrcnn_weights_final.pth')



Epoch: [0]	[   0/1550]	eta: 8:35:54	lr: 0.000016	loss: 9.4578 (9.4578)	loss_classifier: 0.6875 (0.6875)	loss_box_reg: 0.0077 (0.0077)	loss_keypoint: 8.0666 (8.0666)	loss_objectness: 0.6938 (0.6938)	loss_rpn_box_reg: 0.0021 (0.0021)	time: 19.9710	data: 0.1250
Epoch: [0]	[ 100/1550]	eta: 8:46:07	lr: 0.000815	loss: 4.3372 (6.8686)	loss_classifier: 0.0458 (0.2044)	loss_box_reg: 0.0691 (0.0527)	loss_keypoint: 4.1883 (6.2106)	loss_objectness: 0.0185 (0.3937)	loss_rpn_box_reg: 0.0066 (0.0072)	time: 23.6815	data: 0.1078
