Skip to content

Commit

Permalink
bug_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Nov 4, 2018
1 parent 7b7b249 commit 09d7109
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 166 deletions.
10 changes: 9 additions & 1 deletion autokeras/cnn_module.py
Expand Up @@ -79,7 +79,15 @@ def final_fit(self, train_data, test_data, trainer_args=None, retrain=False):

if retrain:
graph.weighted = False
_, _1, graph = train((graph, train_data, test_data, trainer_args, None, self.metric, self.loss, self.verbose))
_, _1, graph = train((graph,
train_data,
test_data,
trainer_args,
None,
self.metric,
self.loss,
self.verbose,
self.path))
searcher.replace_model(graph, searcher.get_best_model_id())

@property
Expand Down
96 changes: 76 additions & 20 deletions autokeras/nn/model_trainer.py
@@ -1,3 +1,4 @@
import os
import abc
import sys
from copy import deepcopy
Expand All @@ -8,7 +9,7 @@
from tqdm.autonotebook import tqdm

from autokeras.constant import Constant
from autokeras.utils import EarlyStop, get_device
from autokeras.utils import get_device


class ModelTrainerBase(abc.ABC):
Expand Down Expand Up @@ -57,7 +58,7 @@ class ModelTrainer(ModelTrainerBase):
verbose: Verbosity mode.
"""

def __init__(self, model, **kwargs):
def __init__(self, model, path, **kwargs):
"""Init the ModelTrainer with `model`, `x_train`, `y_train`, `x_test`, `y_test`, `verbose`"""
super().__init__(**kwargs)
self.model = model
Expand All @@ -66,6 +67,7 @@ def __init__(self, model, **kwargs):
self.early_stop = None
self.current_epoch = 0
self.current_metric_value = 0
self.temp_model_path = os.path.join(path, 'temp_model')

def train_model(self,
max_iter_num=None,
Expand Down Expand Up @@ -98,10 +100,16 @@ def train_model(self,
test_metric_value_list.append(metric_value)
test_loss_list.append(test_loss)
decreasing = self.early_stop.on_epoch_end(test_loss)

if self.early_stop.no_improvement_count == 0:
self._save_model()

if not decreasing:
if self.verbose:
print('\nNo loss decrease after {} epochs.\n'.format(max_no_improvement_num))
self._load_model()
break

last_num = min(max_no_improvement_num, max_iter_num)
return (sum(test_loss_list[-last_num:]) / last_num,
sum(test_metric_value_list[-last_num:]) / last_num)
Expand All @@ -113,12 +121,17 @@ def _train(self):

if self.verbose:
progress_bar = tqdm(total=len(loader),
desc='Epoch-' + str(self.current_epoch) + ', Current Metric - ' + str(self.current_metric_value),
desc='Epoch-'
+ str(self.current_epoch)
+ ', Current Metric - '
+ str(self.current_metric_value),
file=sys.stdout,
leave=False,
ncols=100,
position=0,
unit=' batch')
else:
progress_bar = None

for batch_idx, (inputs, targets) in enumerate(deepcopy(loader)):
inputs, targets = inputs.to(self.device), targets.to(self.device)
Expand Down Expand Up @@ -152,6 +165,12 @@ def _test(self):
all_targets = reduce(lambda x, y: np.concatenate((x, y)), all_targets)
return test_loss, self.metric.compute(all_predicted, all_targets)

def _save_model(self):
torch.save(self.model.state_dict(), self.temp_model_path)

def _load_model(self):
self.model.load_state_dict(torch.load(self.temp_model_path))


class GANModelTrainer(ModelTrainerBase):
def __init__(self,
Expand Down Expand Up @@ -183,32 +202,36 @@ def train_model(self,
self.optimizer_d = torch.optim.Adam(self.d_model.parameters())
self.optimizer_g = torch.optim.Adam(self.g_model.parameters())
if self.verbose:
pbar = tqdm(total=max_iter_num,
desc=' Model ',
file=sys.stdout,
ncols=75,
position=1,
unit=' epoch')
progress_bar = tqdm(total=max_iter_num,
desc=' Model ',
file=sys.stdout,
ncols=75,
position=1,
unit=' epoch')
else:
progress_bar = None
for epoch in range(max_iter_num):
self._train(epoch)
if self.verbose:
pbar.update(1)
progress_bar.update(1)
if self.verbose:
pbar.close()
progress_bar.close()

def _train(self, epoch):
# put model into train mode
self.d_model.train()
# TODO: why?
cp_loader = deepcopy(self.train_loader)
if self.verbose:
pbar = tqdm(total=len(cp_loader),
desc='Current Epoch',
file=sys.stdout,
leave=False,
ncols=75,
position=0,
unit=' Batch')
progress_bar = tqdm(total=len(cp_loader),
desc='Current Epoch',
file=sys.stdout,
leave=False,
ncols=75,
position=0,
unit=' Batch')
else:
progress_bar = None
real_label = 1
fake_label = 0
for batch_idx, inputs in enumerate(cp_loader):
Expand Down Expand Up @@ -241,12 +264,45 @@ def _train(self, epoch):

if self.verbose:
if batch_idx % 10 == 0:
pbar.update(10)
progress_bar.update(10)
if self.outf is not None and batch_idx % 100 == 0:
fake = self.g_model(self.sample_noise)
vutils.save_image(
fake.detach(),
'%s/fake_samples_epoch_%03d.png' % (self.outf, epoch),
normalize=True)
if self.verbose:
pbar.close()
progress_bar.close()


class EarlyStop:
def __init__(self, max_no_improvement_num=Constant.MAX_NO_IMPROVEMENT_NUM, min_loss_dec=Constant.MIN_LOSS_DEC):
super().__init__()
self.training_losses = []
self.minimum_loss = None
self.no_improvement_count = 0
self._max_no_improvement_num = max_no_improvement_num
self._done = False
self._min_loss_dec = min_loss_dec

def on_train_begin(self):
self.training_losses = []
self.no_improvement_count = 0
self._done = False
self.minimum_loss = float('inf')

def on_epoch_end(self, loss):
self.training_losses.append(loss)
if self._done and loss > (self.minimum_loss - self._min_loss_dec):
return False

if loss > (self.minimum_loss - self._min_loss_dec):
self.no_improvement_count += 1
else:
self.no_improvement_count = 0
self.minimum_loss = loss

if self.no_improvement_count > self._max_no_improvement_num:
self._done = True

return True
5 changes: 3 additions & 2 deletions autokeras/search.py
Expand Up @@ -178,7 +178,7 @@ def search(self, train_data, test_data, timeout=60 * 60 * 24):
try:
train_results = pool.map_async(train, [(graph, train_data, test_data, self.trainer_args,
os.path.join(self.path, str(model_id) + '.png'),
self.metric, self.loss, self.verbose)])
self.metric, self.loss, self.verbose, self.path)])

# Do the search in current thread.
searched = False
Expand Down Expand Up @@ -272,11 +272,12 @@ def get_dict(self, u=None):


def train(args):
graph, train_data, test_data, trainer_args, path, metric, loss, verbose = args
graph, train_data, test_data, trainer_args, path, metric, loss, verbose, path = args
model = graph.produce_model()
# if path is not None:
# plot_model(model, to_file=path, show_shapes=True)
loss, metric_value = ModelTrainer(model=model,
path=path,
train_data=train_data,
test_data=test_data,
metric=metric,
Expand Down
35 changes: 0 additions & 35 deletions autokeras/utils.py
Expand Up @@ -9,47 +9,12 @@
import requests
import torch

from autokeras.constant import Constant


class NoImprovementError(Exception):
def __init__(self, message):
self.message = message


class EarlyStop:
def __init__(self, max_no_improvement_num=Constant.MAX_NO_IMPROVEMENT_NUM, min_loss_dec=Constant.MIN_LOSS_DEC):
super().__init__()
self.training_losses = []
self.minimum_loss = None
self._no_improvement_count = 0
self._max_no_improvement_num = max_no_improvement_num
self._done = False
self._min_loss_dec = min_loss_dec

def on_train_begin(self):
self.training_losses = []
self._no_improvement_count = 0
self._done = False
self.minimum_loss = float('inf')

def on_epoch_end(self, loss):
self.training_losses.append(loss)
if self._done and loss > (self.minimum_loss - self._min_loss_dec):
return False

if loss > (self.minimum_loss - self._min_loss_dec):
self._no_improvement_count += 1
else:
self._no_improvement_count = 0
self.minimum_loss = loss

if self._no_improvement_count > self._max_no_improvement_num:
self._done = True

return True


def ensure_dir(directory):
"""Create directory if it does not exist"""
if not os.path.exists(directory):
Expand Down
2 changes: 2 additions & 0 deletions tests/common.py
Expand Up @@ -9,6 +9,8 @@
StubDense, StubConcatenate, StubAdd, StubPooling
from autokeras.preprocessor import ImageDataTransformer

TEST_TEMP_DIR = 'tests/resources/temp'


def get_concat_skip_model():
graph = Graph((32, 32, 3), False)
Expand Down
4 changes: 2 additions & 2 deletions tests/image/test_dcgan.py
Expand Up @@ -4,7 +4,7 @@

from autokeras.constant import Constant
from autokeras.image.gan import DCGAN
from tests.common import clean_dir
from tests.common import clean_dir, TEST_TEMP_DIR


def mock_train(**kwargs):
Expand All @@ -19,7 +19,7 @@ def test_fit_generate(_):
Constant.SEARCH_MAX_ITER = 1
Constant.T_MIN = 0.8
Constant.DATA_AUGMENTATION = False
image_path, size = 'tests/resources/temp', 32
image_path, size = TEST_TEMP_DIR, 32
clean_dir(image_path)
dcgan = DCGAN(gen_training_result=(image_path, size))
train_x = np.random.rand(100, 32, 32, 3)
Expand Down

0 comments on commit 09d7109

Please sign in to comment.