Skip to content

Commit

Permalink
redundant fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Jan 11, 2019
1 parent 63e0e6c commit c9d8c9b
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 56 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ E-mail: dingguanglei.bupt@qq.com

## Install
Requires:
```
``` {.sourceCode .bash}
tensorboard >= 1.12.0
tensorboardX >= 1.4
pytorch >= 0.4.1
```
Install requirement.
```
``` {.sourceCode .bash}
pip install -r requirements.txt
```
### From source
Expand Down Expand Up @@ -277,4 +277,4 @@ Guide List:

* [Parallel Task](https://dingguanglei.com/jdit_parallel)

* ......
* ......
26 changes: 0 additions & 26 deletions jdit/trainer/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,6 @@ def compute_loss(self):
return loss, var_dic
"""
var_dic = {}
# Input: (N,C) where C = number of classes
# Target: (N) where each value is 0≤targets[i]≤C−1
# ground_truth = self.ground_truth.long().squeeze()
var_dic["CEP"] = loss = CrossEntropyLoss()(self.output, self.labels.squeeze().long())

_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
return loss, var_dic

@abstractmethod
def compute_valid(self):
Expand Down Expand Up @@ -112,19 +99,6 @@ def compute_valid(self):
return var_dic
"""
var_dic = dict()
# Input: (N,C) where C = number of classes
# Target: (N) where each value is 0≤targets[i]≤C−1
# ground_truth = self.ground_truth.long().squeeze()
var_dic["CEP"] = CrossEntropyLoss()(self.output, self.labels.squeeze().long())

_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
labels = self.labels.squeeze().long()
correct = predict.eq(labels).cpu().sum().float()
acc = correct / total
var_dic["ACC"] = acc
return var_dic

def valid_epoch(self):
avg_dic = dict()
Expand Down
2 changes: 1 addition & 1 deletion jdit/trainer/gan/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def compute_valid(self):
return var_dic

def test(self):
self.input = Variable(torch.randn((16, *self.latent_shape))).to(self.device)
self.input = Variable(torch.randn((self.datasets.batch_size, *self.latent_shape))).to(self.device)
self.netG.eval()
with torch.no_grad():
fake = self.netG(self.input).detach()
Expand Down
14 changes: 7 additions & 7 deletions jdit/trainer/gan/pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import abstractmethod
import torch


class Pix2pixGanTrainer(SupGanTrainer):
d_turn = 1

Expand All @@ -19,7 +20,7 @@ def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset
"""
super(Pix2pixGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets)

def get_data_from_batch(self,batch_data: list, device: torch.device):
def get_data_from_batch(self, batch_data: list, device: torch.device):
input_tensor, ground_truth_tensor = batch_data[0], batch_data[1]
return input_tensor, ground_truth_tensor

Expand Down Expand Up @@ -113,10 +114,10 @@ def valid_epoch(self):
self.netG.eval()
self.netD.eval()
if self.fixed_input is None:
for iteration, batch in enumerate(self.datasets.loader_test, 1):
if isinstance(batch, list):
for batch in self.datasets.loader_test:
if isinstance(batch, (list, tuple)):
self.fixed_input, fixed_ground_truth = self.get_data_from_batch(batch, self.device)
self.watcher.image(self.fixed_input, self.current_epoch, tag="Fixed/groundtruth",
self.watcher.image(fixed_ground_truth, self.current_epoch, tag="Fixed/groundtruth",
grid_size=(6, 6),
shuffle=False)
else:
Expand All @@ -125,6 +126,7 @@ def valid_epoch(self):
grid_size=(6, 6),
shuffle=False)
break

# watching the variation during training by a fixed input
with torch.no_grad():
fake = self.netG(self.fixed_input).detach()
Expand Down Expand Up @@ -152,10 +154,8 @@ def test(self):
self.watcher.image(fake, self.current_epoch, tag="Test/fake", grid_size=(4, 4), shuffle=False)
self.netG.train()
"""
for index, batch in enumerate(self.datasets.loader_test, 1):
# For test only have input without groundtruth
for batch in self.datasets.loader_test:
self.input, _ = self.get_data_from_batch(batch, self.device)
# self.input = batch.to(self.device) if isinstance(batch,tuple) else batch[0].to(self.device)
self.netG.eval()
with torch.no_grad():
fake = self.netG(self.input).detach()
Expand Down
2 changes: 1 addition & 1 deletion jdit/trainer/instances/fashingClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def compute_loss(self):

def compute_valid(self):
var_dic = {}
var_dic["CEP"] = cep = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())
var_dic["CEP"] = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())

_, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2
total = predict.size(0) * 1.0
Expand Down
9 changes: 3 additions & 6 deletions unittests/test_instences.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@


class TestInstances(TestCase):
@staticmethod
def test_start_fashingClassTrainer():
def test_start_fashingClassTrainer(self):
start_fashingClassTrainer(run_type="debug")

@staticmethod
def test_start_fashingGenerateGanTrainer():
def test_start_fashingGenerateGanTrainer(self):
start_fashingGenerateGanTrainer(run_type="debug")

@staticmethod
def test_start_cifarPix2pixGanTrainer():
def test_start_cifarPix2pixGanTrainer(self):
start_cifarPix2pixGanTrainer(run_type="debug")

def setUp(self):
Expand Down
20 changes: 8 additions & 12 deletions unittests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,17 @@ def setUp(self):
self.opt = Optimizer(param, "RMSprop", 0.5, 3, "step", lr=2)

def test_do_lr_decay(self):
param = torch.nn.Linear(10, 1).parameters()
opt = Optimizer(param, "RMSprop", 0.5, 3, "step", lr=0.999)
opt.do_lr_decay(reset_lr=2, reset_lr_decay=0.3)
self.assertEqual(opt.lr, 2)
self.assertEqual(opt.lr_decay, 0.3)
opt.do_lr_decay()
self.assertEqual(opt.lr, 0.6)
self.opt.do_lr_decay(reset_lr=2, reset_lr_decay=0.3)
self.assertEqual(self.opt.lr, 2)
self.assertEqual(self.opt.lr_decay, 0.3)
self.opt.do_lr_decay()
self.assertEqual(self.opt.lr, 0.6)

def test_is_lrdecay(self):
self.assert_(not self.opt.is_lrdecay(2))
self.assert_(self.opt.is_lrdecay(3))

def test_configure(self):
param = torch.nn.Linear(10, 1).parameters()
opt = Optimizer(param, "RMSprop", 0.5, 3, "step", lr=0.999)
opt.do_lr_decay(reset_lr=2, reset_lr_decay=0.3)
self.assertEqual(opt.configure['lr'], 2)
self.assertEqual(opt.configure['lr_decay'], '0.3')
self.opt.do_lr_decay(reset_lr=2, reset_lr_decay=0.3)
self.assertEqual(self.opt.configure['lr'], 2)
self.assertEqual(self.opt.configure['lr_decay'], '0.3')

0 comments on commit c9d8c9b

Please sign in to comment.