<a href="https://colab.research.google.com/github/mot1122/study_pytorch/blob/main/2_4_2_5__SSD_model_forward.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
from math import sqrt
from itertools import product

import pandas as pd
import torch
from torch.autograd import Function
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

In [4]:
def make_vgg():
  cfg=[64,64,"M",128,128,"M",256,256,256,"MC",512,512,512,"M",512,512,512]
  layers=[]
  in_channels=3
  for m in cfg:
    if m=="M":layers+=[nn.MaxPool2d(kernel_size=2,stride=2)]
    elif m=="MC":layers+=[nn.MaxPool2d(kernel_size=2,stride=2,ceil_mode=True)]
    else:
      layers+=[nn.Conv2d(in_channels,m,kernel_size=3,padding=1),nn.ReLU(inplace=True)]
      in_channels=m
      
  layers+=[nn.MaxPool2d(kernel_size=3,stride=1,padding=1)]
  layers+=[nn.Conv2d(in_channels,1024,kernel_size=3,padding=6,dilation=6),nn.ReLU(inplace=True)]
  in_channels=1024
  layers+=[nn.Conv2d(in_channels,1024,kernel_size=1),nn.ReLU(inplace=True)]
  return nn.ModuleList(layers)


In [5]:
vgg_test=make_vgg()
print(vgg_test)

