Skip to content

Commit

Permalink
Fix bugs in instances
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Dec 22, 2019
1 parent 6fbc4bf commit 4f98733
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 47 deletions.
32 changes: 15 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ Install requirement.
``` {.sourceCode .bash}
pip install -r requirements.txt
```

### From pip
``` {.sourceCode .bash}
pip install jdit
```

### From source
This method is recommended, because you can keep the newest version.
1. Clone from github
Expand All @@ -50,13 +56,10 @@ This method is recommended, because you can keep the newest version.
3. Install
You will find packages in `jdit/dist/`. Use pip to install.
``` {.sourceCode .bash}
pip install dist/jdit-0.x.0-py3-none-any.whl
pip install dist/jdit-x.y.z-py3-none-any.whl
```

### From pip
``` {.sourceCode .bash}
pip install jdit
```


## Quick start

Expand Down Expand Up @@ -109,23 +112,15 @@ class FashingClassTrainer(ClassificationTrainer):
def compute_loss(self):
var_dic = {}
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.ground_truth.squeeze().long())
_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.ground_truth.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, labels)
return loss, var_dic
def compute_valid(self):
var_dic = {}
var_dic["CEP"] = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())
_, var_dic = self.compute_loss()
labels = self.ground_truth.squeeze().long()
_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
total = predict.size(0)
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
Expand All @@ -134,6 +129,7 @@ class FashingClassTrainer(ClassificationTrainer):
def start_fashingClassTrainer(gpus=(), nepochs=10, run_type="train"):
"""" An example of fashing-mnist classification
"""
num_class = 10
depth = 32
Expand Down Expand Up @@ -165,8 +161,10 @@ def start_fashingClassTrainer(gpus=(), nepochs=10, run_type="train"):
Trainer.train()
elif run_type == "debug":
Trainer.debug()
if __name__ == '__main__':
start_fashingClassTrainer()
```

Then you will see something like this as following.
Expand Down
2 changes: 1 addition & 1 deletion jdit/trainer/instances/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .fashingGenerateGan import FashingGenerateGenerateGanTrainer, start_fashingGenerateGanTrainer
from .cifarPix2pixGan import start_cifarPix2pixGanTrainer
from .fashionClassParallelTrainer import start_fashingClassPrarallelTrainer
from .fashingAutoencoder import start_fashingAotoencoderTrainer, FashingAutoEncoderTrainer
from .fashingAutoencoder import FashingAutoEncoderTrainer, start_fashingAotoencoderTrainer
__all__ = ['FashingClassTrainer', 'start_fashingClassTrainer',
'FashingGenerateGenerateGanTrainer', 'start_fashingGenerateGanTrainer',
'cifarPix2pixGan', 'start_cifarPix2pixGanTrainer', 'start_fashingClassPrarallelTrainer',
Expand Down
3 changes: 1 addition & 2 deletions jdit/trainer/instances/fashingAutoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@


class SimpleModel(nn.Module):
def __init__(self, depth=64, num_class=10):
def __init__(self, depth=32):
super(SimpleModel, self).__init__()
self.num_class = num_class
self.layer1 = nn.Conv2d(1, depth, 3, 1, 1)
self.layer2 = nn.Conv2d(depth, depth * 2, 4, 2, 1)
self.layer3 = nn.Conv2d(depth * 2, depth * 4, 4, 2, 1)
Expand Down
16 changes: 4 additions & 12 deletions jdit/trainer/instances/fashingClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,15 @@ def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets, num_class):

def compute_loss(self):
var_dic = {}
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.ground_truth.squeeze().long())

_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.ground_truth.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, labels)
return loss, var_dic

def compute_valid(self):
var_dic = {}
var_dic["CEP"] = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())

_, var_dic = self.compute_loss()
labels = self.ground_truth.squeeze().long()
_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
total = predict.size(0)
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
Expand Down
12 changes: 2 additions & 10 deletions jdit/trainer/instances/fashionClassParallelTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,7 @@ def compute_loss(self):
return loss, var_dic

def compute_valid(self):
var_dic = {}
var_dic["CEP"] = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())

_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
_,var_dic = self.compute_loss()
return var_dic


Expand Down Expand Up @@ -102,7 +94,7 @@ def trainerParallel():
return tp


def start_fashingClassPrarallelTrainer():
def start_fashingClassPrarallelTrainer(run_type="debug"):
tp = trainerParallel()
tp.train()

Expand Down
7 changes: 4 additions & 3 deletions jdit/trainer/single/sup_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def __init__(self, logdir, nepochs, gpu_ids_abs, net: Model, opt: Optimizer, dat
self.net = net
self.opt = opt
self.datasets = datasets
self.fake = None
self.fixed_input = None
self.input = None
self.output = None
self.ground_truth = None

def train_epoch(self, subbar_disable=False):
for iteration, batch in tqdm(enumerate(self.datasets.loader_train, 1), unit="step", disable=subbar_disable):
Expand Down Expand Up @@ -108,7 +109,7 @@ def _watch_images(self, tag: str, grid_size: tuple = (3, 3), shuffle=False, save

def compute_loss(self) -> (torch.Tensor, dict):
""" Rewrite this method to compute your own loss Discriminator.
Use self.input, self.output and self.ground_truth to compute loss.
You should return a **loss** for the first position.
You can return a ``dict`` of loss that you want to visualize on the second position.like
Expand All @@ -126,7 +127,7 @@ def compute_loss(self) -> (torch.Tensor, dict):

def compute_valid(self) -> dict:
""" Rewrite this method to compute your validation values.
Use self.input, self.output and self.ground_truth to compute valid loss.
You can return a ``dict`` of validation values that you want to visualize.
Example::
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="jdit", # pypi中的名称,pip或者easy_install安装时使用的名称,或生成egg文件的名称
version="0.1.1",
version="0.1.3",
author="Guanglei Ding",
author_email="dingguanglei.bupt@qq.com",
maintainer='Guanglei Ding',
Expand Down
8 changes: 7 additions & 1 deletion unittests/test_instences.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest import TestCase
from jdit.trainer.instances import start_fashingClassTrainer, start_fashingGenerateGanTrainer, \
start_cifarPix2pixGanTrainer
start_cifarPix2pixGanTrainer, start_fashingAotoencoderTrainer, start_fashingClassPrarallelTrainer
import shutil
import os

Expand All @@ -15,6 +15,12 @@ def test_start_fashingGenerateGanTrainer(self):
def test_start_cifarPix2pixGanTrainer(self):
start_cifarPix2pixGanTrainer(run_type="debug")

def test_start_fashingAotoencoderTrainer(self):
start_fashingAotoencoderTrainer(run_type="debug")

def test_start_fashingClassPrarallelTrainer(self):
start_fashingClassPrarallelTrainer()

def setUp(self):
dir = "log_debug/"
if os._exists(dir):
Expand Down

0 comments on commit 4f98733

Please sign in to comment.