Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Dec 21, 2019
1 parent 831df32 commit 3511f9d
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ from jdit import Model
from jdit.optimizer import Optimizer
from jdit.dataset import FashionMNIST
# This is your model. Defined by torch.nn.Module
class SimpleModel(nn.Module):
def __init__(self, depth=64, num_class=10):
super(SimpleModel, self).__init__()
Expand All @@ -100,46 +100,58 @@ class SimpleModel(nn.Module):
out = out.view(-1, self.num_class)
return out
# A trainer, you need to rewrite the loss and valid function.
class FashingClassTrainer(ClassificationTrainer):
def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets, num_class):
super(FashingClassTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, datasets, num_class)
data, label = self.datasets.samples_train
# plot samples of dataset in tensorboard.
self.watcher.embedding(data, data, label, 1)
def compute_loss(self):
var_dic = {}
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.ground_truth)
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.ground_truth.squeeze().long())
predict = torch.argmax(self.output.detach(), 1) # 0100=>1 0010=>2
correct = predict.eq(self.ground_truth).sum().float()
acc = correct / predict.size()[0]
_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.ground_truth.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
return loss, var_dic
def compute_valid(self):
_, var_dic = self.compute_loss()
var_dic = {}
var_dic["CEP"] = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())
_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
return var_dic
def start_fashingClassTrainer(gpus=(), nepochs=10, run_type="train"):
"""" An example of fashing-mnist classification
"""
num_class = 10
depth = 32
gpus = gpus
batch_size = 64
batch_size = 4
nepochs = nepochs
logdir = "log/fashion_classify"
opt_hpm = {"optimizer": "Adam",
"lr_decay": 0.94,
"decay_position": 10,
"decay_type": "epoch",
"lr": 1e-3,
"position_type": "epoch",
"lr_reset": {2: 5e-4, 3: 1e-3},
"lr": 1e-4,
"weight_decay": 2e-5,
"betas": (0.9, 0.99)}
print('===> Build dataset')
mnist = FashionMNIST(batch_size=batch_size)
# mnist.dataset_train = mnist.dataset_test
torch.backends.cudnn.benchmark = True
print('===> Building model')
net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=1)
Expand All @@ -148,12 +160,11 @@ def start_fashingClassTrainer(gpus=(), nepochs=10, run_type="train"):
print('===> Training')
print("using `tensorboard --logdir=log` to see learning curves and net structure."
"training and valid_epoch data, configures info and checkpoint were save in `log` directory.")
Trainer = FashingClassTrainer(logdir, nepochs, gpus, net, opt, mnist, num_class)
Trainer = FashingClassTrainer("log/fashion_classify", nepochs, gpus, net, opt, mnist, num_class)
if run_type == "train":
Trainer.train()
elif run_type == "debug":
Trainer.debug()
if __name__ == '__main__':
start_fashingClassTrainer()
```
Expand Down

0 comments on commit 3511f9d

Please sign in to comment.