ModuleList(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
  (17): Conv2d(256, 512, kernel_siz

In [6]:
def make_extras():
  layers=[]
  in_channel=1024
  for i in range(8):
    out_channel=[256, 512, 128, 256, 128, 256, 128, 256][i]
    kernel_size=[1,3][i%2]
    if i in [1,3]:stride,padding=2,1
    else:stride,padding=1,0
    layers+=[nn.Conv2d(in_channel,out_channel,kernel_size=kernel_size,stride=stride,padding=padding)]
    in_channel=out_channel
  return nn.ModuleList(layers)

In [7]:
extras_test=make_extras()
print(extras_test)

ModuleList(
  (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
  (1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (2): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
  (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (4): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
  (5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
  (6): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
  (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
)


In [8]:
def make_loc_conf(num_classes=21, bbox_aspect_num=[4, 6, 6, 6, 4, 4]):
  loc_layers,conf_layers=[],[]
  in_channels=[512,1024,512,256,256,256]
  for in_channel,an in zip(in_channels,bbox_aspect_num):
    loc_layers+=[nn.Conv2d(in_channel,an*4,kernel_size=3,padding=1)]
    conf_layers+=[nn.Conv2d(in_channel,an*num_classes,kernel_size=3,padding=1)]
  return nn.ModuleList(loc_layers),nn.ModuleList(conf_layers)

In [9]:
loc_test,conf_test=make_loc_conf()
print(loc_test)
print(conf_test)

ModuleList(
  (0): Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (2): Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
ModuleList(
  (0): Conv2d(512, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): Conv2d(1024, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (2): Conv2d(512, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): Conv2d(256, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)


In [10]:
class L2Norm(nn.Module):
  def __init__(self,in_channels=512,scale=20):
    super().__init__()
    self.in_channels=in_channels
    self.weight=nn.Parameter(torch.Tensor(in_channels))
    self.scale=scale
    self.reset_parameters()
    self.eps=1e-10
  def reset_parameters(self):
    init.constant_(self.weight,self.scale)
  def forward(self,x):
    norm=x.pow(2).sum(dim=1,keepdim=True).sqrt()+self.eps
    torch.div(x,norm)
    weights=self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x)
    return x*weights

In [11]:
ssd_cfg={
    "num_classes":21,
    "image_size":300,
    "bbox_aspect_num":[4,6,6,6,4,4],
    "feature_maps":[38, 19, 10, 5, 3, 1],
    "steps":[8, 16, 32, 64, 100, 300],
    "min_sizes":[30, 60, 111, 162, 213, 264],
    "max_sizes":[60, 111, 162, 213, 264, 315],
    "aspect_ratios":[[2], [2, 3], [2, 3], [2, 3], [2], [2]],
}

In [12]:
class DBox(object):
  def __init__(self,cfg):
    super().__init__()
    self.image_size=cfg["image_size"]
    self.feature_maps=cfg["feature_maps"]
    self.num_priors=len(cfg["feature_maps"])
    self.steps=cfg["steps"]
    self.min_sizes=cfg["min_sizes"]
    self.max_sizes=cfg["max_sizes"]
    self.aspect_ratios=cfg["aspect_ratios"]
  def make_dbox_list(self):
    mean=[]
    for map,step,min,max,ar in zip(self.feature_maps,self.steps,self.min_sizes,self.max_sizes,self.aspect_ratios):
      for i,j in product(range(map),repeat=2):
        f_k=300/step
        cx,cy=(j+0.5)/f_k, (i+0.5)/f_k
        s_k=min/self.image_size
        mean+=[cx,cy,s_k,s_k]
        s_k_max=sqrt(s_k*max/self.image_size)
        mean+=[cx,cy,s_k_max,s_k_max]
        for e in ar:
          mean+=[cx,cy,s_k*sqrt(e),s_k/sqrt(e)]
          mean+=[cx,cy,s_k/sqrt(e),s_k*sqrt(e)]
    output=torch.Tensor(mean).view(-1,4)
    output.clamp_(min=0,max=1)
    return output

In [13]:
dbox=DBox(ssd_cfg)
dbox_list=dbox.make_dbox_list()
pd.DataFrame(dbox_list.numpy())

Unnamed: 0,0,1,2,3
0,0.013333,0.013333,0.100000,0.100000
1,0.013333,0.013333,0.141421,0.141421
2,0.013333,0.013333,0.141421,0.070711
3,0.013333,0.013333,0.070711,0.141421
4,0.040000,0.013333,0.100000,0.100000
...,...,...,...,...
8727,0.833333,0.833333,0.502046,1.000000
8728,0.500000,0.500000,0.880000,0.880000
8729,0.500000,0.500000,0.961249,0.961249
8730,0.500000,0.500000,1.000000,0.622254


In [14]:
class SSD(nn.Module):
  def __init__(self,phase,cfg):
    super().__init__()
    self.phase=phase
    self.num_classes=cfg["num_classes"]
    self.vgg=make_vgg()
    self.extras=make_extras()
    self.L2Norm=L2Norm()
    self.loc,self.conf=make_loc_conf(cfg["num_classes"],cfg["bbox_aspect_num"])
    self.dbox=DBox(cfg)
    self.dbox_list=dbox.make_dbox_list()
    if phase=="inference":self.detect=Detect()

In [15]:
ssd_test=SSD("train",cfg=ssd_cfg)

In [16]:
print(ssd_test)

SSD(
  (vgg): ModuleList(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, cei

In [17]:
def decode(loc,dbox_list):
  boxes=torch.cat((
      dbox_list[:,:2]+loc[:,:2]*0.1*dbox_list[:,2:],
      dbox_list[:,2:]*torch.exp(loc[:,2:]*0.2)
  ),dim=1)
  
  boxes[:,:2]-=boxes[:,2:]/2
  boxes[:,2:]+=boxes[:,:2]
  return boxes

In [18]:
def nm_supression(boxes,scores,overlap=0.45,top_k=200):
  count=0
  keep=scores.new(scores.size(0)).zero_().long()
  x_min=boxes[:,0]
  y_min=boxes[:,1]
  x_max=boxes[:,2]
  y_max=boxes[:,3]
  area=torch.mul(x_max-x_min,y_max-y_min)
  tmp_x_min=tmp_y_min=tmp_x_max=tmp_y_max=tmp_w=tmp_h=boxes.new()
  v,idx=scores.sort(0)
  idx=idx[-top_k:]
  while idx.numel()>0:
    i=idx[-1]
    keep[count]=i
    count+=1
    if idx.size(0)==1:break
    idx=idx[:-1]
    torch.index_select(x_min,0,idx,out=tmp_x_min)
    torch.index_select(y_min,0,idx,out=tmp_y_min)
    torch.index_select(x_max,0,idx,out=tmp_x_max)
    torch.index_select(y_max,0,idx,out=tmp_y_max)
    tmp_x_min=torch.clamp(tmp_x_min,min=x_min[i])
    tmp_y_min=torch.clamp(tmp_y_min,min=y_min[i])
    tmp_x_max=torch.clamp(tmp_x_max,min=x_max[i])
    tmp_y_max=torch.clamp(tmp_y_max,min=y_max[i])
    tmp_w.resize_as_(tmp_x_max)
    tmp_h.resize_as_(tmp_y_max)
    tmp_w=tmp_x_max-tmp_x_min
    tmp_h=tmp_y_max-tmp_y_min
    inter=tmp_w*tmp_h
    rem_areas=torch.index_select(area,0,idx)
    union=(rem_areas-inter)+area[i]
    IoU=inter/union
    idx=idx[IoU.le(overlap)]
  return keep,count

In [19]:
class Detect():
  def __init__(self,conf_thresh=0.1,top_k=200,nms_thresh=0.45):
    self.softmax=nn.Softmax(dim=-1)
    self.conf_thresh=conf_thresh
    self.top_k=top_k
    self.nms_thresh=nms_thresh
  def forward(self,loc_data,conf_data,dbox_list):
    num_batch=loc_data.size(0)
    num_dbox=loc_data.size(1)
    num_classes=conf_data(2)
    conf_data=self.softmax(conf_data)
    output=torch.zeros(num_batch,num_classes,self.top_k,5)
    conf_preds=conf_data.transpose(2,1)
    for i in range(num_batch):
      decoded_boxes=decode(loc_data[i],dbox_list)
      conf_scores=conf_preds[i].clone()
      for cl in range(1,num_classes):
        c_mask=conf_scores[cl].gt(self.conf_thresh)
        scores=conf_scores[cl][c_mask]
        if scores.nelement()==0:continue
        l_mask=c_mask.unsqueeze(1).expana_as(decoded_boxes)
        boxes=decoded_boxes[l_mask].view(-1,4)
        ids,count=nm_supression(boxes,scores,self.nms_thresh,self.top_k)
        output[i,cl,:count]=torch.cat((
            scores[ids[:count]].unsqueeze(1),boxes[ids[:count]])
        ,1)
    return output

In [20]:
class SSD(nn.Module):
  def __init__(self,phase,cfg):
    super().__init__()
    self.phase=phase
    self.num_classes=cfg["num_classes"]
    self.vgg=make_vgg()
    self.extras=make_extras()
    self.L2Norm=L2Norm()
    self.loc,self.conf=make_loc_conf(cfg["num_classes"],cfg["bbox_aspect_num"])
    self.dbox=DBox(cfg)
    self.dbox_list=dbox.make_dbox_list()
    if phase=="inference":self.detect=Detect()
  def forward(self,x):
    sources,loc,conf=[],[],[]
    for k in range(23):
      x=self.vgg[k](x)
    source1=self.L2Norm(x)
    sources.append(source1)
    for k in range(23,len(self.vgg)):
      x=self.vgg[k](x)
    sources.append(x)
    for k,v in enumerate(self.extras):
      x=F.relu(v[x],inplace=True)
      if k%2==1:sources.append(x)
    for (x,l,c) in zip(sources,self.loc,self.conf):
      loc.append(l(x).permute(0,2,3,1).contiguous())
      conf.append(c(x).permute(0,2,3,1).contiguous())
    loc=torch.cat([o.view(o.size(0),-1) for o in loc],1)
    conf=torch.cat([o.view(o.size(0),-1) for o in conf],1)
    loc=loc.view(loc.size(0),-1,4)
    conf=conf.view(conf.size(0),-1,self.num_classes)
    output=(loc,conf,self.dbox_list)
    if self.phase=="inference": return self.detect(output[0], output[1], output[2])
    else: return output