Skip to content

Commit

Permalink
md
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Dec 25, 2018
1 parent 173ddb3 commit e7aa175
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 40 deletions.
125 changes: 113 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
![logo](https://github.com/dingguanglei/jdit/blob/master/logo.png)
![logo](https://github.com/dingguanglei/jdit/blob/master/resources/logo.png)

---

Expand Down Expand Up @@ -45,6 +45,104 @@ this code in ipython cmd.(Create a main.py file is also acceptable.)
from jdit.trainer.instances.fashingClassification import start_fashingClassTrainer
start_fashingClassTrainer()
```
The following is the accomplishment of ``start_fashingClassTrainer()``

``` {.sourceCode .python}
# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
from jdit.trainer.classification import ClassificationTrainer
from jdit import Model
from jdit.optimizer import Optimizer
from jdit.dataset import FashionMNIST
# This is your model. Defined by torch.nn.Module
class SimpleModel(nn.Module):
def __init__(self, depth=64, num_class=10):
super(SimpleModel, self).__init__()
self.num_class = num_class
self.layer1 = nn.Conv2d(1, depth, 3, 1, 1)
self.layer2 = nn.Conv2d(depth, depth * 2, 4, 2, 1)
self.layer3 = nn.Conv2d(depth * 2, depth * 4, 4, 2, 1)
self.layer4 = nn.Conv2d(depth * 4, depth * 8, 4, 2, 1)
self.layer5 = nn.Conv2d(depth * 8, num_class, 4, 1, 0)
def forward(self, input):
out = F.relu(self.layer1(input))
out = F.relu(self.layer2(out))
out = F.relu(self.layer3(out))
out = F.relu(self.layer4(out))
out = self.layer5(out)
out = out.view(-1, self.num_class)
return out
# A trainer, you need to rewrite the loss and valid function.
class FashingClassTrainer(ClassificationTrainer):
def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets, num_class):
super(FashingClassTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, datasets, num_class)
data, label = self.datasets.samples_train
# plot samples of dataset in tensorboard.
self.watcher.embedding(data, data, label, 1)
def compute_loss(self):
var_dic = {}
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())
_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
return loss, var_dic
def compute_valid(self):
var_dic = {}
var_dic["CEP"] = cep = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())
_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
return var_dic
def start_fashingClassTrainer(gpus=(), nepochs=10, run_type="train"):
num_class = 10
depth = 32
gpus = gpus
batch_size = 64
nepochs = nepochs
opt_hpm = {"optimizer": "Adam",
"lr_decay": 0.94,
"decay_position": 10,
"decay_type": "epoch",
"lr": 1e-3,
"weight_decay": 2e-5,
"betas": (0.9, 0.99)}
print('===> Build dataset')
mnist = FashionMNIST(batch_size=batch_size)
torch.backends.cudnn.benchmark = True
print('===> Building model')
net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=1)
print('===> Building optimizer')
opt = Optimizer(net.parameters(), **opt_hpm)
print('===> Training')
print("using `tensorboard --logdir=log` to see learning curves and net structure."
"training and valid_epoch data, configures info and checkpoint were save in `log` directory.")
Trainer = FashingClassTrainer("log/fashion_classify", nepochs, gpus, net, opt, mnist, num_class)
if run_type == "train":
Trainer.train()
elif run_type == "debug":
Trainer.debug()
if __name__ == '__main__':
start_fashingClassTrainer()
```

Then you will see something like this as following.

Expand All @@ -68,6 +166,20 @@ training and valid_epoch data, configures info and checkpoint were save in `log`
0%| | 0/10 [00:00<?, ?epoch/s]
0step [00:00, step?/s]
```
To see learning curves in tensorboard. Pay attention to your code about ``var_dic["ACC"], var_dic["CEP"]``.
This will be shown in the tensorboard.
For learning curves:

