<a href="https://colab.research.google.com/github/miramira227/DeformableConv1_pytorch/blob/master/Deformable_Conv%EC%9D%98_%EC%82%AC%EB%B3%B8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##**0. Device Setting**##

In [1]:
import torch 

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

print(f'torch device is {device}')

torch device is cuda:0


##**1. Resize Data Shape**##

In [2]:
from google.colab.patches import cv2_imshow
import os 
import cv2
import torch
import numpy as np 
import json 

def get_input_img(idx, mode):
  
  if mode != 'test':
    json_path = f'/content/data/annotations/instances_{mode}2017.json'
    
    with open(json_path) as f:
      json_data = json.load(f)
  
  path = f'/content/data/{mode}2017/'


  file_list = os.listdir(path)

  while idx : 
    img_jpg = file_list_jpg[idx]
  
    if folder != 'test':
      for image in json_data['images']:
        if image['file_name'] == img_jpg:      # img = '000000368456.jpg'
          img_h = image['height']
          img_w = image['width']
          img_id = image['id']
          break

      for anno in json_data['annotations']:
        if anno['image_id'] == img_id:
          img_gt_bboxes = anno['bbox']
          img_class = anno['category_id']     # 총 90개의 카테고리, 33: suitcase 7: train 
          break
      
      if (anno['image_id'] != img_id) or (image['file_name'] != img_jpg) :
        idx += 1
      else:
        break

    print(f'{idx}번째 이미지')

    image = {}
    image['img_h'] = img_h
    image['img_w'] = img_w
    image['img_id'] = img_id
    image['img_gt_bboxes'] = img_gt_bboxes
    image['img_class'] = img_class    # 38: kite 
  
  img_path = path + img_jpg
  img = cv2.imread(img_path)

  # bounding box와 class 표시
  image['img_gt_bboxes'] = [int(np.round(i)) for i in image['img_gt_bboxes']]
  cv2.rectangle(img, (image['img_gt_bboxes'][0], image['img_gt_bboxes'][1]), (image['img_gt_bboxes'][0] + image['img_gt_bboxes'][2], image['img_gt_bboxes'][1] + image['img_gt_bboxes'][3]), (0, 0, 255), 2)
  cv2.putText(img, str(image['img_class']), (image['img_gt_bboxes'][0], image['img_gt_bboxes'][1]), 1, 1, (255, 0, 0))
  cv2_imshow(img)    

  # ResNet의 input은 224 x 224로 규정 
  W = 224
  H = 224 

  if folder == 'test':
    image = {}
    image['img_h'] = H
    image['img_w'] = W

  img = cv2.resize(img, (W, H), interpolation = cv2.INTER_AREA)
  img = img.reshape(1, CH, H, W)
  img = torch.tensor(img, dtype=dtype, device=device)

  return image, img 

##**1. Deformable Conv2d**##

In [3]:
import torch.nn as nn 
import torch.nn.functional as F 
import torch 
import math
from torch.autograd import Variable
import time
import torch.optim as optim

