This repository has been archived by the owner on Oct 13, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from georgianpartners/6_local_training
6 Local Training
- Loading branch information
Showing
19 changed files
with
612 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[run] | ||
omit = tests/*, setup.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
FROM continuumio/miniconda3 | ||
|
||
ADD executor.sh /home | ||
WORKDIR /home | ||
|
||
ENTRYPOINT ["sh", "executor.sh"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
mkdir project | ||
cd project | ||
|
||
git clone https://$OAUTH_TOKEN:x-oauth-basic@$GIT_URL . | ||
git checkout $COMMIT_SHA | ||
|
||
conda env create -f environment.yml | ||
|
||
conda run -n hydra $PREFIX_PARAMS python3 $MODEL_PATH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
DIR="$( dirname "${BASH_SOURCE[0]}" )" | ||
|
||
# Add random Hash | ||
LOG_NAME=$(date +'%Y_%m_%d_%H_%M_%S') | ||
|
||
cd $DIR | ||
docker build -t hydra_image . | ||
|
||
docker run \ | ||
-e GIT_URL=$1 \ | ||
-e COMMIT_SHA=$2 \ | ||
-e OAUTH_TOKEN=$3 \ | ||
-e MODEL_PATH=$4 \ | ||
-e PREFIX_PARAMS=$5 \ | ||
hydra_image:latest 2>&1 | tee ${LOG_NAME}.log | ||
|
||
# Move Log file to where the program is being called | ||
cd - | ||
mkdir -p tmp/hydra | ||
mv ${DIR}/${LOG_NAME}.log tmp/hydra/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,44 @@ | ||
import os | ||
import click | ||
from hydra.utils import * | ||
from hydra.cloud.local_platform import LocalPlatform | ||
from hydra.cloud.fast_local_platform import FastLocalPlatform | ||
from hydra.version import __version__ | ||
|
||
|
||
@click.group() | ||
@click.version_option(__version__) | ||
def cli(): | ||
pass | ||
|
||
@click.command() | ||
@click.argument('name') | ||
def hello(name): | ||
click.echo('Hello %s!' % name) | ||
|
||
@cli.command() | ||
@click.option('--project_name') | ||
@click.option('--model_name') | ||
@click.option('--cpu') | ||
@click.option('--memory') | ||
@click.option('--options') | ||
def train(project_name, model_name, cpu, memory, options): | ||
click.echo("This is the training command") | ||
@click.option('-m', '--model_path', required=True, type=str) | ||
@click.option('-c', '--cpu', default=16, type=click.IntRange(0, 128), help='Number of CPU cores required') | ||
@click.option('-r', '--memory', default=8, type=click.IntRange(0, 128), help='GB of RAM required') | ||
@click.option('--cloud', default='local', required=True, type=click.Choice(['fast_local','local', 'aws', 'gcp', 'azure'], case_sensitive=False)) | ||
@click.option('--github_token', envvar='GITHUB_TOKEN') # Takes either an option or environment var | ||
@click.option('-o', '--options', default='{}', type=str, help='Environmental variables for the script') | ||
def train(model_path, cpu, memory, github_token, cloud, options): | ||
prefix_params = json_to_string(options) | ||
|
||
if cloud == 'fast_local': | ||
platform = FastLocalPlatform(model_path, prefix_params) | ||
platform.train() | ||
|
||
return 0 | ||
|
||
check_repo(github_token) | ||
git_url = get_repo_url() | ||
commit_sha = get_commit_sha() | ||
|
||
if cloud == 'local': | ||
platform = LocalPlatform(model_path, prefix_params, git_url, commit_sha, github_token) | ||
platform.train() | ||
|
||
return 0 | ||
|
||
raise Exception("Reached parts of Hydra that are not yet implemented.") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
|
||
|
||
class AbstractPlatform(): | ||
def __init__(self, model_path, prefix_params): | ||
self.model_path = model_path | ||
self.prefix_params = prefix_params | ||
|
||
def train(self): | ||
raise Exception("Not Implemented: Please implement this function in the subclass.") | ||
|
||
def serve(self): | ||
raise Exception("Not Implemented: Please implement this function in the subclass.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import os | ||
from hydra.cloud.abstract_platform import AbstractPlatform | ||
|
||
class FastLocalPlatform(AbstractPlatform): | ||
def __init__(self, model_path, prefix_params): | ||
super().__init__(model_path, prefix_params) | ||
|
||
def train(self): | ||
os.system(" ".join([self.prefix_params, 'python3', self.model_path])) | ||
return 0 | ||
|
||
def serve(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from hydra.cloud.abstract_platform import AbstractPlatform | ||
|
||
class GoogleCloud(AbstractPlatform): | ||
def __init__(self, model_path, prefix_params, git_url, commit_sha, github_token): | ||
self.git_url = git_url | ||
self.commit_sha = commit_sha | ||
self.github_token = github_token | ||
super().__init__(model_path, prefix_params) | ||
|
||
def train(self): | ||
pass | ||
|
||
def serve(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import os | ||
import subprocess | ||
from hydra.cloud.abstract_platform import AbstractPlatform | ||
|
||
class LocalPlatform(AbstractPlatform): | ||
def __init__(self, model_path, prefix_params, git_url, commit_sha, github_token): | ||
self.git_url = git_url | ||
self.commit_sha = commit_sha | ||
self.github_token = github_token | ||
super().__init__(model_path, prefix_params) | ||
|
||
def train(self): | ||
execution_script_path = os.path.join(os.path.dirname(__file__), '../../docker/local_execution.sh') | ||
command = ['sh', execution_script_path, self.git_url, self.commit_sha, | ||
self.github_token, self.model_path, self.prefix_params] | ||
|
||
subprocess.run(command) | ||
return 0 | ||
|
||
def serve(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
class GitRepo(): | ||
def __init__(self, repo): | ||
self.repo = repo | ||
|
||
def is_empty(self): | ||
return self.repo.bare | ||
|
||
def is_untracked(self): | ||
return len(self.repo.untracked_files) > 0 | ||
|
||
def is_modified(self): | ||
return len(self.repo.index.diff(None)) > 0 | ||
|
||
def is_uncommitted(self): | ||
return len(self.repo.index.diff("HEAD")) > 0 | ||
|
||
def is_unsynced(self): | ||
branch_name = self.repo.active_branch.name | ||
count_unpushed_commits = len(list(self.repo.iter_commits('origin/{}..{}'.format(branch_name, branch_name)))) | ||
return count_unpushed_commits > 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import re | ||
import os | ||
import git | ||
import json | ||
import warnings | ||
import subprocess | ||
from collections import OrderedDict | ||
from hydra.git_repo import GitRepo | ||
|
||
|
||
def json_to_string(packet): | ||
od = json.loads(packet, object_pairs_hook=OrderedDict) | ||
|
||
params = "" | ||
for key, value in od.items(): | ||
params += key + "=" + str(value) + " " | ||
|
||
return params.strip() | ||
|
||
|
||
def get_repo_url(): | ||
git_url = subprocess.check_output("git config --get remote.origin.url", shell=True).decode("utf-8").strip() | ||
git_url = re.compile(r"https?://(www\.)?").sub("", git_url).strip().strip('/') | ||
return git_url | ||
|
||
|
||
def get_commit_sha(): | ||
commit_sha = subprocess.check_output("git log --pretty=tformat:'%h' -n1 .", shell=True).decode("utf-8").strip() | ||
return commit_sha | ||
|
||
|
||
def check_repo(github_token, repo=None): | ||
if github_token == None: | ||
raise Exception("GITHUB_TOKEN not found in environment variable or as argument.") | ||
|
||
if repo is None: | ||
repo = git.Repo(os.getcwd()) | ||
repo = GitRepo(repo) | ||
|
||
if repo.is_empty(): | ||
raise Exception("Hydra is not being called in the root of a git repo.") | ||
|
||
if repo.is_untracked(): | ||
warnings.warn("Some files are not tracked by git.", UserWarning) | ||
|
||
if repo.is_modified(): | ||
raise Exception("Some modified files are not staged for commit.") | ||
|
||
if repo.is_uncommitted(): | ||
raise Exception("Some staged files are not commited.") | ||
|
||
if repo.is_unsynced(): | ||
raise Exception("Some commits are not pushed to the remote repo.") | ||
|
||
return 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = '0.1.0' | ||
__version__ = '0.1.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
click==7.1.2 | ||
click==7.1.2 | ||
pytest==6.1.1 | ||
pytest_mock==3.3.1 | ||
GitPython==3.1.9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import pytest | ||
from hydra.cli import * | ||
from click.testing import CliRunner | ||
|
||
VALID_MODEL_PATH = "d3bug.py" | ||
VALID_REPO_URL = "https://georgian.io/" | ||
VALID_COMMIT_SHA = "m1rr0r1ng" | ||
VALID_FILE_PATH = "ones/and/zer0es" | ||
VALID_GITHUB_TOKEN = "Georgian" | ||
VALID_PREFIX_PARAMS = "{'epoch': 88}" | ||
|
||
def test_hello_world(): | ||
runner = CliRunner() | ||
result = runner.invoke(hello, ['Peter']) | ||
assert result.exit_code == 0 | ||
assert result.output == 'Hello Peter!\n' | ||
|
||
def test_train_local(mocker): | ||
def stub(dummy): | ||
pass | ||
|
||
mocker.patch( | ||
"hydra.cli.check_repo", | ||
stub | ||
) | ||
mocker.patch( | ||
"hydra.cli.get_repo_url", | ||
return_value=VALID_REPO_URL | ||
) | ||
mocker.patch( | ||
"hydra.cli.get_commit_sha", | ||
return_value=VALID_COMMIT_SHA | ||
) | ||
mocker.patch( | ||
"hydra.cli.os.path.join", | ||
return_value=VALID_FILE_PATH | ||
) | ||
mocker.patch( | ||
"hydra.cli.json_to_string", | ||
return_value=VALID_PREFIX_PARAMS | ||
) | ||
|
||
mocker.patch( | ||
'hydra.cli.subprocess.run', | ||
) | ||
|
||
runner = CliRunner() | ||
result = runner.invoke(train, ['--model_path', VALID_MODEL_PATH, '--cloud', 'local', '--github_token', VALID_GITHUB_TOKEN]) | ||
|
||
|
||
subprocess.run.assert_called_once_with( | ||
['sh', VALID_FILE_PATH, | ||
VALID_REPO_URL, VALID_COMMIT_SHA, VALID_GITHUB_TOKEN, | ||
VALID_MODEL_PATH, VALID_PREFIX_PARAMS]) | ||
|
||
assert result.exit_code == 0 |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.