Skip to content

Commit

Permalink
Added precision to train log
Browse files Browse the repository at this point in the history
  • Loading branch information
eriklindernoren committed Oct 1, 2018
1 parent 3f5f07b commit 844c6e3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 30 deletions.
8 changes: 6 additions & 2 deletions models.py
Expand Up @@ -156,6 +156,7 @@ def forward(self, x, targets=None):
if x.is_cuda:
self.mse_loss = self.mse_loss.cuda()
self.bce_loss = self.bce_loss.cuda()
self.ce_loss = self.ce_loss.cuda()

nGT, nCorrect, mask, conf_mask, tx, ty, tw, th, tconf, tcls = build_targets(
pred_boxes=pred_boxes.cpu().data,
Expand All @@ -170,8 +171,9 @@ def forward(self, x, targets=None):
img_dim=self.image_dim,
)

nProposals = int((pred_conf > 0.25).sum().item())
nProposals = int((pred_conf > 0.5).sum().item())
recall = float(nCorrect / nGT) if nGT else 1
precision = float(nCorrect / nProposals)

# Handle masks
mask = Variable(mask.type(ByteTensor))
Expand Down Expand Up @@ -210,6 +212,7 @@ def forward(self, x, targets=None):
loss_conf.item(),
loss_cls.item(),
recall,
precision,
)

else:
Expand All @@ -235,7 +238,7 @@ def __init__(self, config_path, img_size=416):
self.img_size = img_size
self.seen = 0
self.header_info = np.array([0, 0, 0, self.seen, 0])
self.loss_names = ["x", "y", "w", "h", "conf", "cls", "recall"]
self.loss_names = ["x", "y", "w", "h", "conf", "cls", "recall", "precision"]

def forward(self, x, targets=None):
is_training = targets is not None
Expand Down Expand Up @@ -264,6 +267,7 @@ def forward(self, x, targets=None):
layer_outputs.append(x)

self.losses["recall"] /= 3
self.losses["precision"] /= 3
return sum(output) if is_training else torch.cat(output, 1)

def load_weights(self, weights_path):
Expand Down
48 changes: 20 additions & 28 deletions test.py
Expand Up @@ -20,26 +20,26 @@
import torch.optim as optim

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=16, help='size of each image batch')
parser.add_argument('--model_config_path', type=str, default='config/yolov3.cfg', help='path to model config file')
parser.add_argument('--data_config_path', type=str, default='config/coco.data', help='path to data config file')
parser.add_argument('--weights_path', type=str, default='weights/yolov3.weights', help='path to weights file')
parser.add_argument('--class_path', type=str, default='data/coco.names', help='path to class label file')
parser.add_argument('--iou_thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
parser.add_argument('--conf_thres', type=float, default=0.5, help='object confidence threshold')
parser.add_argument('--nms_thres', type=float, default=0.45, help='iou thresshold for non-maximum suppression')
parser.add_argument('--n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation')
parser.add_argument('--img_size', type=int, default=416, help='size of each image dimension')
parser.add_argument('--use_cuda', type=bool, default=True, help='whether to use cuda if available')
parser.add_argument("--batch_size", type=int, default=16, help="size of each image batch")
parser.add_argument("--model_config_path", type=str, default="config/yolov3.cfg", help="path to model config file")
parser.add_argument("--data_config_path", type=str, default="config/coco.data", help="path to data config file")
parser.add_argument("--weights_path", type=str, default="weights/yolov3.weights", help="path to weights file")
parser.add_argument("--class_path", type=str, default="data/coco.names", help="path to class label file")
parser.add_argument("--iou_thres", type=float, default=0.5, help="iou threshold required to qualify as detected")
parser.add_argument("--conf_thres", type=float, default=0.5, help="object confidence threshold")
parser.add_argument("--nms_thres", type=float, default=0.45, help="iou thresshold for non-maximum suppression")
parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--use_cuda", type=bool, default=True, help="whether to use cuda if available")
opt = parser.parse_args()
print(opt)

cuda = torch.cuda.is_available() and opt.use_cuda

# Get data configuration
data_config = parse_data_config(opt.data_config_path)
test_path = data_config['valid']
num_classes = int(data_config['classes'])
data_config = parse_data_config(opt.data_config_path)
test_path = data_config["valid"]
num_classes = int(data_config["classes"])

# Initiate model
model = Darknet(opt.model_config_path)
Expand All @@ -52,19 +52,11 @@

# Get dataloader
dataset = ListDataset(test_path)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=opt.batch_size, shuffle=False, num_workers=opt.n_cpu)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.n_cpu)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

n_gt = 0
correct = 0

print ('Compute mAP...')

outputs = []
targets = None
APs = []
print("Compute mAP...")

all_detections = []
all_annotations = []
Expand Down Expand Up @@ -102,10 +94,10 @@

# Reformat to x1, y1, x2, y2 and rescale to image dimensions
annotation_boxes = np.empty_like(_annotation_boxes)
annotation_boxes[:, 0] = (_annotation_boxes[:, 0] - _annotation_boxes[:, 2] / 2)
annotation_boxes[:, 1] = (_annotation_boxes[:, 1] - _annotation_boxes[:, 3] / 2)
annotation_boxes[:, 2] = (_annotation_boxes[:, 0] + _annotation_boxes[:, 2] / 2)
annotation_boxes[:, 3] = (_annotation_boxes[:, 1] + _annotation_boxes[:, 3] / 2)
annotation_boxes[:, 0] = _annotation_boxes[:, 0] - _annotation_boxes[:, 2] / 2
annotation_boxes[:, 1] = _annotation_boxes[:, 1] - _annotation_boxes[:, 3] / 2
annotation_boxes[:, 2] = _annotation_boxes[:, 0] + _annotation_boxes[:, 2] / 2
annotation_boxes[:, 3] = _annotation_boxes[:, 1] + _annotation_boxes[:, 3] / 2
annotation_boxes *= opt.img_size

for label in range(num_classes):
Expand Down

0 comments on commit 844c6e3

Please sign in to comment.