Skip to content

Commit

Permalink
update pbt example to master (#2515)
Browse files Browse the repository at this point in the history
  • Loading branch information
colorjam committed Jun 4, 2020
1 parent 5a911b3 commit 131fb2c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
14 changes: 5 additions & 9 deletions examples/trials/mnist-pbt-tuner-pytorch/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
logger = logging.getLogger('mnist_pbt_tuner_pytorch_AutoML')

class Net(nn.Module):
def __init__(self, hidden_size):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, hidden_size)
self.fc2 = nn.Linear(hidden_size, 10)
self.fc1 = nn.Linear(4*4*50, 512)
self.fc2 = nn.Linear(512, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
Expand Down Expand Up @@ -104,9 +104,7 @@ def main(args):
])),
batch_size=1000, shuffle=True, **kwargs)

hidden_size = args['hidden_size']

model = Net(hidden_size=hidden_size).to(device)
model = Net().to(device)

save_checkpoint_dir = args['save_checkpoint_dir']
save_checkpoint_path = os.path.join(save_checkpoint_dir, 'model.pth')
Expand Down Expand Up @@ -146,11 +144,9 @@ def get_params():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str,
default='./tmp/pytorch/mnist/input_data', help="data directory")
default='/tmp/pytorch/mnist/input_data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
help='hidden layer size (default: 512)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
Expand Down
1 change: 0 additions & 1 deletion examples/trials/mnist-pbt-tuner-pytorch/search_space.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
{
"batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]},
"hidden_size":{"_type":"choice","_value":[128, 256, 512, 1024]},
"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]},
"momentum":{"_type":"uniform","_value":[0, 1]}
}

0 comments on commit 131fb2c

Please sign in to comment.