In [None]:
%matplotlib inline
import os
import numpy as np 
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.utils.data as td 
import torchvision as tv
from PIL import Image
import matplotlib.pyplot as plt
import nntools as nt
import math
import xml.etree.ElementTree as ET
import cv2

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
dataset_root_dir = '/datasets/ee285f-public/PascalVOC2012'

In [None]:
class VOCDataset(td.Dataset):

    def __init__(self, root_dir, mode='train', image_size=(448, 448), S=7): 
        super(VOCDataset, self).__init__()
        self.mode = mode
        self.image_size = image_size
        self.S = S
        if(mode=='train'):
            self.list_file = os.path.join(root_dir, 'ImageSets/Main/trainval.txt')
        else:
            self.list_file = os.path.join(root_dir, 'ImageSets/Main/val.txt')
        
        self.annot_dir = os.path.join(root_dir, 'Annotations')
        
        with open(self.list_file) as f:
            lines = f.readlines()    
        
        self.image_names = []
        for line in lines:
            self.image_names.append(line[:11])
        
        self.images_dir = os.path.join(root_dir, 'JPEGImages')

    def __len__(self):
        return len(self.image_names)

    def __repr__(self):
        return "VOCDataset(mode={}, image_size={}, sigma={})". \
            format(self.mode, self.image_size, self.sigma)

    def __getitem__(self, idx):
        iname = self.image_names[idx]
        img_path = os.path.join(self.images_dir, iname+'.jpg') 
        
        # Read annotations from xml
        tree = ET.parse(os.path.join(self.annot_dir, iname+'.xml'))
        boxes = []
        labels = []
        for obj in tree.iter(tag='object'):
            bbox = [int(obj.find('bndbox').find('xmax').text),\
                    int(obj.find('bndbox').find('ymax').text),\
                    int(obj.find('bndbox').find('xmin').text),\
                    int(obj.find('bndbox').find('ymin').text)]
            label = obj.find('name').text
            boxes.append(bbox)
            labels.append(label)
        isize = (int(tree.find('size').find('width').text),int(tree.find('size').find('height').text))
        boxes = torch.FloatTensor(boxes)
        
        # Read images and perform random processing and normalization
        img = cv2.imread(img_path)
        if(self.mode == 'train'):
            img = self.RandomBrightness(img)
            img = self.RandomSaturation(img)
        img = cv2.resize(img, self.image_size)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        transform = tv.transforms.Compose([
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((.5, .5,.5),(.5, .5, .5))
            ])
        img = transform(img)
        
        target = self.encoder(boxes, labels, isize)
        
        # Return img(3x448x448 torch tensor), target(7x7x30 torch tensor)
        return img, target
    
    def encoder(self, boxes, labels, isize):
        '''
        Encode boxes and labels to 7x7x30 tensor. For each area, the 30 len tensor has such structure:
        [ 20(class label) | 1(C) | 1(C) | 4(width, height, center_w, center_h, and all are ratio) | 4(the same) ]
        '''
        S = self.S
        class_dict = {'aeroplane':0, 'bicycle':1, 'bird':2, 'boat':3, 'bottle':4, 'bus':5, 'car':6, 'cat':7, \
                     'chair':8, 'cow':9, 'diningtable':10, 'dog':11, 'horse':12, 'motorbike':13, 'person':14, \
                      'pottedplant':15, 'sheep':16, 'sofa':17, 'train':18, 'tvmonitor':19}
        target = torch.zeros((S, S, 30))
        wh = boxes[:, :2]-boxes[:, 2:]
        cxcy = (boxes[:, :2]+boxes[:, 2:])/2
        for i in range(cxcy.size()[0]):
            center = cxcy[i]
            loc = (int(center[0]/(isize[0]/S)), int(center[1]/(isize[1]/S)))
            target[loc[0], loc[1], 20] = 1
            target[loc[0], loc[1], 21] = 1
            target[loc[0], loc[1], class_dict[labels[i]]] = 1
            normalized_wh = torch.tensor([wh[i,0]/isize[0], wh[i,1]/isize[1]])
            normalized_center = torch.tensor([center[0]/isize[0], center[1]/isize[1]])
            target[loc[0], loc[1], 22:24] = normalized_wh
            target[loc[0], loc[1], 24:26] = normalized_center
            target[loc[0], loc[1], 26:28] = normalized_wh
            target[loc[0], loc[1], 28:30] = normalized_center
            
        return target
    
    def RandomBrightness(self, bgr):
        if np.random.random() < 0.5:
            hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
            h, s, v = cv2.split(hsv)
            adjust = np.random.choice([0.5, 1.5])
            v = v*adjust
            v = np.clip(v, 0, 255).astype(hsv.dtype)
            hsv = cv2.merge((h, s, v))
            bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
        return bgr

    def RandomSaturation(self, bgr):
        if np.random.random() < 0.5:
            hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
            h, s, v = cv2.split(hsv)
            adjust = np.random.choice([0.5, 1.5])
            s = s*adjust
            s = np.clip(s, 0, 255).astype(hsv.dtype)
            hsv = cv2.merge((h, s, v))
            bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
        return bgr

    

In [None]:
def myimshow(image, ax=plt):
    image = image.to('cpu').numpy()
    image = np.moveaxis(image, [0, 1, 2], [2, 0, 1]) 
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1 
    h = ax.imshow(image) 
    ax.axis('off') 
    return h

In [None]:
train_data = VOCDataset(dataset_root_dir)
img, target = train_data[5]
myimshow(img);