<a href="https://colab.research.google.com/github/njmarko/gat-or/blob/master/gat_or.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## GATor

GATor is a Graph Attention Network for object detection with relational reasoning.

### Installing dependencies and downloading dataset

In [None]:
!python -c "import torch; print(torch.__version__)"
!pip uninstall opencv-python-headless -y
!pip install opencv-python-headless==4.1.2.30
!pip install fiftyone
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install torch-geometric

In [None]:

%%shell

# Download TorchVision repo to use some files from
# references/detection
git clone https://github.com/pytorch/vision.git
cd vision
git checkout v0.8.2

cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/coco_eval.py ../
cp references/detection/engine.py ../
cp references/detection/coco_utils.py ../

cd ..
rm -r vision

In [None]:
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.utils.coco as fouc
from fiftyone import ViewField as VF
import numpy as np
from itertools import product
from tqdm import tqdm
import seaborn as sns
import networkx as nx
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.transforms import functional as TVF
from torchvision.ops import box_iou
from transforms import Compose, ToTensor
from PIL import Image
import matplotlib.pylab as plt
import matplotlib.patches as patches
from matplotlib.pyplot import figure
import torch_geometric as pyg
import torch_geometric.data as pyg_data
from torch_geometric.nn import GATConv


from engine import train_one_epoch, evaluate
import utils

In [None]:
dataset = foz.load_zoo_dataset(
    "coco-2017",
    split="train",
    max_samples=200
)
dataset.persistent = True

session = fo.launch_app(dataset)

### Visualizing co-occurrence matrix

In [None]:
class_names = ['background','person','bicycle','car','motorcycle','airplane','bus','train','truck','boat','traffic light','fire hydrant','street sign','stop sign','parking meter','bench','bird','cat','dog','horse',
'sheep','cow','elephant','bear','zebra','giraffe','hat','backpack','umbrella','shoe','eye glasses','handbag','tie','suitcase','frisbee','skis','snowboard','sports ball','kite','baseball bat','baseball glove','skateboard',
'surfboard','tennis racket','bottle','plate','wine glass','cup','fork','knife','spoon','bowl','banana','apple','sandwich','orange','broccoli','carrot','hot dog','pizza','donut','cake','chair','couch','potted plant','bed',
'mirror','dining table','window','desk','toilet','door','tv','laptop','mouse','remote','keyboard','cell phone','microwave','oven','toaster','sink','refrigerator','blender','book','clock','vase','scissors','teddy bear',
'hair drier','toothbrush','hair brush'
]

class_name_to_idx = {elem:idx for idx, elem in enumerate(class_names)}
n_classes = len(class_names)
co_occurrence_matrix = np.zeros((n_classes, n_classes))

with fo.ProgressBar() as pb:
  for sample in pb(dataset):
    detections = sample.ground_truth.detections
    for attention_pair in product(detections, detections):
      i = class_name_to_idx[attention_pair[0].label]
      j = class_name_to_idx[attention_pair[1].label]
      co_occurrence_matrix[i][j] += 1
      co_occurrence_matrix[j][i] += 1

In [None]:
normalized_matrix = co_occurrence_matrix / np.amax(co_occurrence_matrix + 1.0, axis=1)

figure(figsize=(20, 20), dpi=80)
ax = sns.heatmap(normalized_matrix, linewidth=0.5, xticklabels=class_names, yticklabels=class_names, cmap="Blues")
plt.show()

### Defining PyTorch dataset

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
class Resize(object):
  def __init__(self, size, interpolation=TVF.InterpolationMode.BILINEAR, max_size=None, antialias=None):
    self.size = size
    self.max_size = max_size
    self.interpolation = interpolation
    self.antialias = antialias

  def __call__(self, image, target):
    assert torch.is_tensor(image), "Image is expected to be of type torch.tensor"
    original_shape = image.shape
    image = TVF.resize(image, self.size, self.interpolation, self.max_size, self.antialias)
    transformed_shape = image.shape
    x_scale = transformed_shape[2] / original_shape[2]
    y_scale = transformed_shape[1] / original_shape[1]
    scale_tensor = torch.tensor([x_scale, y_scale]).repeat(target['boxes'].shape[0], 2)
    target['boxes'] = torch.mul(target['boxes'], scale_tensor)
    return image, target

