In [2]:
import json
import os
from collections import defaultdict
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from pycocotools import mask
from PIL import Image
from rasterio.plot import show
import matplotlib.pyplot as plt
import base64

First attempt at making this into a pytorch dataset. I'm sure there are off-the-shelf solutions for this, but I'm too lazy to look it up right now.

Biggest issue: annotations are not in a format that PyTorch dataloaders can process. So this only works when you grab one example at a time.

In [4]:
class CocoLoader(Dataset):
    def __init__(self, annotation_file, image_dir, max_masks=64):
        self.annotation_file = annotation_file
        self.image_dir = image_dir
        self.max_masks = max_masks
        with open(self.annotation_file, 'r') as f:
            self.data = json.load(f)
        self.index_annotations()
        self.transform = transforms.ToTensor()

    def index_annotations(self):
        self.images = {}
        self.imgToAnns = defaultdict(list)
        for image in self.data['images']:
            self.images[image['id']] = image
        
        self.annotations = {}
        for ann in self.data['annotations']:
            self.annotations[ann['id']] = ann
            self.imgToAnns[ann['image_id']].append(ann)
            self.imgToAnns[ann['image_id']] = sorted(self.imgToAnns[ann['image_id']], key=lambda x: x['area'], reverse=True)
        
        self.categories = {}
        for cat in self.data['categories']:
            self.categories[cat['id']] = cat

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_record = self.images[idx]
        image = Image.open(os.path.join(self.image_dir, image_record['file_name']))
        annotations = self.imgToAnns[idx]
        return (self.transform(image), annotations)



In [5]:
dl = CocoLoader()

In [12]:
dl.images

{397133: {'license': 4,
  'file_name': '000000397133.jpg',
  'coco_url': 'http://images.cocodataset.org/val2017/000000397133.jpg',
  'height': 427,
  'width': 640,
  'date_captured': '2013-11-14 17:02:52',
  'flickr_url': 'http://farm7.staticflickr.com/6116/6255196340_da26cf2c9e_z.jpg',
  'id': 397133},
 37777: {'license': 1,
  'file_name': '000000037777.jpg',
  'coco_url': 'http://images.cocodataset.org/val2017/000000037777.jpg',
  'height': 230,
  'width': 352,
  'date_captured': '2013-11-14 20:55:31',
  'flickr_url': 'http://farm9.staticflickr.com/8429/7839199426_f6d48aa585_z.jpg',
  'id': 37777},
 252219: {'license': 4,
  'file_name': '000000252219.jpg',
  'coco_url': 'http://images.cocodataset.org/val2017/000000252219.jpg',
  'height': 428,
  'width': 640,
  'date_captured': '2013-11-14 22:32:02',
  'flickr_url': 'http://farm4.staticflickr.com/3446/3232237447_13d84bd0a1_z.jpg',
  'id': 252219},
 87038: {'license': 1,
  'file_name': '000000087038.jpg',
  'coco_url': 'http://images.

In [14]:

img, anns = dl.__getitem__(252219)

img = img.numpy().transpose((1,2,0))
img = (img * 255).astype(np.uint8)
masks = []
for an in anns:
    # decode from JSON string to bytes
    an['segmentation']['counts'] = base64.b64decode(an['segmentation']['counts'].encode('utf-8'))
    print(an['segmentation'])
    masks.append(mask.decode(an['segmentation']))

plt.figure(figsize=(12,9))
plt.imshow(img)

reds = [i for i in range(50, 220, (220-50)//len(masks))]
blues = reds.copy()
blues.reverse()

for i, m in enumerate(masks):
    mplot = np.zeros(img.shape, dtype=np.uint8)
    mplot[:,:,3] = np.where(m == 1, 255,0)
    mplot[:,:,0] = reds[i]
    mplot[:,:,1] = 30
    mplot[:,:,2] = blues[i]
    plt.imshow(mplot, alpha=0.5)  # Adjust alpha for transparency

# Display the result
plt.axis('off')
plt.show()

AttributeError: 'bytes' object has no attribute 'encode'