In [1]:
import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
import skimage.io as io
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
from pycocotools.coco import COCO
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as transforms

%matplotlib inline

In [2]:
def display_image_annot(img, annts):
    img = io.imread(img['coco_url']) if isinstance(img, dict) else img
    if not annts:
        raise('Annotation is empty')
    for annt in annts:
        if 'bbox' in annt:
            bbox = np.ascontiguousarray(annt['bbox']).astype(np.int)
            img = cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), (0, 255, 0), 3)
    plt.axis('off')
    plt.imshow(img)
    plt.show()
    
def display_bbox_annot(img, bboxes):
    img = io.imread(img['coco_url']) if isinstance(img, dict) else img
    for bbox in bboxes:
        bbox = bbox.astype(np.int)
        img = cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 3)
    plt.axis('off')
    plt.imshow(img)
    plt.show()

class CocoDataset(Dataset):
    def __init__(self, root, data_type, transforms, in_memory=False, is_debug=False):
        self.data_type = data_type
        self.transforms = transforms
        self.in_memory = in_memory
        self.is_debug = is_debug
        data_type = data_type.split('/')[-1]
        annts_file = f'{ root }/annotations/instances_{ data_type }.json'
        self.coco = COCO(annts_file)
        
        # Note: In theory, ensures persons presence.
        category_ids = self.coco.getCatIds(catNms=['person'])
        image_ids = self.coco.getImgIds(catIds=category_ids)
        self.image_meta = self.coco.loadImgs(image_ids)
        
        self.images = []; self.annts = []
        for image_meta_data in self.image_meta:
            annts_ids = self.coco.getAnnIds(imgIds=image_meta_data['id'], catIds=category_ids, iscrowd=False)
            img = io.imread(image_meta_data['coco_url']) if self.in_memory else image_meta_data
            self.images.append(img)
            self.annts.append(self.coco.loadAnns(annts_ids))
        
        if self.is_debug and len(self.images) > 0 and len(self.annts) > 0:
            idx = np.random.randint(0, len(self.images) - 1)
            display_image_annot(self.images[idx], self.annts[idx])
            
        self.n = len(self.images)
        
    def __getitem__(self, idx):
        img = self.images[idx] if self.in_memory else io.imread(self.images[idx]['coco_url'])
        bboxes = [ annt['bbox'] for annt in self.annts[idx] if 'bbox' in annt ]
        bboxes = None if len(bboxes) == 0 else np.array(bboxes)
        
        for bbox in bboxes:
            bbox[2] += bbox[0]
            bbox[3] += bbox[1]
        
        if self.is_debug and bboxes.shape[0] > 0:
            display_bbox_annot(self.images[idx], bboxes)
            
        if not self.transforms:
            self.transforms = transforms.ToTensor()
            
        img_tensor = self.transforms(img).float()
        bboxes_tensor = torch.from_numpy(bboxes).float() if bboxes.shape[0] > 0 else None
        
        return (img_tensor, bboxes_tensor)
    
    def __len__(self):
        return self.n

In [3]:
root = 'coco'
img_paths = 'images'
train_dataset = CocoDataset(root, data_type=os.path.join(root, 'train2017'), transforms=None, in_memory=False, is_debug=False)
val_dataset = CocoDataset(root, data_type=os.path.join(root, 'val2017'), transforms=None)

loading annotations into memory...
Done (t=9.81s)
creating index...
index created!
loading annotations into memory...
Done (t=0.25s)
creating index...
index created!


In [6]:
import multiprocessing as mp
import torch.backends.cudnn as cudnn

print(f'Usable threads: { torch.get_num_threads() }')
print(f'cuda version: { torch.version.cuda }\tcudnn version: { cudnn.version() }')

num_workers = (0.4 * mp.cpu_count())
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True, num_workers=num_workers)
#val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=num_workers)

Usable threads: 12
cuda version: 11.1	cudnn version: 8005


In [None]:
# Architecture

