Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Jan 19, 2023
1 parent 077c28c commit a13c3d6
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions examples/RandomNAS_MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "3rm2g1DCTXG6"
},
"source": [
"# Install Dependencies and Import Modules"
"# Install Dependencies and Import Modules\n",
"\n",
"You shoud restart the runtime after running the following pip commands."
]
},
{
Expand Down Expand Up @@ -174,7 +177,7 @@
],
"source": [
"!pip install hyperbox==1.3.1\n",
"!pip install rich wandb loguru"
"!pip install pytorch-lightning==1.8."
]
},
{
Expand Down Expand Up @@ -564,7 +567,7 @@
" loss = F.nll_loss(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" if batch_idx % args.log_interval == 0:\n",
" if (batch_idx + 1) % args.log_interval == 0:\n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx * len(data), len(train_loader.dataset),\n",
" 100. * batch_idx / len(train_loader), loss.item()))\n",
Expand Down Expand Up @@ -633,7 +636,14 @@
" 'conv2': F.one_hot(index2, num_classes=3).view(-1).bool()\n",
" }\n",
" search_space.append(arch)\n",
"print(f\"Search space includes {len(search_space)} candidate models.\")"
"print(f\"Search space includes {len(search_space)} candidate models.\")\n",
"\n",
"def mask_to_arch_str(mask: dict):\n",
" conv_names = np.array(['conv3x3', 'conv5x5', 'conv7x7'])\n",
" arch = ''\n",
" for key, one_hot_mask in mask.items():\n",
" arch += f\"{conv_names[one_hot_mask][0]}, \"\n",
" return arch"
]
},
{
Expand Down

0 comments on commit a13c3d6

Please sign in to comment.