Skip to content

Commit

Permalink
minor changes (1) only save the lastest model (save every 10 epoch) (…
Browse files Browse the repository at this point in the history
…2) do not evaluate OOD detection performance during the training phase (3) fixed learning rate scheduling of supervised CSI and SupCLR
  • Loading branch information
jihoontack committed Sep 2, 2020
1 parent 08ef9fb commit 4d95017
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 84 deletions.
4 changes: 1 addition & 3 deletions common/common.py
Expand Up @@ -32,10 +32,8 @@ def parse_args(default=False):
default=None, type=str)
parser.add_argument('--error_step', help='Epoch steps to compute errors',
default=5, type=int)
parser.add_argument('--eval_step', help='Epoch steps to evaluate values',
default=50, type=int)
parser.add_argument('--save_step', help='Epoch steps to save models',
default=100, type=int)
default=10, type=int)

##### Training Configurations #####
parser.add_argument('--epochs', help='Epochs',
Expand Down
56 changes: 7 additions & 49 deletions train.py
@@ -1,10 +1,9 @@
from utils.utils import Logger
from utils.utils import save_checkpoint
from utils.utils import save_checkpoint_epoch
from utils.utils import save_linear_checkpoint

from common.train import *
from evals import test_classifier, eval_ood_detection
from evals import test_classifier

if 'sup' in P.mode:
from training.sup import setup
Expand Down Expand Up @@ -39,61 +38,20 @@

model.eval()

if epoch % P.error_step == 0:
error = test_classifier(P, model, test_loader, epoch, logger=logger)

if epoch % P.save_step == 0 and P.local_rank == 0:
if P.multi_gpu:
save_states = model.module.state_dict()
else:
save_states = model.state_dict()
save_checkpoint(epoch, save_states, optimizer.state_dict(), logger.logdir)
save_linear_checkpoint(linear_optim.state_dict(), logger.logdir)

if epoch % P.error_step == 0 and ('sup' in P.mode):
error = test_classifier(P, model, test_loader, epoch, logger=logger)

is_best = (best > error)
if is_best:
best = error

if P.local_rank == 0:
save_checkpoint(epoch, best, save_states, optimizer.state_dict(), logger.logdir, is_best)
save_linear_checkpoint(linear_optim.state_dict(), logger.logdir, is_best)
if epoch % P.save_step == 0:
save_checkpoint_epoch(epoch, save_states, optimizer.state_dict(), logger.logdir)

logger.scalar_summary('eval/best_error', best, epoch)
logger.log('[Epoch %3d] [Test %5.2f] [Best %5.2f]' % (epoch, error, best))

if epoch % P.eval_step == 0:
logger.log('[Epoch %3d Evaluation]' % epoch)

# OOD detection
if 'sup' in P.mode:
ood_scores = ['baseline']
else:
ood_scores = ['clean_norm', 'similar']

if 'simclr' in P.mode:
kwargs = {'simclr_aug': simclr_aug}
else:
kwargs = {}

auroc_dict = eval_ood_detection(P, model, test_loader, ood_test_loader, ood_scores, **kwargs)

if P.one_class_idx is not None:
mean_dict = dict()
for ood_score in ood_scores:
mean = 0
for ood in auroc_dict.keys():
mean += auroc_dict[ood][ood_score]
mean_dict[ood_score] = mean / len(auroc_dict.keys())
auroc_dict['one_class_mean'] = mean_dict

for ood in auroc_dict.keys():
message = ''
best_auroc = 0
for ood_score, auroc in auroc_dict[ood].items():
message += '[%s %s %.4f] ' % (ood, ood_score, auroc)
logger.scalar_summary(f'eval_ood_{ood}/{ood_score}', auroc, epoch)
if auroc > best_auroc:
best_auroc = auroc
message += '[%s %s %.4f] ' % (ood, 'best', best_auroc)
logger.scalar_summary(f'eval_ood_{ood}/best', best_auroc, epoch)
logger.log(message)

9 changes: 6 additions & 3 deletions training/sup/sup_linear.py
Expand Up @@ -100,9 +100,7 @@ def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
loss_joint.backward()
P.joint_linear_optim.step()

### optimizer learning rate scheduler ###
P.linear_scheduler.step(epoch - 1 + n / len(loader))
P.rot_scheduler.step(epoch - 1 + n / len(loader))
### optimizer learning rate ###
lr = P.linear_optim.param_groups[0]['lr']

batch_time.update(time.time() - check)
Expand All @@ -118,6 +116,11 @@ def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
losses['cls'].value, losses['rot'].value))
check = time.time()

### optimizer learning rate scheduler ###
P.linear_scheduler.step()
P.rot_scheduler.step()
P.joint_scheduler.step()

log_('[DONE] [Time %.3f] [Data %.3f] [LossC %f] [LossR %f]' %
(batch_time.average, data_time.average,
losses['cls'].average, losses['rot'].average))
Expand Down
31 changes: 2 additions & 29 deletions utils/utils.py
Expand Up @@ -125,36 +125,13 @@ def load_checkpoint(logdir, mode='last'):
return model_state, optim_state, cfg


def save_checkpoint(epoch, best, model_state, optim_state, logdir, is_best):
def save_checkpoint(epoch, model_state, optim_state, logdir):
last_model = os.path.join(logdir, 'last.model')
best_model = os.path.join(logdir, 'best.model')
last_optim = os.path.join(logdir, 'last.optim')
best_optim = os.path.join(logdir, 'best.optim')
last_config = os.path.join(logdir, 'last.config')
best_config = os.path.join(logdir, 'best.config')

opt = {
'epoch': epoch,
'best': best
}
torch.save(model_state, last_model)
torch.save(optim_state, last_optim)
with open(last_config, 'wb') as handle:
pickle.dump(opt, handle, protocol=pickle.HIGHEST_PROTOCOL)
if is_best:
shutil.copyfile(last_model, best_model)
shutil.copyfile(last_optim, best_optim)
shutil.copyfile(last_config, best_config)


def save_checkpoint_epoch(epoch, model_state, optim_state, logdir):
last_model = os.path.join(logdir, f'epoch{epoch}.model')
last_optim = os.path.join(logdir, f'epoch{epoch}.optim')
last_config = os.path.join(logdir, f'epoch{epoch}.config')

opt = {
'epoch': epoch,
'best': None
}
torch.save(model_state, last_model)
torch.save(optim_state, last_optim)
Expand All @@ -178,14 +155,10 @@ def load_linear_checkpoint(logdir, mode='last'):
return None


def save_linear_checkpoint(linear_optim_state, logdir, is_best):
def save_linear_checkpoint(linear_optim_state, logdir):
last_linear_optim = os.path.join(logdir, 'last.linear_optim')
best_linear_optim = os.path.join(logdir, 'best.linear_optim')
torch.save(linear_optim_state, last_linear_optim)

if is_best:
shutil.copyfile(last_linear_optim, best_linear_optim)


def set_random_seed(seed):
random.seed(seed)
Expand Down

0 comments on commit 4d95017

Please sign in to comment.