class Conv2d(nn.Module):
    def __init__(self, in_chs, out_chs, ksize, stride, padding):
        super(Conv2d, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_chs, out_chs, ksize, stride, padding, bias=False),
            nn.BatchNorm2d(out_chs)
        )
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return self.relu(self.conv(x))
    
class Residual(nn.Module):
    def __init__(self, in_chs):
        super(Residual, self).__init__()
        
        out_chs = in_chs
        out1 = out_chs // 4
        self.conv1 = Conv2d(in_chs, out1, 1)
        out2 = out1 // 4
        self.conv2 = Conv2d(out1, out2, 3)
        self.conv3 = Conv2d(out2, out_chs, 1)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out += x
        return out

# input res: 511 x 511
# output res: 128 x 128
# Note: The hourglass architecture uses global and local perceptive fields.
class Hourglass(nn.Module):
    def __init__(self, in_chs):
        super(Hourglass, self).__init__()
        
        # Note: Reduce feature map resolution 5 times
        self.enc = nn.Sequential(
            self.conv1 = Conv2d(in_chs, 256, 7),
            self.conv2 = Residual(256),

            self.pool1 = nn.MaxPool2d(2, 2),

            self.conv3 = Con2d(256, 512, 7),
            self.conv4 = Residual(512),

            self.pool2 = nn.MaxPool2d(2, 2),

            self.conv5 = Con2d(512, 512, 7),
            self.conv6 = Residual(512),

            self.pool3 = nn.MaxPool2d(2, 2),

            self.conv7 = Con2d(512, 512, 7),
            self.conv8 = Residual(512),

            self.pool4 = nn.MaxPool2d(2, 2),

            self.conv9 = Con2d(512, 512, 7),
            self.conv10 = Residual(512),

            self.pool5 = nn.MaxPool2d(2, 2)
        )
        
        self.dec = nn.Sequential(
            self.up5 = nn.Upsample(scaling_factor=2, mode='nearest'),
        
            self.conv10 = Residual(512),
            self.conv9 = Con2d(512, 512, 7),

            self.up4 = nn.Upsample(scaling_factor=2, mode='nearest'),

            self.conv8 = Residual(512),
            self.conv7 = Con2d(512, 512, 7),

            self.up3 = nn.Upsample(scaling_factor=2, mode='nearest'),

            self.conv6 = Residual(512),
            self.conv5 = Con2d(512, 512, 7),

            self.up2 = nn.Upsample(scaling_factor=2, mode='nearest'),

            self.conv4 = Residual(512),
            self.conv3 = Con2d(256, 512, 7),

            self.up1 = nn.Upsample(scaling_factor=2, mode='nearest'),

            self.conv2 = Residual(256),
            self.conv1 = Conv2d(in_chs, 256, 7),
        )
        
        self.skip1 = nn.Sequential(
            Residual(256),
            Residual(256)
        )
        
        self.skip2 = nn.Sequential(
            Residual(512),
            Residual(512)
        )
        
        self.skip3 = nn.Sequential(
            Residual(512),
            Residual(512)
        )
        
        self.skip4 = nn.Sequential(
            Residual(512),
            Residual(512)
        )
        
        self.skip5 = nn.Sequential(
            Residual(512),
            Residual(512)
        )
        
        
    def forward(self, x):
        skip_conns = []
        
        # Encoding
        out = x
        for idx in range(len(self.enc)):
            out = self.enc[idx]
            if idx != 0 and idx % 2 == 0:
                skip_conns.append(out)
        
        # Decoding
        for idx in range(len(self.dec)):
            out = sel.dec(out)
            if idx != 0 and idx % 2 == 0:
                
            
        out5 = self.up5(x)
        out = self.conv10(out5)
        out = self.conv9(out)
        
        out4 = self.up4(x)
        out = self.conv8(out4)
        out = self.conv7(out)
        
        out3 = self.up3(x)
        out = self.conv6(out3)
        out = self.conv5(out)
        
        out2 = self.up2(x)
        out = self.conv4(out2)
        out = self.conv3(out)
        
        out1 = self.up1(x)
        out = self.conv2(out1)
        out = self.conv1(out)
        
        return out