class DeformConv2d(nn.Module):

  def __init__(self, D_in, D_out, k_size, stride = 1, padding = 0, dilation=1):    # D_in, D_out = 512, k_size = 3
    super(DeformConv2d, self).__init__()
    self.D_in = D_in          # input channel(512)
    self.D_out = D_out        # output channel(512)
    self.k_size = k_size      # kernel size = 3
    self.padding = padding
    self.stride = stride
    self.dilation = dilation
    self.off_conv = nn.Conv2d(self.D_in, 18, 3, padding=1)
    self.deform_conv = nn.Conv2d(self.D_in, self.D_out, 3)
    # self.bias = bias

  def bilinear(self, h, w, float_h, float_w):       # 2d
    return max(0, 1-abs(float_w - w)) * max(0, 1-abs(float_h - h))

  def bilinear_value(self, float_h, float_w, padded_img):            # 3d    # padded_img = (512, 9, 9)
    h_int = torch.tensor([math.floor(float_h), math.ceil(float_h)], dtype=torch.long, device=device).clamp(min=0, max=padded_img.size(2)-1)
    w_int = torch.tensor([math.floor(float_w), math.ceil(float_w)], dtype=torch.long, device=device).clamp(min=0, max=padded_img.size(2)-1)

    value = torch.zeros((padded_img.size(0)), dtype=dtype, device=device)   # (512)   -> 1-dim

    for h in h_int:
      for w in w_int:
        value += padded_img[:, h, w] * self.bilinear(h, w, float_h, float_w)
    return value # (512, 1, 1)

  def deformable_grid(self, i, j, off_field, padded_img):            # 3d 
    batch = padded_img.size(0)

    grid = torch.zeros((batch, self.D_out,) + (self.k_size,) * 2, dtype=dtype, device=device)    # (512, 3, 3)
    h = [j + self.dilation * k for k in range(self.k_size)]     # j, j+1, j+2
    w = [i + self.dilation * k for k in range(self.k_size)]     # i, i+1, i+2
    
    offset = off_field[:, :, j, i].view(batch, 2, -1)   # off_field = [batch, 18, 7, 7]    # [batch, 18]

    for b in range(batch):
      idx = 0
      for ih, hei in enumerate(h):
        for iw, wid in enumerate(w):
          grid[b, :, ih, iw] = self.bilinear_value(hei+offset[b, 0, idx], wid+offset[b, 1, idx], padded_img[b])
          idx += 1
    return grid       # grid = [512, 3, 3]

  def deformable_conv(self,in_img,padded_img, off_field):   
    out_h = int((in_img.size(2) + self.padding * 2 - self.k_size) / (self.stride * self.dilation)) + 1
    out_w = int((in_img.size(3) + self.padding * 2 - self.k_size) / (self.stride * self.dilation)) + 1

    out_img = torch.zeros(in_img.size()[:2] + (out_h, out_w), dtype = dtype, device = device)

    i = 0
    for h in range(out_img.size(2)):    # 7
      for w in range(out_img.size(3)):    # 7
        out_img[:, :, h, w] = self.deform_conv(self.deformable_grid(h, w, off_field, padded_img)).squeeze()
    return out_img

  def forward(self, x, bias=None):   # x = (batch, 512, 7, 7)
    off_field = self.off_conv(x)   # (batch, 18, 7, 7)
    padded_img = F.pad(x, (1, ) * 4)
    output_img = self.deformable_conv(x, padded_img, off_field)

    return output_img
  

##**2. ResNet101 1st~4th block**##

In [4]:
import torchvision.models as models 

model = models.resnet101(pretrained=True, progress=True)
removed = list(model.children())[:-3]
model = torch.nn.Sequential(*removed)

Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /root/.cache/torch/checkpoints/resnet101-5d3b4d8f.pth


HBox(children=(FloatProgress(value=0.0, max=178728960.0), HTML(value='')))




##**3. ResNet101 5th block**##

In [5]:
import torchvision.models as models
from torchvision.models.resnet import Bottleneck
import time
import torch.nn as nn 

class MyConv5ofResnet101(nn.Module):
  def __init__(self):
    super(MyConv5ofResnet101, self).__init__()
    self.conv_deform_d1 = DeformConv2d(512, 512, 3, padding=1, dilation=1)
    self.conv_deform_d2 = DeformConv2d(512, 512, 3, padding=1, dilation=2)

    self.conv_begin = nn.Conv2d(1024, 512, 1, bias=False)
    self.conv_no_begin = nn.Conv2d(2048, 512, 1, bias=False)
    self.conv3rd = nn.Conv2d(512, 2048, 1, bias=False)
    self.conv_downsample = nn.Conv2d(1024, 2048, 1, stride = 2, bias=False)
    self.bn_downsample = nn.BatchNorm2d(2048)
    self.bn = nn.BatchNorm2d(512)
    self.relu = nn.ReLU(inplace=False)
    self.conv_last = nn.Conv2d(2048, 1024, 1, bias=False)

  def forward(self, x):
    input = x

    for i in range(3):
      if i == 0:
        x = self.conv_begin(x)
      else:
        x = self.conv_no_begin(x)

      x = self.bn(x)
      if i == 0:
        x = self.conv_deform_d2(x)
      else:
        x = self.conv_deform_d1(x)

      x = self.bn(x)
      x = self.conv3rd(x)
      x = self.bn_downsample(x)
      x = self.relu(x)

      if i == 0:
        x += self.conv_downsample(input)
        x = self.bn_downsample(x)
    x = self.conv_last(x)

    return x       

##**4. RPN**##

