In [11]:
import torch
import numpy as np
import json
import config as C
import os
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from utils import parse_camera_intrinsic as parse_camera_intrinsic
from scipy.spatial.transform import Rotation as R
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision import transforms as T
from torch import nn
from tqdm import tqdm
import seaborn

In [12]:
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [13]:
with open(C.TRAIN_CROPS_JSON, "r") as f:
    train_gt = json.load(f)

In [16]:
import io
class PKURegressionDataset(Dataset):
    
    def __init__(self, orientation_field, images_dir, whole_images_dir, max_size, max_whole_size, transforms=None):
        super().__init__()
        self.orientation_field = orientation_field
        self.images_dir = images_dir
        self.whole_images_dir = whole_images_dir
        self.max_size = max_size
        self.max_whole_size = max_whole_size
        self.transforms = transforms
        
        with open(C.TRAIN_CROPS_JSON, "r") as f:
            self.gt = json.load(f)
        
        annotations =[]
        for i in range(len(self.gt['annotations'])):
            wx, wy, wz = self.gt['annotations'][i]['position']
            if ((-50<wx<50) and (0<wy<50) and (0<wz<200) and (np.sqrt(wx**2 + wy**2 +wz**2) < 100)):
                annotations.append(self.gt['annotations'][i])
        self.gt['annotations'] = annotations
    
        cat_ids = set(ann['category_id'] for ann in self.gt['annotations'])
        categories = [cat for cat in self.gt['categories'] if cat['id'] in cat_ids]
        self.category_id_to_label = {
            cat["id"]: label
            for label, cat in enumerate(sorted(categories, key=lambda x: x["id"]))
        }
        self.images_jpeg = self.load_images()
        self.whole_images_jpeg, self.ann_id_to_whole_image_filename = self.load_whole_images()
        
        self.p = parse_camera_intrinsic()
        for k in self.p:
            self.p[k] = float(self.p[k])
    
    def load_images(self):
        images = {}
        for image in tqdm(self.gt['images']):
            path = os.path.join(self.images_dir, image['file_name'])
            data = open(path, 'rb').read()
            images[image['id']] = io.BytesIO(data)
        return images
    
    def load_whole_images(self):
        ann_id_to_whole_image_filename = {}
        filenames = []
        for image in self.gt['images']:
            name, ext = image['file_name'].split('.')
            ID, whole_image_name, ann_id = name.split("_")
            whole_image_name = ID + "_" + whole_image_name
            filename = whole_image_name + '.' + ext
            ann_id_to_whole_image_filename[int(ann_id)] = filename
            filenames.append(filename)
        filenames = list(set(filenames))
        whole_images_jpeg = {}
        for filename in tqdm(filenames):
            path = os.path.join(self.whole_images_dir, filename)
            data = open(path, 'rb').read()
            whole_images_jpeg[filename] = io.BytesIO(data)
        return whole_images_jpeg, ann_id_to_whole_image_filename 
    
    def __len__(self):
        return len(self.gt["annotations"])
    
    def __getitem__(self, idx):
        image = self.load_image(idx)
        whole_image,ann_id_to_whole_image_filename = self.load_whole_image(idx)
        label = self.get_label(idx)
        bbox_x, bbox_y, bbox_w, bbox_h = self.get_bbox(idx)
        bbox_center_x, bbox_center_y = bbox_x + bbox_w / 2, bbox_y + bbox_h / 2
        wx, wy, wz = self.get_position(idx)
        orientation = self.get_orientation(idx)
        filename = ann_id_to_whole_image_filename[idx]
        
        result =  dict(
            image=image,
            whole_image=whole_image,
            label=label,
            bbox=np.array([(bbox_center_x-self.p['cx'])/self.p['fx'], (bbox_center_y-self.p['cy'])/self.p['fy'], bbox_w/self.p['fx'], bbox_h/self.p['fy']]),
            position=np.array([wx, wy, wz]),
            distance=np.sqrt(wx**2 + wy**2 + wz**2),
            orientation=np.array(orientation),
            filename=filename
            )
        if self.transforms is not None:
            result['image'] = self.transforms(result['image'])
            result['whole_image'] = self.transforms(result['whole_image'])
        return result
    
    def load_image(self, idx):
        image_id = self.gt["annotations"][idx]["image_id"]
        image = self.decode_image(self.images_jpeg[image_id])
        w, h = image.size
        scale = self.max_size / max(w, h)
        w_new, h_new = int(w * scale), int(h * scale)
        image = image.resize((w_new, h_new), Image.LANCZOS)
        new_image = Image.new("RGB", (self.max_size, self.max_size))
        new_image.paste(image, ((self.max_size - w_new) // 2, (self.max_size - h_new) // 2))
        return new_image

    def load_whole_image(self, idx):
        ann_id = self.gt["annotations"][idx]["id"]
        image = self.decode_image(self.whole_images_jpeg[self.ann_id_to_whole_image_filename[ann_id]])
        w, h = image.size
        scale = self.max_whole_size / max(w, h)
        w_new, h_new = int(w * scale), int(h * scale)
        image = image.resize((w_new, h_new), Image.LANCZOS)
        new_image = Image.new("RGB", (self.max_whole_size, self.max_whole_size))
        new_image.paste(image, ((self.max_whole_size - w_new) // 2, (self.max_whole_size - h_new) // 2))
        return new_image

    @staticmethod
    def decode_image(bytes_io):
        image = Image.open(bytes_io)
        image.load()
        return image
        
    def get_label(self, idx):
        return self.category_id_to_label[self.gt["annotations"][idx]["category_id"]]
    
    def get_bbox(self, idx):
        return self.gt["annotations"][idx]["bbox"]
    
    def get_position(self, idx):
        return self.gt["annotations"][idx]["position"]
    
    def get_orientation(self, idx):
        euler_angles = -1*np.array(self.gt["annotations"][idx][self.orientation_field])
        prom = euler_angles[0]*1
        euler_angles[0] = euler_angles[1]*1
        euler_angles[1] = prom*1
        rotation = R.from_euler("YXZ", euler_angles)
        q = rotation.as_quat()
        e = np.array([0,0,1])
        q[:3] = q[:3] * np.dot(e, q[:3]) / (np.abs(np.dot(q[:3], e)))
        return q

In [17]:
transforms = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [18]:
ds = PKURegressionDataset("orientation_relative", C.TRAIN_CROPS_CALIBRATED, C.TRAIN_IMAGES, 256, 512, transforms)

100%|██████████| 49115/49115 [00:01<00:00, 48590.54it/s]
100%|██████████| 4219/4219 [00:02<00:00, 1770.69it/s]


In [19]:
batch_size = 32

In [20]:
dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=2)

In [21]:
class Model(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.big_backbone = models.resnet101(pretrained=True)
    
    def extract_nl_features(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)
        
        return x
    
    def forward(self,whole_image):
        
        nl_features = self.extract_nl_features(whole_image)
        
        return dict(nl)

In [22]:
model = Model()

In [None]:
model.cuda()

In [23]:
pbar = tqdm(total=len(dl))
for batch in dl:
    for k in batch:
        if k == 'label' or k == 'filename':
            batch[k] = batch[k].long()
        else:
            batch[k] = batch[k].float()
        batch[k] = batch[k].cuda()
    with torch.no_grad():
        outputs = model(batch['whole_image'])
    pbar.update()

  0%|          | 0/1383 [00:00<?, ?it/s]

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/ivb/.conda/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/ivb/.conda/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/ivb/.conda/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-16-3fa9103233aa>", line 67, in __getitem__
    whole_image,ann_id_to_whole_image_filename = self.load_whole_image(idx)
TypeError: cannot unpack non-iterable Image object