![tb_curves](https://github.com/dingguanglei/jdit/blob/master/resources/tb_scalars.png)

For Model structure:

![tb_curves](https://github.com/dingguanglei/jdit/blob/master/resources/tb_graphs.png)

For dataaset:
You need to apply ``self.watcher.embedding(data, data, label)``)

![tb_curves](https://github.com/dingguanglei/jdit/blob/master/resources/tb_projector.png)

- It will search a fashing mnist dataset.
- Then build a resnet18 for classification.
Expand Down Expand Up @@ -291,13 +403,8 @@ Something like this:

def train():
for epoch in range(nepochs):
self._record_configs() # record info
self.train_epoch()
self.valid_epoch()
# do learning rate decay
self._change_lr()
# save model check point
self._check_point()
self.test()

Every method will be rewrite by the second level templates. It only
Expand Down Expand Up @@ -357,15 +464,9 @@ Up to this level every this is clear. So, inherit the

``` {.sourceCode .python}
class FashingClassTrainer(ClassificationTrainer):
mode = "L" # used by tensorboard display
num_class = 10
every_epoch_checkpoint = 20
every_epoch_changelr = 10
def __init__(self, logdir, nepochs, gpu_ids, net, opt, dataset):
super(FashingClassTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, dataset)
# to print the network on tensorboard
self.watcher.graph(net, (4, 1, 32, 32), self.use_gpu)
def compute_loss(self):
var_dic = {}
Expand Down
2 changes: 1 addition & 1 deletion jdit/trainer/instances/fashingClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def compute_valid(self):
return var_dic


def start_fashingClassTrainer(gpus=(), nepochs=100, run_type="debug"):
def start_fashingClassTrainer(gpus=(), nepochs=10, run_type="train"):
"""" An example of fashing-mnist classification
"""
Expand Down
65 changes: 38 additions & 27 deletions jdit/trainer/super.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,31 +81,40 @@ def __setattr__(self, key, value):
super(SupTrainer, self).__setattr__(key, value)

if key == "step" and value != 0:
is_change = self._change_lr("step", value)
# is_change = self._change_lr("step", value)
is_change = super(SupTrainer, self).__getattribute__("_change_lr")("step", value)
if is_change:
self._record_configs("optimizer")
# self._record_configs("optimizer")
super(SupTrainer, self).__getattribute__("_record_configs")("optimizer")
elif key == "current_epoch" and value != 0:
is_change = self._change_lr("epoch", value)
is_change = super(SupTrainer, self).__getattribute__("_change_lr")("epoch", value)
if is_change:
self._record_configs("optimizer")
self._check_point()
self._record_configs("performance")
super(SupTrainer, self).__getattribute__("_record_configs")("optimizer")
super(SupTrainer, self).__getattribute__("_check_point")()
super(SupTrainer, self).__getattribute__("_record_configs")("performance")
elif isinstance(value, Model):
self._models.update({key: value})
super(SupTrainer, self).__getattribute__("_models").update({key: value})
elif isinstance(value, Optimizer):
self._opts.update({key: value})
super(SupTrainer, self).__getattribute__("_opts").update({key: value})
elif isinstance(value, DataLoadersFactory):
self._datasets.update({key: value})
super(SupTrainer, self).__getattribute__("_datasets").update({key: value})
else:
pass

def __delattr__(self, item):
if isinstance(item, Model):
self._models.pop(item)
super(SupTrainer, self).__getattribute__("_models").pop(item)
elif isinstance(item, Optimizer):
self._opts.pop(item)
super(SupTrainer, self).__getattribute__("_opts").pop(item)
elif isinstance(item, DataLoadersFactory):
self._datasets.pop(item)
super(SupTrainer, self).__getattribute__("_datasets").pop(item)

def __getattribute__(self, name):
v = super(SupTrainer, self).__getattribute__(name)
if name == "get_data_from_batch":
new_fc = super(SupTrainer, self).__getattribute__("_mv_device")(v)
return new_fc
return v

def debug(self):
"""Debug the trainer.
Expand Down Expand Up @@ -209,18 +218,12 @@ def train_epoch(self, subbar_disable=False):
"""
pass

def __getattribute__(self, name):
v = super().__getattribute__(name)
if name == "get_data_from_batch":
new_fc = self._mv_device(v)
return new_fc
return v

def _mv_device(self, f):
@wraps(f)
def wrapper(*args, **kwargs):
variables = f(*args, **kwargs)
variables = tuple(v.to(self.device) if hasattr(v, "to") else v for v in variables)
device = super(SupTrainer, self).__getattribute__("device")
variables = tuple(v.to(device) if hasattr(v, "to") else v for v in variables)
return variables

return wrapper
Expand Down Expand Up @@ -286,13 +289,16 @@ def _record_configs(self, configs_names=None):
:return:
"""
if (configs_names is None) or "model" in configs_names:
for name, model in self._models.items():
_models = super(SupTrainer, self).__getattribute__("_models")
for name, model in _models.items():
self.loger.regist_config(model, self.current_epoch, self.step, config_filename=name)
if (configs_names is None) or "dataset" in configs_names:
for name, dataset in self._datasets.items():
_datasets = super(SupTrainer, self).__getattribute__("_datasets")
for name, dataset in _datasets.items():
self.loger.regist_config(dataset, self.current_epoch, self.step, config_filename=name)
if (configs_names is None) or "optimizer" in configs_names:
for name, opt in self._opts.items():
_opts = super(SupTrainer, self).__getattribute__("_opts")
for name, opt in _opts.items():
self.loger.regist_config(opt, self.current_epoch, self.step, config_filename=name)
if (configs_names is None) or "trainer" in configs_names or (configs_names is None):
self.loger.regist_config(self, config_filename=self.__class__.__name__)
Expand All @@ -305,16 +311,21 @@ def plot_graphs_lazy(self):
:return:
"""
for name, model in self._models.items():
_models = super(SupTrainer, self).__getattribute__("_models")
for name, model in _models.items():
self.watcher.graph_lazy(model, name)

def _check_point(self):
for name, model in self._models.items():
model.is_checkpoint(name, self.current_epoch, self.logdir)
_models = super(SupTrainer, self).__getattribute__("_models")
current_epoch = super(SupTrainer, self).__getattribute__("current_epoch")
logdir = super(SupTrainer, self).__getattribute__("logdir")
for name, model in _models.items():
model.is_checkpoint(name, current_epoch, logdir)

def _change_lr(self, decay_type="step", position=2):
is_change = False
for name, opt in self._opts.items():
_opts = super(SupTrainer, self).__getattribute__("_opts")
for name, opt in _opts.items():
if opt.decay_type == decay_type and opt.is_lrdecay(position):
opt.do_lr_decay()
is_change = True
Expand Down
File renamed without changes
Binary file added resources/tb_graphs.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/tb_projector.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/tb_scalars.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit e7aa175

Please sign in to comment.