In [6]:
class MyRPN(nn.Module):

  def __init__(self):
    super(MyRPN, self).__init__()
    # self.roi = DeformableRoI(7)
    self.conv1 = nn.Conv2d(1024, 256, 3, padding=1)
    self.relu = nn.ReLU(inplace=False)

    self.cls_layer = nn.Conv2d(256, 18, 1)      # for objectness  
    self.reg_layer = nn.Conv2d(256, 36, 1)

    self.softmax = nn.Softmax(dim = 2)

  def forward(self, x):    # x.size() = [batch, 1024, 7, 7]
    x = self.conv1(x)
    x = self.relu(x)

    cls = self.cls_layer(x)
    cls_size = cls.size()
    cls = cls.view(cls.size(0), 9, 2, cls.size(2), cls.size(3))
    cls = self.softmax(cls)
    cls = cls.view(cls_size)

    reg = self.reg_layer(x)
    return cls, reg     # cls = [batch, 18, 7, 7], reg = [batch, 36, 7, 7]

##**5. Make Anchor Boxes**##

In [7]:
# 0. anchor scale and ratios
def make_anchor_boxes(scales, ratios, image, cls):
  scales_list = scales.repeat(len(ratios))      # tensor([128, 256, 512, 128, 256, 512, 128, 256, 512])
  ratios_list = ratios.view(-1, 1).repeat(1, len(scales)).view(-1)   # tensor([0.5000, 0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 2.0000, 2.0000, 2.0000])

  img_h = image['img_h']      # 234
  img_w = image['img_w']      # 500
  img_id = image['img_id']
  img_gt_bboxes = image['img_gt_bboxes']     # [18.17, 155.82, 137.1, 60.01]
  img_class = image['img_class']        # 234 500

  # 1. stride 
  stride_h = int(img_h / cls.size(2))   
  stride_w = int(img_w / cls.size(3))
  # print((stride_h), int(stride_w))    # 33 71


  # 2. make base_anchor     # stride의 중간에서 시작
  start_point = torch.tensor((int(stride_h / 2), int(stride_w / 2)), dtype = dtype, device = device)    # (16, 35)  # h, w 순서 
  # print(start_point)      # tensor([16, 35])

  start_h = torch.tensor(int(stride_h / 2), dtype = dtype, device = device)
  start_w = torch.tensor(int(stride_w / 2), dtype = dtype, device = device)

  grid_h = torch.arange(start_h, img_h, stride_h, dtype = dtype, device = device)
  grid_w = torch.arange(start_w, img_w, stride_w, dtype = dtype, device = device)
  # print(grid_h)   # tensor([ 16.,  49.,  82., 115., 148., 181., 214.])
  # print(grid_w)  # tensor([ 35., 106., 177., 248., 319., 390., 461.])

  grid_hh = grid_h.view(-1, 1).repeat(1, len(grid_w))
  grid_ww = grid_w.repeat(len(grid_h), 1)
  grid_hh_flat = grid_hh.view(-1, 1)
  grid_ww_flat = grid_ww.view(-1, 1)

  grid = torch.stack([grid_ww_flat, grid_hh_flat, grid_ww_flat, grid_hh_flat], dim=-1)     # torch.Size([49, 1, 4])

  grid = grid.repeat(1, 9, 1).view(-1, 4)   # torch.Size([441, 4])  
  # anchor for regression

  anchor_w = scales_list * ratios_list    # tensor([  64.,  128.,  256.,  128.,  256.,  512.,  256.,  512., 1024.])
  anchor_h = anchor_w / ratios_list   # tensor([128., 256., 512., 128., 256., 512., 128., 256., 512.])
  anchor_w = anchor_w.view(-1, 1).repeat((grid.size(0) // anchor_w.size(0), 1))    # torch.Size([441, 1])
  anchor_h = anchor_h.view(-1, 1).repeat((grid.size(0) // anchor_h.size(0), 1))     # torch.Size([441, 1])

  anchor = torch.stack([(anchor_w // -2), (anchor_h // -2), anchor_w // 2, anchor_h // 2], dim = -1).squeeze()    # torch.Size([441, 4])

  anchor_xyxy = grid + anchor    # total anchor boxes in xyxy format , 7 x 7 x 9 x 4

  # anchor for loss calculation
  anchor_grid_x1 = anchor_xyxy[:, 0]
  anchor_grid_y1 = anchor_xyxy[:, 1]
  anchor_grid_x2 = anchor_xyxy[:, 2]
  anchor_grid_y2 = anchor_xyxy[:, 3]

  anchor_x_ctr = (anchor_grid_x1 + anchor_grid_x2) // 2
  anchor_y_ctr = (anchor_grid_y1 + anchor_grid_y2) // 2
  anchor_w = (anchor_grid_x2 - anchor_grid_x1)
  anchor_h = (anchor_grid_y2 - anchor_grid_y1)

  anchor_xywh = torch.stack([anchor_x_ctr, anchor_y_ctr, anchor_w, anchor_h], dim = -1)   # total anchor boxes in xctr-yctr-w-h format 
  return anchor_xyxy, anchor_xywh

##**6. Get IoU, regressed_gt box**##

In [8]:
def get_gt_xyxy(gt_bboxes):
  gt_box = gt_bboxes
  gt_xyxy = torch.tensor([gt_box[0], gt_box[1], gt_box[0]+gt_box[2], gt_box[1] + gt_box[3]]) 
  return gt_xyxy


def get_IoU(gt_xyxy, anchor_xyxy):
  gt_area = (gt_xyxy[2]-gt_xyxy[0]) * (gt_xyxy[3] - gt_xyxy[1])
  anchor_area = (anchor_xyxy[:, 2]-anchor_xyxy[:, 0]) * (anchor_xyxy[:, 3] - anchor_xyxy[:, 1])

  inter_area = torch.zeros_like(anchor_area)

  cond1 = torch.min(anchor_xyxy[:, 2], gt_xyxy[2].float()) - torch.max(anchor_xyxy[:, 0], gt_xyxy[0].float()) 
  cond2 = torch.min(anchor_xyxy[:, 3], gt_xyxy[3].float()) - torch.max(anchor_xyxy[:, 1], gt_xyxy[1].float()) 

  inter_area = cond1 * cond2
  inter_area = inter_area.where((cond1 >= 0) & (cond2 >= 0), torch.tensor(0).float())   # 아닐 때의 값   

  IoU = inter_area / (gt_area + anchor_area - inter_area)  # IoU for each anchor with gt box    # torch.Size([441])
  return IoU


def make_label(IoU): 
  # 1. default = zero(ignore)
  label = torch.zeros_like(IoU)

  # 2. negative 
  label[IoU < 0.3] = -1

  # 3. positive > 0.7
  label[IoU > 0.7] = 1

  # 4. max > 0.7
  label[torch.argmax(IoU)] = 1
  return label 


# regress 
def get_gt_reg(gt_xyxy, anchor_xywh):
  gt_x_ctr = (gt_xyxy[0] + gt_xyxy[2]) // 2
  gt_y_ctr = (gt_xyxy[1] + gt_xyxy[3]) // 2
  gt_w = gt_xyxy[2] - gt_xyxy[0]
  gt_h = gt_xyxy[3] - gt_xyxy[1]

  pos_labeled_anchor = anchor_xywh[label == 1]

  if len(pos_labeled_anchor) > 1:
    [pos_x_ctr,pos_y_ctr, pos_w, pos_h] = [pos_labeled_anchor[:, i] for i in range(4)]
  else:
    pos_labeled_anchor = pos_labeled_anchor.squeeze()
    [pos_x_ctr,pos_y_ctr, pos_w, pos_h] = [pos_labeled_anchor[i] for i in range(4)]

  gt_tx = (gt_x_ctr - pos_x_ctr) / pos_w
  gt_ty = (gt_y_ctr - pos_y_ctr) / pos_h
  gt_tw = torch.log(gt_w / pos_w)
  gt_th = torch.log(gt_h / pos_h)

  gt_reg = torch.stack([gt_tx, gt_ty, gt_tw, gt_th], dim = -1)
  return gt_reg
  

##**Training RPN Network**##

In [None]:
import torch.optim as optim
import os 
import torch.nn as nn
import time 

scales = torch.tensor([128, 256, 512])
ratios = torch.tensor([0.5, 1, 2])

# merge modules 
customResNet101 = nn.Sequential(model, MyConv5ofResnet101(), MyRPN())

optimizer = optim.SGD(customResNet101.parameters(), lr=0.001, momentum=0.9)

for param in model.parameters():
  param.requires_grad = False

startTime = time.time()

epochs = 50

for epoch in range(epochs):
  running_loss = 0

  for i in range(4000):
    image, x = get_input_img(i, 'train')    # 'val'

    cls, reg = customResNet101(x) 
    anchor_xyxy, anchor_xywh = make_anchor_boxes(scales, ratios, image, cls)

    pred_cls = cls.permute(0, 2, 3, 1).reshape(-1, 2)[:, 1]

    gt_bboxes = image['img_gt_bboxes']
    gt_xyxy = get_gt_xyxy(gt_bboxes)
    IoU = get_IoU(gt_xyxy, anchor_xyxy)
    label = make_label(IoU)

    loss = nn.BCELoss()    
    cls_loss = loss(pred_cls, label)    

    _, gt_reg = get_gt_reg(gt_xyxy, anchor_xyxy)
    pred_reg = reg.permute(0, 2, 3, 1).reshape(-1, 4)[label == 1].squeeze()

    loss = nn.SmoothL1Loss()
    reg_loss = loss(gt_reg, pred_reg)
    total_loss = cls_loss / 256 + reg_loss * 10 / len(label)

    optimizer.zero_grad()

    total_loss.backward()

    optimizer.step()

    running_loss += total_loss.item()
    if i % 100 == 99:
      print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 100))
      running_loss = 0.0

print('Finished Training')

##**Regression**##

In [None]:
# regressed anchor

def regression(anchor, pred_reg):

  anchor = anchor.squeeze()
  [pos_x_ctr,pos_y_ctr, pos_w, pos_h] = [anchor[i] for i in range(4)]

  pred_reg = pred_reg.squeeze()

  rgrsd_x_ctr = pred_reg[0] * pos_w + pos_x_ctr
  rgrsd_y_ctr = pred_reg[1] * pos_h + pos_y_ctr
  rgrsd_w = torch.exp(pred_reg[2]) * pos_w
  rgrsd_h = torch.exp(pred_reg[3]) * pos_h

  rgrsd_anchor = torch.stack([rgrsd_x_ctr, rgrsd_y_ctr, rgrsd_w, rgrsd_h], dim = -1)
  return rgrsd_anchor

##**Inference**##

In [None]:
import torch.optim as optim
import os 
import torch.nn as nn
import cv2
from google.colab.patches import cv2_imshow 
import numpy as np

path = "/content/data/test2017/" 

file_list = os.listdir(path)
file_list_jpg = [file for file in file_list if file.endswith('.jpg')]

scales = torch.tensor([128, 256, 512])
ratios = torch.tensor([0.5, 1, 2])

i = 7   # random number  

image, x = get_input_img(i, 'test')

cls, reg = customResNet101(x) 
anchor_xyxy, anchor_xywh = make_anchor_boxes(scales, ratios, image, cls)

pred_cls = cls.permute(0, 2, 3, 1).reshape(-1, 2)[:, 1]
pred_reg = reg.permute(0, 2, 3, 1).reshape(-1, 4)

anchor = anchor_xyxy[torch.argmax(pred_cls)]
rgrsd_anchor = regression(anchor, pred_reg)

if len(rgrsd_anchor) > 1:
  rgrsd_coord = torch.stack([rgrsd_anchor[:, 0] - (rgrsd_anchor[:, 2] / 2), rgrsd_anchor[:, 1] - (rgrsd_anchor[:, 3] / 2), rgrsd_anchor[:, 0] + (rgrsd_anchor[:, 2] / 2), rgrsd_anchor[:, 1] + (rgrsd_anchor[:, 3] / 2)], dim = -1).int()

img = x.cpu().reshape(224, 224, 3).numpy()

for i in range(len(rgrsd_coord)):
  cv2.rectangle(img, (rgrsd_coord[i, 0], rgrsd_coord[i, 1]), (rgrsd_coord[i, 2], rgrsd_coord[i, 3]), (0, 0, 255), 2)
  cv2.putText(img, str(rgrsd_coord) + str(i), (rgrsd_coord[i, 0], rgrsd_coord[i, 1]), 1, 1, (0, 0, 255))

cv2_imshow(img) 

##**NMS**##

In [None]:
from torchvision.ops import nms 

def nms(anchor_xyxy, pred_cls, label):
  pos_labeled_anchor = anchor_xyxy[label == 1]
  score = pred_cls[label == 1]
  assert len(score) == len(pos_labeled_anchor)

  idx = torch.nms(pos_labeled_anchor, score, 0.7)

  proposals = pos_labeled_anchor[idx]
  return proposals