Skip to content

Commit

Permalink
Update fashionClassParallelTrainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Dec 21, 2019
1 parent 3511f9d commit 6fbc4bf
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions jdit/trainer/instances/fashionClassParallelTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def __init__(self, logdir, nepochs, gpu_ids, net, opt, dataset, num_class):

def compute_loss(self):
var_dic = {}
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.ground_truth.squeeze().long())

_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
labels = self.ground_truth.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
Expand Down

0 comments on commit 6fbc4bf

Please sign in to comment.