Skip to content

Commit

Permalink
Multi-scale training
Browse files Browse the repository at this point in the history
  • Loading branch information
eriklindernoren committed Apr 24, 2019
1 parent fd04b68 commit 6a67fba
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
19 changes: 9 additions & 10 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,8 @@ def create_modules(module_defs):
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
anchors = [anchors[i] for i in anchor_idxs]
num_classes = int(module_def["classes"])
img_height = int(hyperparams["height"])
# Define detection layer
yolo_layer = YOLOLayer(anchors, num_classes, img_height)
yolo_layer = YOLOLayer(anchors, num_classes)
modules.add_module("yolo_%d" % i, yolo_layer)
# Register module list and number of output filters
module_list.append(modules)
Expand Down Expand Up @@ -114,7 +113,7 @@ def __init__(self):
class YOLOLayer(nn.Module):
"""Detection layer"""

def __init__(self, anchors, num_classes, img_dim):
def __init__(self, anchors, num_classes):
super(YOLOLayer, self).__init__()
self.anchors = anchors
self.num_anchors = len(anchors)
Expand All @@ -123,16 +122,16 @@ def __init__(self, anchors, num_classes, img_dim):
self.ignore_thres = 0.5
self.mse_loss = nn.MSELoss()
self.bce_loss = nn.BCELoss()
# Set to balance confidence loss for objects and non-objects
self.obj_scale = 1
self.noobj_scale = 5
self.noobj_scale = 10
self.metrics = {}

def forward(self, x, targets=None):
def forward(self, x, targets, img_dim):
nA = self.num_anchors
nB = x.size(0)
nG = x.size(2)
img_size = x.shape[-1]
stride = img_size / nG
stride = img_dim / nG

# Tensors for cuda support
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
Expand Down Expand Up @@ -187,7 +186,6 @@ def forward(self, x, targets=None):
num_classes=self.num_classes,
grid_size=nG,
ignore_thres=self.ignore_thres,
img_dim=img_size,
)

# Masks
Expand Down Expand Up @@ -254,6 +252,7 @@ def __init__(self, config_path, img_size=416):

def forward(self, x, targets=None):
is_training = targets is not None
img_dim = x.shape[2]
output = []
loss = 0
layer_outputs = []
Expand All @@ -268,10 +267,10 @@ def forward(self, x, targets=None):
x = layer_outputs[-1] + layer_outputs[layer_i]
elif module_def["type"] == "yolo":
if is_training:
x, layer_loss = module[0](x, targets)
x, layer_loss = module[0](x, targets, img_dim=img_dim)
loss += layer_loss
else:
x = module(x)
x = module(x, img_dim=img_dim)
output.append(x)
layer_outputs.append(x)

Expand Down
2 changes: 1 addition & 1 deletion utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):


def build_targets(
pred_boxes, pred_conf, pred_cls, target, anchors, num_anchors, num_classes, grid_size, ignore_thres, img_dim
pred_boxes, pred_conf, pred_cls, target, anchors, num_anchors, num_classes, grid_size, ignore_thres
):
nB = target.size(0)
nA = num_anchors
Expand Down

0 comments on commit 6a67fba

Please sign in to comment.