Skip to content

Commit

Permalink
a few updates on risk curve
Browse files Browse the repository at this point in the history
  • Loading branch information
lilianweng committed Mar 21, 2019
1 parent c2293e6 commit 7593d04
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions risk_curve.py
Expand Up @@ -24,7 +24,7 @@ def __init__(self, loss_type, max_epochs, n_train_sample):
logging.info(f"critical_n_units: {critical_n_units}")
self.n_units_to_test = sorted(set(
list(range(critical_n_units - 7, critical_n_units + 4)) +
list(range(5, 105, 10))
list(range(5, 55, 5)) + list(range(50, 105, 10)) + [120, 150, 200]
))
logging.info(f"n_units_to_test: {self.n_units_to_test}")

Expand All @@ -43,7 +43,7 @@ def run(self):
'--loss-type', str(self.loss_type),
]

if old_n_units and total_params < self.n_train_sample * 10:
if old_n_units: # and total_params < self.n_train_sample * 10:
args.extend(['--old-n-units', str(old_n_units)])

proc = subprocess.run(
Expand Down Expand Up @@ -78,5 +78,5 @@ def plot(self):


if __name__ == '__main__':
exp = NewRiskCurveExperiment(loss_type='mse', max_epochs=500, n_train_sample=2500)
exp = NewRiskCurveExperiment(loss_type='mse', max_epochs=500, n_train_sample=4000)
exp.run()
6 changes: 3 additions & 3 deletions risk_curve_evaluate_model.py
Expand Up @@ -108,12 +108,12 @@ def report_performance(sess, n_units, old_n_units, max_epochs, loss_type, lr, ba
@click.option('--old-n-units', default=None, type=int, help="")
@click.option('--loss-type', default='mse', type=str, help="type of loss func.")
@click.option('--max-epochs', default=500, type=int, help="num. training epochs.")
@click.option('--n-train-samples', default=2500, type=int, help="num. training samples")
def main(n_units=1, old_n_units=None, loss_type='mse', max_epochs=500, n_train_samples=2500):
@click.option('--n-train-samples', default=4000, type=int, help="num. training samples")
def main(n_units=1, old_n_units=None, loss_type='mse', max_epochs=500, n_train_samples=4000):
assert old_n_units is None or old_n_units < n_units
logging.info(f"n_units:{n_units} max_epochs:{max_epochs}")
sess = make_session()
lr = 0.005
lr = 0.001
batch_size = 128

epoch, step, train_loss, train_acc, eval_loss, eval_acc = report_performance(
Expand Down

0 comments on commit 7593d04

Please sign in to comment.