In [None]:
class TorchCocoDataset(torch.utils.data.Dataset):
  def __init__(self, fiftyone_dataset, transforms=None, gt_field="ground_truth"):
    self.samples = fiftyone_dataset
    self.transforms = transforms
    self.gt_field = gt_field
    self.img_paths = self.samples.values("filepath")
    self.classes = class_names
    self.labels_map_rev = class_name_to_idx

  def __getitem__(self, idx):
    img_path = self.img_paths[idx]
    sample = self.samples[img_path]
    metadata = sample.metadata
    img = Image.open(img_path).convert("RGB")

    boxes = []
    labels = []
    area = []
    iscrowd = []
    detections = sample[self.gt_field].detections

    for det in detections:
      category_id = self.labels_map_rev[det.label]
      coco_obj = fouc.COCOObject.from_label(
          det, metadata, category_id=category_id,
      )
      x, y, w, h = coco_obj.bbox
      boxes.append([x, y, x + w, y + h])
      labels.append(coco_obj.category_id)
      area.append(coco_obj.area)
      iscrowd.append(coco_obj.iscrowd)

    target = {}
    target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
    target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
    target["image_id"] = torch.as_tensor([idx])
    target["area"] = torch.as_tensor(area, dtype=torch.float32)
    target["iscrowd"] = torch.as_tensor(iscrowd, dtype=torch.int64)

    if self.transforms:
        img, target = self.transforms(img, target)

    return img, target

  def __len__(self):
    return len(self.img_paths)

  def get_classes(self):
    return self.classes

In [None]:
has_bounding_box_view = dataset.filter_labels(
        "ground_truth",
        VF("label").is_in(class_names)
)

train_transforms = Compose([ToTensor(), Resize(size=(512, 512))])
test_transforms = Compose([ToTensor()])

train_view = has_bounding_box_view.take(150, seed=51)
test_view = has_bounding_box_view.exclude([s.id for s in train_view])

print(f"Number of training samples: {len(train_view)}")
print(f"Number of test samples: {len(test_view)}")

torch_dataset_train = TorchCocoDataset(train_view, train_transforms)
torch_dataset_test = TorchCocoDataset(test_view, test_transforms)

In [None]:
def draw_image_with_bounding_boxes(image, target):
  fig, ax = plt.subplots()
  ax.imshow(image.permute(1, 2, 0))
  for box in target['boxes']:
    x = box[0]
    y = box[1]
    h = box[2] - x
    w = box[3] - y
    rect = patches.Rectangle((x, y), h, w, linewidth=1, edgecolor='r', facecolor='none')
    ax.add_patch(rect)
  plt.show()

In [None]:
draw_image_with_bounding_boxes(*torch_dataset_train[0])

### Defining the model

In [None]:
class GATorGraph(object):
  def __init__(self):
    pass
  
  def __call__(self, x, batched_boxes):
    # print(f"GATorGraph: x shape before: {x.shape}")
    edges = []
    for index, boxes in enumerate(batched_boxes):
      ious = box_iou(boxes, boxes)
      ious = torch.triu(ious, diagonal=1)
      edge_list = (ious > 0.5).nonzero(as_tuple=False)
      # print(f"GATorGraph: x shape in for loop: {x[index].shape}") 
      edges.append(pyg_data.Data(x=x[index], edge_index=edge_list.T))

    batched_edge_lists = pyg_data.Batch.from_data_list(edges)
    # print(f"GATorGraph: x shape after: {x.shape}")
    # print(f"GATorGraph: data num nodes: {edges[0]}")
    return batched_edge_lists

In [None]:
class GATor(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim, layers, heads, dropout, edge_dim=None):
    super(GATor, self).__init__()
    self.convs = nn.ModuleList()
    self.convs.append(GATConv(in_channels=input_dim, out_channels=hidden_dim, heads=heads, edge_dim=edge_dim))
    for l in range(layers-1):
      self.convs.append(GATConv(in_channels=heads * hidden_dim, out_channels=hidden_dim, heads=heads, edge_dim=edge_dim))
    # post-message-passing
    self.post_mp = nn.Sequential(
      nn.Linear(heads * hidden_dim, hidden_dim), nn.Dropout(dropout), 
      nn.Linear(hidden_dim, output_dim))
    self.dropout = dropout
    self.num_layers = layers


  def forward(self, data):
    x, edge_index, batch = data.x, data.edge_index, data.batch
    # print(f"GATor x shape: {x.shape}")  
    # print(f"GATor edge_index shape: {edge_index.shape}")
    for i in range(self.num_layers):
      # print(f"GATor current layer index: {i}")
      # print(f"GATor current x shape: {x.shape}")
      x = self.convs[i](x, edge_index)
      x = F.relu(x)
      x = F.dropout(x, p=self.dropout, training=self.training)
    x = self.post_mp(x)
    return x

In [None]:
class MultiScaleRoIAlignWrapper(nn.Module):
  def __init__(self, roi_align):
    super(MultiScaleRoIAlignWrapper, self).__init__()
    self.roi_align = roi_align

  def forward(self, x, boxes, image_shapes):
    # print(f"MULTI SCALE ROI: X dictionary: {x}")
    # print(f"MULTI SCALE ROI: boxes len: {len(boxes)}")
    # print(f"MULTI SCALE ROI: boxe[0] shape: {boxes[0].shape}")
    # print(f"MULTI SCALE ROI: image_shapes len: {len(image_shapes)}")
    # for a in x:
    #   print(f"MULTI SCALE ROI: shapes o elements in X: {a}")
    ret_val = self.roi_align(x, boxes, image_shapes)
    # print(f"MULTI SCALE ROI: retval: {ret_val}")
    # print(f"MULTI SCALE ROI: retval shape: {ret_val.shape}")
    cached['boxes'] = boxes
    return ret_val

