Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Nov 19, 2019
1 parent b4cfce8 commit 8836b14
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ The following is the accomplishment of ``start_fashingClassTrainer()``
import torch
import torch.nn as nn
import torch.nn.functional as F
from jdit.trainer.classification import ClassificationTrainer
from jdit.trainer.single.classification import ClassificationTrainer
from jdit import Model
from jdit.optimizer import Optimizer
from jdit.dataset import FashionMNIST
Expand Down Expand Up @@ -110,26 +110,16 @@ class FashingClassTrainer(ClassificationTrainer):
def compute_loss(self):
var_dic = {}
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.ground_truth)
_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
predict = torch.argmax(self.output.detach(), 1) # 0100=>1 0010=>2
correct = predict.eq(self.ground_truth).sum().float()
acc = correct / predict.size()[0]
var_dic["ACC"] = acc
return loss, var_dic
def compute_valid(self):
var_dic = {}
var_dic["CEP"] = cep = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())
_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
_, var_dic = self.compute_loss()
return var_dic
Expand Down

0 comments on commit 8836b14

Please sign in to comment.