Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed May 15, 2023
1 parent 5983bb9 commit 744ee00
Show file tree
Hide file tree
Showing 4 changed files with 1,668 additions and 32 deletions.
100 changes: 71 additions & 29 deletions examples/DARTS_CIFAR10.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,7 @@
}
],
"source": [
"# !pip install hyperbox==1.3.2\n",
"!pip uninstall -y hyperbox\n",
"!pip install git+https://github.com/marsggbo/hyperbox.git\n",
"!pip install pytorch-lightning==1.8.6"
"!pip install git+https://github.com/marsggbo/hyperbox.git"
]
},
{
Expand Down Expand Up @@ -315,7 +312,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand All @@ -328,8 +325,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[36m[2023-02-01 04:22:20]\u001b[0m \u001b[32m[INFO]\u001b[0m \u001b[31m[/usr/local/lib/python3.8/dist-packages/hyperbox/utils/logger.py:27 (hyperbox.utils.logger)]\u001b[0m Logger is configured: <loguru.logger handlers=[(id=1, level=20, sink=stderr), (id=2, level=20, sink='/content/exp.log')]> 139696350142672\n",
"\u001b[36m[2023-02-01 04:22:20]\u001b[0m \u001b[32m[INFO]\u001b[0m \u001b[31m[/usr/local/lib/python3.8/dist-packages/hyperbox/utils/logger.py:27 (hyperbox.utils.logger)]\u001b[0m Logger is configured: <loguru.logger handlers=[(id=3, level=20, sink=stderr), (id=4, level=20, sink='/content/exp.log')]> 139696350142672\n"
"/datasets/xihe/miniconda3/envs/colossal/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
Expand All @@ -342,11 +339,10 @@
"import torch.optim as optim\n",
"import torchvision\n",
"\n",
"from torch.optim.lr_scheduler import StepLR\n",
"from torchvision import datasets, transforms\n",
"\n",
"from hyperbox.mutables import spaces, ops\n",
"from hyperbox.mutator import DartsMutator\n",
"from hyperbox.mutator import DartsMutator, RandomMutator, OnehotMutator\n",
"from hyperbox.networks.base_nas_network import BaseNASNetwork\n"
]
},
Expand Down Expand Up @@ -394,7 +390,7 @@
" transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
"])\n",
"\n",
"batch_size = 64\n",
"batch_size = 32\n",
"\n",
"all_train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)\n",
"length_set = len(all_train_set)\n",
Expand Down Expand Up @@ -449,8 +445,6 @@
"import numpy as np\n",
"\n",
"# functions to show an image\n",
"\n",
"\n",
"def imshow(img):\n",
" img = img / 2 + 0.5 # unnormalize\n",
" npimg = img.numpy()\n",
Expand All @@ -469,12 +463,13 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "YA1FtaOqXcFe"
},
"source": [
"# Define a Supernet"
"# Define a Supernet and Search Strategy\n"
]
},
{
Expand All @@ -497,8 +492,8 @@
}
],
"source": [
"from hyperbox.networks.resnet import resnet18, resnet50\n",
"from torchvision import models"
"from hyperbox.networks.mobilenet.mobile_net import MobileNet\n",
"from hyperbox.mutator import DartsMutator, RandomMutator, OnehotMutator"
]
},
{
Expand All @@ -521,9 +516,25 @@
}
],
"source": [
"from hyperbox.mutator import DartsMutator, RandomMutator, OnehotMutator\n",
"net = resnet18(ratios=[0.5, 0.8, 1], num_classes=10)\n",
"# net = models.resnet50()\n",
"\n",
"op_list=[\n",
" '3x3_MBConv6',\n",
" '3x3_MBConv7',\n",
" '3x3_MBConv8',\n",
" '5x5_MBConv3',\n",
" '5x5_MBConv6',\n",
" '7x7_MBConv3',\n",
" '7x7_MBConv6',\n",
"]\n",
"net = MobileNet(\n",
" first_stride=1,\n",
" op_list=op_list,\n",
" stride_stages=[1,1,1,2,1,2,1],\n",
" width_stages=[32,64,96,192,256,320,640],\n",
" classes=10,\n",
" dropout_rate=0.,\n",
" n_cell_stages=[1,2,3,4,3,3,1]\n",
")\n",
"# dm = DartsMutator(net)\n",
"# dm = RandomMutator(net)\n",
"dm = OnehotMutator(net)\n",
Expand All @@ -549,8 +560,10 @@
"outputs": [],
"source": [
"w_opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n",
"a_opt = torch.optim.Adam(dm.parameters(), lr=3e-4, weight_decay=1e-3)\n",
"# a_opt = None"
"if not isinstance(dm, RandomMutator):\n",
" a_opt = torch.optim.Adam(dm.parameters(), lr=3e-4, weight_decay=1e-3)\n",
"else:\n",
" a_opt = None"
]
},
{
Expand Down Expand Up @@ -1084,16 +1097,19 @@
"model = net.to(device)\n",
"mutator = dm.to(device)\n",
"is_rolled = False\n",
"for epoch in range(1, 50):\n",
"search_epochs = 50\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(w_opt, T_max=60)\n",
"for epoch in range(1, search_epochs):\n",
" train(train_loader, val_loader, model, mutator, criterion, w_opt, a_opt, is_rolled, device, epoch)\n",
" mask = mutator.export()\n",
" mutator.sample_by_mask(mask)\n",
" # arch = model.arch\n",
" val_loss, val_acc = validate(test_loader, model, criterion, device, verbose=True)\n",
" print(f\"acc={val_acc} loss={val_loss}\")\n",
" # print(f\"{arch} acc={val_acc} loss={val_loss}\")\n",
" torch.save(model.state_dict(), 'resnet18_supernet.pt')\n",
" torch.save(mutator.state_dict(), 'resnet18_darts.pt')"
" torch.save(model.state_dict(), 'mbnet_supernet.pt')\n",
" torch.save(mutator.state_dict(), 'mbnet_darts.pt')\n",
" scheduler.step()"
]
},
{
Expand All @@ -1114,6 +1130,24 @@
"# Export the best model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def weights2opname(weights, op_list):\n",
" \"\"\"\n",
" 将操作权重转化为操作名称\n",
" \"\"\"\n",
" opname_list = {}\n",
" for layer, weight in weights.items():\n",
" max_idx = weight.float().argmax().item()\n",
" opname = op_list[max_idx]\n",
" opname_list[layer] = opname\n",
" return opname_list"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -1137,7 +1171,8 @@
],
"source": [
"mask = mutator.export()\n",
"subnet = super(model.__class__, model).build_subnet(mask).to(device)\n",
"print(weights2opname(mask, op_list))\n",
"subnet = model.build_subnet(mask).to(device)\n",
"val_loss, val_acc = validate(test_loader, model, criterion, device, verbose=True)\n",
"# print(subnet)"
]
Expand Down Expand Up @@ -1518,15 +1553,14 @@
}
],
"source": [
"lr = 0.025 if search_epochs == 1 else 0.01\n",
"# w_opt = torch.optim.Adam(subnet.parameters(), lr=0.001, weight_decay=5e-4)\n",
"w_opt = torch.optim.SGD(subnet.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)\n",
"w_opt = torch.optim.SGD(subnet.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)\n",
"for epoch in range(1, 30):\n",
" finetune(all_train_loader, subnet, criterion, w_opt, device, epoch)\n",
" # arch = model.arch\n",
" val_loss, val_acc = validate(test_loader, subnet, criterion, device, verbose=True)\n",
" print(f\"acc={val_acc} loss={val_loss}\")\n",
" # print(f\"{arch} acc={val_acc} loss={val_loss}\")\n",
" torch.save(model.state_dict(), 'resnet18_subnet.pt')"
" torch.save(model.state_dict(), 'mbnet_subnet.pt')"
]
},
{
Expand Down Expand Up @@ -1576,8 +1610,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.7.13"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"vscode": {
"interpreter": {
Expand Down

0 comments on commit 744ee00

Please sign in to comment.