In [None]:
class TwoMLPHeadWrapper(nn.Module):
  def __init__(self, box_head, gator_graph, gator):
    super(TwoMLPHeadWrapper, self).__init__()
    self.box_head = box_head
    self.gator_graph = gator_graph
    self.gator = gator

  def forward(self, x):
    # print(f"TwoMLPHeadWrapper: x shape: {x.shape}")
    node_features = self.box_head(x)
    # print(f"TwoMLPHeadWrapper: node_features shape: {node_features.shape}")
    B = len(cached['boxes'])
    graph = self.gator_graph(node_features.view(B,-1, node_features.shape[1]), cached['boxes'])
    ret_val = self.gator(graph)
    # print(f"TwoMLPHeadWrapper: Evo me")
    return ret_val

In [None]:
cached = {}
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
gator_graph = GATorGraph()
gator = GATor(input_dim=1024, hidden_dim=1024, output_dim=1024, layers=2, heads=8, dropout=0.5, edge_dim=None)
# gat = GATModule(model.roi_heads.box_head, gator_graph)
# model.rpn = RPNWrapper(model.rpn)
model.roi_heads.box_roi_pool = MultiScaleRoIAlignWrapper(model.roi_heads.box_roi_pool)
model.roi_heads.box_head = TwoMLPHeadWrapper(model.roi_heads.box_head, gator_graph, gator)
# model.box_predictor.reset_parameters()
# model.roi_heads.box_head = gat
model.to(device)

### Defining PyTorch dataloaders

In [None]:
def collate_fn(batch):
  return tuple(zip(*batch))

In [None]:
data_loader_train = DataLoader(torch_dataset_train, batch_size=4, shuffle=True, collate_fn=collate_fn, pin_memory=True)
data_loader_test = DataLoader(torch_dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)

### Training

In [None]:
def train(model, train_loader, test_loader, n_epochs):
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  model.train()
  for param in model.parameters():
    param.requires_grad = False
  for param in model.roi_heads.box_head.parameters():
    param.requires_grad = True
  for param in model.roi_heads.box_predictor.parameters():
    param.requires_grad = True
  model.to(device)
  params = [p for p in model.parameters() if p.requires_grad]
  optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

  lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=3,
                                                gamma=0.1)

  for epoch in range(n_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, test_loader, device=device)

In [None]:
train(model, data_loader_train, data_loader_test, n_epochs=2)

### Evaluation

In [None]:
def convert_torch_predictions(preds, det_id, s_id, w, h, classes):
  # Convert the outputs of the torch model into a FiftyOne Detections object
  dets = []
  for bbox, label, score in zip(preds["boxes"].cpu().detach().numpy(), preds["labels"].cpu().detach().numpy(), preds["scores"].cpu().detach().numpy()):
    # Parse prediction into FiftyOne Detection object
    x0,y0,x1,y1 = bbox
    coco_obj = fouc.COCOObject(det_id, s_id, int(label), [x0, y0, x1-x0, y1-y0])
    det = coco_obj.to_detection((w,h), classes)
    det["confidence"] = float(score)
    dets.append(det)
    det_id += 1
        
  detections = fo.Detections(detections=dets)
      
  return detections, det_id

In [None]:
def add_detections(model, torch_dataset, view, field_name="predictions"):
  # Run inference on a dataset and add results to FiftyOne
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

  model.eval()
  model.to(device)
  image_paths = torch_dataset.img_paths
  classes = torch_dataset.classes
  det_id = 0
  
  print("Adding detections to FiftyOne view...")
  with fo.ProgressBar() as pb:
    for img, targets in pb(torch_dataset):
      # Get FiftyOne sample indexed by unique image filepath
      img_id = int(targets["image_id"][0])
      img_path = image_paths[img_id]
      sample = view[img_path]
      s_id = sample.id
      w = sample.metadata["width"]
      h = sample.metadata["height"]
      
      # Inference
      preds = model(img.unsqueeze(0).to(device))[0]
      
      detections, det_id = convert_torch_predictions(preds, det_id, s_id, w, h, classes)
      
      sample[field_name] = detections
      sample.save()

In [None]:
add_detections(model, torch_dataset_test, test_view)

results = fo.evaluate_detections(test_view, "predictions", eval_key="eval", compute_mAP=True) 

In [None]:
print(results.mAP())

In [None]:
results.plot_confusion_matrix()

In [None]:
results.plot_pr_curves()

In [None]:
session.view = test_view