<a href="https://colab.research.google.com/github/ell-hol/stonks-wid-codex/blob/main/few_shot_object_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
Defines a pytorch end-to-end object detection model.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.detection import fasterrcnn_resnet50_fpn

NUM_CLASSES = 2

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = fasterrcnn_resnet50_fpn(
            pretrained=False,
            num_classes=NUM_CLASSES,
            pretrained_backbone=False,
        )

    def forward(self, x):
        return self.model(x)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))

In [12]:
        
"""
Defines a few-shot object detection Dataset based on COCO.
"""

import json
from pathlib import Path
import pickle

import numpy as np
import PIL
from sklearn.neighbors import NearestNeighbors
import torch
from torch.utils.data import Dataset

from utils import get_image_path, get_image_url

class CocoDataset(Dataset):
  def __init__(self, dataset_name, split, num_classes, num_examples):
    super().__init__()
    self.dataset_name = dataset_name
    self.split = split
    self.num_classes = num_classes
    self.num_examples = num_examples
  def __len__(self):
    return len(self.image_filenames)

class CocoDataset(Dataset):
  """
  A few-shot object detection dataset based on COCO.
  """
  def __init__(self, dataset_name, split, num_classes, num_examples):
    super().__init__()
    self.dataset_name = dataset_name
    self.split = split
    self.num_classes = num_classes
    self.num_examples = num_examples

    if self.dataset_name == 'train':
      path = Path(f'/mnt/c/Users/James/Downloads/annotations/instances_train2017.json')
    else:
      path = Path(f'/mnt/c/Users/James/Downloads/annotations/instances_val2017.json')

    data = json.load(path.open())
    self.categories = [cat for cat in data['categories'] if cat['name'] in [f'{c}' for c in range(num_classes)]][:num_classes]
    self.category_ids = [cat['id'] for cat in self.categories]
    self.category_names = [cat['name'] for cat in self.categories]

    self.image_ids = [image['id'] for image in data['images']]
    self.image_id_to_filename = {image['id']: image['file_name'] for image in data['images']}
    self.image_id_to_url = {image['id']: get_image_url(image['file_name']) for image in data['images']}
    self.image_filenames = [image['file_name'] for image in data['images']]

    self.image_id_to_bboxes = {}
    self.image_id_to_category_id = {}
    for ann in data['annotations']:
      if ann['category_id'] in self.category_ids:
        if ann['image_id'] not in self.image_id_to_bboxes:
          self.image_id_to_bboxes[ann['image_id']] = []
        if ann['image_id'] not in self.image_id_to_category_id:
          self.image_id_to_category_id[ann['image_id']] = []
        self.image_id_to_bboxes[ann['image_id']].append(ann['bbox'])
        self.image_id_to_category_id[ann['image_id']].append(ann['category_id'])

    self.image_id_to_bboxes = {k: np.array(v) for k, v in self.image_id_to_bboxes.items()}
    self.image_id_to_category_id = {k: np.array(v) for k, v in self.image_id_to_category_id.items()}

    self.nearest_neighbors = NearestNeighbors(n_neighbors=num_examples)

  def __len__(self):
    return len(self.image_filenames)

  def __getitem__(self, index):
    image_id = self.image_ids[index]
    filename = self.image_id_to_filename[image_id]
    url = self.image_id_to_url[image_id]
    bboxes = self.image_id_to_bboxes[image_id]
    category_id = self.image_id_to_category_id[image_id]

    image = PIL.Image.open(get_image_path(self.dataset_name, filename))
    image = np.array(image)

    if len(bboxes) == 0:
      bboxes = np.zeros((0, 4), dtype=np.float32)
      category_id = np.zeros((0,), dtype=np.int32)
    else:
      bboxes = (bboxes - np.array([[0.0, 0.0, 0.0, 0.0]])) / np.array([[image.shape[1], image.shape[0], image.shape[1], image.shape[0]]])

    # Randomly select num_examples bounding boxes
    if len(bboxes) > self.num_examples:
      idxs = np.random.choice(range(len(bboxes)), self.num_examples, replace=False)
      bboxes = bboxes[idxs]
      category_id = category_id[idxs]

    # Randomly flip the image horizontally
    if np.random.rand() < 0.5:
      image = image[:, ::-1, :]
      bboxes[:, [0, 2]] = 1.0 - bboxes[:, [2, 0]]

    # Resize image and bounding boxes
    image = PIL.Image.fromarray(image)
    image = image.resize((224, 224), resample=PIL.Image.BILINEAR)
    image = np.array(image)

    bboxes = torch.tensor(bboxes, dtype=torch.float32)
    category_id = torch.tensor(category_id, dtype=torch.int32)

    return image, bboxes, category_id, filename

IndentationError: ignored