Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Argparse support added. #188

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 21 additions & 10 deletions examples/src/tasks/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms as T

from gradsflow import AutoImageClassifier
from gradsflow.data.common import random_split_dataset

# Replace dataloaders with your custom dataset and you are all set to train your model
image_size = (64, 64)
batch_size = 4

to_rgb = lambda x: x.convert("RGB")
def main(image_size=(64, 64), batch_size=4, max_epochs=5, optimization_metric="train_loss", max_steps=1, n_trials=1):
# Replace dataloaders with your custom dataset and you are all set to train your model

# TODO: Add argument parser
if __name__ == "__main__":
to_rgb = lambda x: x.convert("RGB")
augs = T.Compose([to_rgb, T.AutoAugment(), T.Resize(image_size), T.ToTensor()])
data = torchvision.datasets.CIFAR10("~/data", download=True, transform=augs)
train_data, val_data = random_split_dataset(data, 0.01)
Expand All @@ -39,11 +38,23 @@
train_dataloader=train_dl,
val_dataloader=val_dl,
num_classes=num_classes,
max_epochs=5,
optimization_metric="train_loss",
max_steps=1,
n_trials=1,
max_epochs=max_epochs,
optimization_metric=optimization_metric,
max_steps=max_steps,
n_trials=n_trials,
)
print("AutoImageClassifier initialised!")

model.hp_tune()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AutoImageClassifier")
parser.add_argument("--image_size", type=tuple, default=(64, 64), help="image size")
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
parser.add_argument("--max_epochs", type=int, default=5, help="max epochs")
parser.add_argument("--optimization_metric", type=str, default="train_loss", help="optimization metric")
parser.add_argument("--max_steps", type=int, default=1, help="max steps")
parser.add_argument("--n_trials", type=int, default=1, help="number of trials")
args = parser.parse_args()
main(**vars(args))