In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import pathlib
import shutil
import cv2
import os
import json
from PIL import Image
# import utils

import torch
import torchvision
import torchvision.transforms as T
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [73]:
class SugarcaneWeedsDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transforms=None):
        super().__init__()
        
        self.img_ids = list(sorted(os.listdir(os.path.join(root, "jpg_images"))))
        self.img_dir = pathlib.Path(data_dir) / "jpg_images"
        with open(os.path.join(data_dir, "labels.json")) as f:
            self.annot = json.load(f)
        self.transforms = transforms
        
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        annotations = self.annot[img_id]
        print(annotations)
        
        img = Image.open(self.img_dir/img_id).convert('RGB')
        img = T.ToTensor()(img)
        
        box_dict = annotations['regions']
        obj_classes = [x['region_attributes']['label'] for x in box_dict]
        num_classes = len(set(obj_classes))
        
        # get bounding box coordinates for each mask
        num_objs = len(obj_classes)
        boxes = []
        for i in range(num_objs):
            bd   = box_dict[i]['shape_attributes']
            xmin = bd['x']
            xmax = bd['x'] + bd['width']
            ymin = bd['y']
            ymax = bd['y'] + bd['height']
            boxes.append([xmin, ymin, xmax, ymax])
        boxes = torch.tensor(boxes, dtype=torch.int64)
        
        # instance labels
        labels = torch.ones((num_objs,), dtype=torch.int64)
        
        # class labels - TODO
        # or just use weed labels
        
        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['img_id'] = torch.tensor([idx])
        
        return img, target, img_id
    
    def __len__(self):
        return len(self.img_ids)
    

In [77]:
# get the pretrained model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [None]:
# 

In [None]:
# Filter out if too many boxes for a single instance
keep = torchvision.ops.nms(boxes, scores, 0.3)

In [74]:
data_dir = '/Users/mschoder/weeding_project/box_labeled_data'
test = SugarcaneWeedsDataset(data_dir)
test

<__main__.SugarcaneWeedsDataset at 0x13cdfba90>

In [75]:
test.__getitem__(302)

{'filename': 'i3092.jpg', 'size': 6182682, 'regions': [{'shape_attributes': {'name': 'rect', 'x': 887, 'y': 3249, 'width': 1577, 'height': 1352}, 'region_attributes': {'label': 'weed'}}, {'shape_attributes': {'name': 'rect', 'x': 7, 'y': 0, 'width': 3058, 'height': 1522}, 'region_attributes': {'label': 'weed'}}, {'shape_attributes': {'name': 'rect', 'x': 608, 'y': 1563, 'width': 2444, 'height': 1673}, 'region_attributes': {'label': 'weed'}}, {'shape_attributes': {'name': 'rect', 'x': 2499, 'y': 3639, 'width': 560, 'height': 963}, 'region_attributes': {'label': 'weed'}}, {'shape_attributes': {'name': 'rect', 'x': 20, 'y': 20, 'width': 3017, 'height': 4560}, 'region_attributes': {'label': 'sugarcane'}}], 'file_attributes': {}}


(tensor([[[0.1843, 0.1882, 0.1961,  ..., 0.3725, 0.3608, 0.3569],
          [0.1843, 0.1922, 0.1961,  ..., 0.3451, 0.3373, 0.3490],
          [0.1843, 0.1961, 0.2000,  ..., 0.3333, 0.3373, 0.3608],
          ...,
          [0.3608, 0.3725, 0.3961,  ..., 0.2902, 0.3294, 0.3569],
          [0.2824, 0.3059, 0.3294,  ..., 0.3569, 0.3725, 0.3569],
          [0.2039, 0.2314, 0.2471,  ..., 0.3882, 0.3686, 0.3529]],
 
         [[0.1725, 0.1765, 0.1843,  ..., 0.5137, 0.4941, 0.4745],
          [0.1725, 0.1804, 0.1843,  ..., 0.4784, 0.4549, 0.4588],
          [0.1725, 0.1843, 0.1882,  ..., 0.4431, 0.4353, 0.4471],
          ...,
          [0.3216, 0.3333, 0.3608,  ..., 0.3608, 0.4039, 0.4314],
          [0.2549, 0.2784, 0.3020,  ..., 0.4314, 0.4510, 0.4353],
          [0.1882, 0.2157, 0.2314,  ..., 0.4706, 0.4588, 0.4431]],
 
         [[0.1451, 0.1490, 0.1569,  ..., 0.1294, 0.1255, 0.1216],
          [0.1451, 0.1529, 0.1569,  ..., 0.1098, 0.1098, 0.1255],
          [0.1451, 0.1569, 0.1608,  ...,