Skip to content

Commit

Permalink
add alexnet model
Browse files Browse the repository at this point in the history
  • Loading branch information
lb1100 authored and libo committed Nov 29, 2018
1 parent e50e179 commit 5ce18a4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 22 deletions.
59 changes: 42 additions & 17 deletions code/net.py
Expand Up @@ -7,29 +7,34 @@
import torch.nn.functional as F


class SiamRPNBIG(nn.Module):
def __init__(self, feat_in=512, feature_out=512, anchor=5):
super(SiamRPNBIG, self).__init__()
self.anchor = anchor
self.feature_out = feature_out
class SiamRPN(nn.Module):
def __init__(self, size=2, feature_out=512, anchor=5):
configs = [3, 96, 256, 384, 384, 256]
configs = list(map(lambda x: 3 if x==3 else x*size, configs))
feat_in = configs[-1]
super(SiamRPN, self).__init__()
self.featureExtract = nn.Sequential(
nn.Conv2d(3, 192, 11, stride=2),
nn.BatchNorm2d(192),
nn.Conv2d(configs[0], configs[1] , kernel_size=11, stride=2),
nn.BatchNorm2d(configs[1]),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, stride=2),
nn.Conv2d(192, 512, 5),
nn.BatchNorm2d(512),
nn.Conv2d(configs[1], configs[2], kernel_size=5),
nn.BatchNorm2d(configs[2]),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, stride=2),
nn.Conv2d(512, 768, 3),
nn.BatchNorm2d(768),
nn.Conv2d(configs[2], configs[3], kernel_size=3),
nn.BatchNorm2d(configs[3]),
nn.ReLU(inplace=True),
nn.Conv2d(768, 768, 3),
nn.BatchNorm2d(768),
nn.Conv2d(configs[3], configs[4], kernel_size=3),
nn.BatchNorm2d(configs[4]),
nn.ReLU(inplace=True),
nn.Conv2d(768, 512, 3),
nn.BatchNorm2d(512),
nn.Conv2d(configs[4], configs[5], kernel_size=3),
nn.BatchNorm2d(configs[5]),
)

self.anchor = anchor
self.feature_out = feature_out

self.conv_r1 = nn.Conv2d(feat_in, feature_out*4*anchor, 3)
self.conv_r2 = nn.Conv2d(feat_in, feature_out, 3)
self.conv_cls1 = nn.Conv2d(feat_in, feature_out*2*anchor, 3)
Expand All @@ -39,6 +44,8 @@ def __init__(self, feat_in=512, feature_out=512, anchor=5):
self.r1_kernel = []
self.cls1_kernel = []

self.cfg = {}

def forward(self, x):
x_f = self.featureExtract(x)
return self.regress_adjust(F.conv2d(self.conv_r2(x_f), self.r1_kernel)), \
Expand All @@ -51,3 +58,21 @@ def temple(self, z):
kernel_size = r1_kernel_raw.data.size()[-1]
self.r1_kernel = r1_kernel_raw.view(self.anchor*4, self.feature_out, kernel_size, kernel_size)
self.cls1_kernel = cls1_kernel_raw.view(self.anchor*2, self.feature_out, kernel_size, kernel_size)


class SiamRPNBIG(SiamRPN):
def __init__(self):
super(SiamRPNBIG, self).__init__(size=2)
self.cfg = {'lr':0.295, 'window_influence': 0.42, 'penalty_k': 0.055, 'instance_size': 271, 'adaptive': True} # 0.383


class SiamRPNvot(SiamRPN):
def __init__(self):
super(SiamRPNvot, self).__init__(size=1, feature_out=256)
self.cfg = {'lr':0.45, 'window_influence': 0.44, 'penalty_k': 0.04, 'instance_size': 271, 'adaptive': False} # 0.355


class SiamRPNotb(SiamRPN):
def __init__(self):
super(SiamRPNotb, self).__init__(size=1, feature_out=256)
self.cfg = {'lr': 0.30, 'window_influence': 0.40, 'penalty_k': 0.22, 'instance_size': 271, 'adaptive': False} # 0.655
20 changes: 15 additions & 5 deletions code/run_SiamRPN.py
Expand Up @@ -17,6 +17,7 @@ def generate_anchor(total_stride, scales, ratios, score_size):
size = total_stride * total_stride
count = 0
for ratio in ratios:
# ws = int(np.sqrt(size * 1.0 / ratio))
ws = int(np.sqrt(size / ratio))
hs = int(ws * ratio)
for scale in scales:
Expand Down Expand Up @@ -54,6 +55,13 @@ class TrackerConfig(object):
penalty_k = 0.055
window_influence = 0.42
lr = 0.295
# adaptive change search region #
adaptive = True

def update(self, cfg):
for k, v in cfg.items():
setattr(self, k, v)
self.score_size = (self.instance_size - self.exemplar_size) / self.total_stride + 1


def tracker_eval(net, x_crop, target_pos, target_sz, window, scale_z, p):
Expand Down Expand Up @@ -109,15 +117,17 @@ def sz_wh(wh):
def SiamRPN_init(im, target_pos, target_sz, net):
state = dict()
p = TrackerConfig()
p.update(net.cfg)
state['im_h'] = im.shape[0]
state['im_w'] = im.shape[1]

if ((target_sz[0] * target_sz[1]) / float(state['im_h'] * state['im_w'])) < 0.004:
p.instance_size = 287 # small object big search region
else:
p.instance_size = 271
if p.adaptive:
if ((target_sz[0] * target_sz[1]) / float(state['im_h'] * state['im_w'])) < 0.004:
p.instance_size = 287 # small object big search region
else:
p.instance_size = 271

p.score_size = (p.instance_size - p.exemplar_size) / p.total_stride + 1
p.score_size = (p.instance_size - p.exemplar_size) / p.total_stride + 1

p.anchor = generate_anchor(p.total_stride, p.scales, p.ratios, p.score_size)

Expand Down

0 comments on commit 5ce18a4

Please sign in to comment.