Skip to content
This repository has been archived by the owner on Oct 13, 2023. It is now read-only.

Commit

Permalink
Merge pull request #9 from georgianpartners/6_local_training
Browse files Browse the repository at this point in the history
6 Local Training
  • Loading branch information
coder46 committed Oct 13, 2020
2 parents 1d0b82e + 6489642 commit 43e53ed
Show file tree
Hide file tree
Showing 19 changed files with 612 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[run]
omit = tests/*, setup.py
4 changes: 4 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install pytest-cov
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
Expand All @@ -37,3 +38,6 @@ jobs:
- name: Test with pytest
run: |
pytest
- name: Display test coverage
run: |
pytest --cov=. tests/
6 changes: 6 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
FROM continuumio/miniconda3

ADD executor.sh /home
WORKDIR /home

ENTRYPOINT ["sh", "executor.sh"]
9 changes: 9 additions & 0 deletions docker/executor.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
mkdir project
cd project

git clone https://$OAUTH_TOKEN:x-oauth-basic@$GIT_URL .
git checkout $COMMIT_SHA

conda env create -f environment.yml

conda run -n hydra $PREFIX_PARAMS python3 $MODEL_PATH
21 changes: 21 additions & 0 deletions docker/local_execution.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

DIR="$( dirname "${BASH_SOURCE[0]}" )"

# Add random Hash
LOG_NAME=$(date +'%Y_%m_%d_%H_%M_%S')

cd $DIR
docker build -t hydra_image .

docker run \
-e GIT_URL=$1 \
-e COMMIT_SHA=$2 \
-e OAUTH_TOKEN=$3 \
-e MODEL_PATH=$4 \
-e PREFIX_PARAMS=$5 \
hydra_image:latest 2>&1 | tee ${LOG_NAME}.log

# Move Log file to where the program is being called
cd -
mkdir -p tmp/hydra
mv ${DIR}/${LOG_NAME}.log tmp/hydra/
42 changes: 34 additions & 8 deletions hydra/cli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
import os
import click
from hydra.utils import *
from hydra.cloud.local_platform import LocalPlatform
from hydra.cloud.fast_local_platform import FastLocalPlatform
from hydra.version import __version__


@click.group()
@click.version_option(__version__)
def cli():
pass

@click.command()
@click.argument('name')
def hello(name):
click.echo('Hello %s!' % name)

@cli.command()
@click.option('--project_name')
@click.option('--model_name')
@click.option('--cpu')
@click.option('--memory')
@click.option('--options')
def train(project_name, model_name, cpu, memory, options):
click.echo("This is the training command")
@click.option('-m', '--model_path', required=True, type=str)
@click.option('-c', '--cpu', default=16, type=click.IntRange(0, 128), help='Number of CPU cores required')
@click.option('-r', '--memory', default=8, type=click.IntRange(0, 128), help='GB of RAM required')
@click.option('--cloud', default='local', required=True, type=click.Choice(['fast_local','local', 'aws', 'gcp', 'azure'], case_sensitive=False))
@click.option('--github_token', envvar='GITHUB_TOKEN') # Takes either an option or environment var
@click.option('-o', '--options', default='{}', type=str, help='Environmental variables for the script')
def train(model_path, cpu, memory, github_token, cloud, options):
prefix_params = json_to_string(options)

if cloud == 'fast_local':
platform = FastLocalPlatform(model_path, prefix_params)
platform.train()

return 0

check_repo(github_token)
git_url = get_repo_url()
commit_sha = get_commit_sha()

if cloud == 'local':
platform = LocalPlatform(model_path, prefix_params, git_url, commit_sha, github_token)
platform.train()

return 0

raise Exception("Reached parts of Hydra that are not yet implemented.")
Empty file added hydra/cloud/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions hydra/cloud/abstract_platform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@


class AbstractPlatform():
def __init__(self, model_path, prefix_params):
self.model_path = model_path
self.prefix_params = prefix_params

def train(self):
raise Exception("Not Implemented: Please implement this function in the subclass.")

def serve(self):
raise Exception("Not Implemented: Please implement this function in the subclass.")
13 changes: 13 additions & 0 deletions hydra/cloud/fast_local_platform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os
from hydra.cloud.abstract_platform import AbstractPlatform

class FastLocalPlatform(AbstractPlatform):
def __init__(self, model_path, prefix_params):
super().__init__(model_path, prefix_params)

def train(self):
os.system(" ".join([self.prefix_params, 'python3', self.model_path]))
return 0

def serve(self):
pass
14 changes: 14 additions & 0 deletions hydra/cloud/google_cloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from hydra.cloud.abstract_platform import AbstractPlatform

class GoogleCloud(AbstractPlatform):
def __init__(self, model_path, prefix_params, git_url, commit_sha, github_token):
self.git_url = git_url
self.commit_sha = commit_sha
self.github_token = github_token
super().__init__(model_path, prefix_params)

def train(self):
pass

def serve(self):
pass
21 changes: 21 additions & 0 deletions hydra/cloud/local_platform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
import subprocess
from hydra.cloud.abstract_platform import AbstractPlatform

class LocalPlatform(AbstractPlatform):
def __init__(self, model_path, prefix_params, git_url, commit_sha, github_token):
self.git_url = git_url
self.commit_sha = commit_sha
self.github_token = github_token
super().__init__(model_path, prefix_params)

def train(self):
execution_script_path = os.path.join(os.path.dirname(__file__), '../../docker/local_execution.sh')
command = ['sh', execution_script_path, self.git_url, self.commit_sha,
self.github_token, self.model_path, self.prefix_params]

subprocess.run(command)
return 0

def serve(self):
pass
21 changes: 21 additions & 0 deletions hydra/git_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

class GitRepo():
def __init__(self, repo):
self.repo = repo

def is_empty(self):
return self.repo.bare

def is_untracked(self):
return len(self.repo.untracked_files) > 0

def is_modified(self):
return len(self.repo.index.diff(None)) > 0

def is_uncommitted(self):
return len(self.repo.index.diff("HEAD")) > 0

def is_unsynced(self):
branch_name = self.repo.active_branch.name
count_unpushed_commits = len(list(self.repo.iter_commits('origin/{}..{}'.format(branch_name, branch_name))))
return count_unpushed_commits > 0
55 changes: 55 additions & 0 deletions hydra/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import re
import os
import git
import json
import warnings
import subprocess
from collections import OrderedDict
from hydra.git_repo import GitRepo


def json_to_string(packet):
od = json.loads(packet, object_pairs_hook=OrderedDict)

params = ""
for key, value in od.items():
params += key + "=" + str(value) + " "

return params.strip()


def get_repo_url():
git_url = subprocess.check_output("git config --get remote.origin.url", shell=True).decode("utf-8").strip()
git_url = re.compile(r"https?://(www\.)?").sub("", git_url).strip().strip('/')
return git_url


def get_commit_sha():
commit_sha = subprocess.check_output("git log --pretty=tformat:'%h' -n1 .", shell=True).decode("utf-8").strip()
return commit_sha


def check_repo(github_token, repo=None):
if github_token == None:
raise Exception("GITHUB_TOKEN not found in environment variable or as argument.")

if repo is None:
repo = git.Repo(os.getcwd())
repo = GitRepo(repo)

if repo.is_empty():
raise Exception("Hydra is not being called in the root of a git repo.")

if repo.is_untracked():
warnings.warn("Some files are not tracked by git.", UserWarning)

if repo.is_modified():
raise Exception("Some modified files are not staged for commit.")

if repo.is_uncommitted():
raise Exception("Some staged files are not commited.")

if repo.is_unsynced():
raise Exception("Some commits are not pushed to the remote repo.")

return 0
2 changes: 1 addition & 1 deletion hydra/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.0'
__version__ = '0.1.0'
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
click==7.1.2
click==7.1.2
pytest==6.1.1
pytest_mock==3.3.1
GitPython==3.1.9
56 changes: 56 additions & 0 deletions tests/test_cil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
from hydra.cli import *
from click.testing import CliRunner

VALID_MODEL_PATH = "d3bug.py"
VALID_REPO_URL = "https://georgian.io/"
VALID_COMMIT_SHA = "m1rr0r1ng"
VALID_FILE_PATH = "ones/and/zer0es"
VALID_GITHUB_TOKEN = "Georgian"
VALID_PREFIX_PARAMS = "{'epoch': 88}"

def test_hello_world():
runner = CliRunner()
result = runner.invoke(hello, ['Peter'])
assert result.exit_code == 0
assert result.output == 'Hello Peter!\n'

def test_train_local(mocker):
def stub(dummy):
pass

mocker.patch(
"hydra.cli.check_repo",
stub
)
mocker.patch(
"hydra.cli.get_repo_url",
return_value=VALID_REPO_URL
)
mocker.patch(
"hydra.cli.get_commit_sha",
return_value=VALID_COMMIT_SHA
)
mocker.patch(
"hydra.cli.os.path.join",
return_value=VALID_FILE_PATH
)
mocker.patch(
"hydra.cli.json_to_string",
return_value=VALID_PREFIX_PARAMS
)

mocker.patch(
'hydra.cli.subprocess.run',
)

runner = CliRunner()
result = runner.invoke(train, ['--model_path', VALID_MODEL_PATH, '--cloud', 'local', '--github_token', VALID_GITHUB_TOKEN])


subprocess.run.assert_called_once_with(
['sh', VALID_FILE_PATH,
VALID_REPO_URL, VALID_COMMIT_SHA, VALID_GITHUB_TOKEN,
VALID_MODEL_PATH, VALID_PREFIX_PARAMS])

assert result.exit_code == 0
4 changes: 0 additions & 4 deletions tests/test_dummy.py

This file was deleted.

Loading

0 comments on commit 43e53ed

Please sign in to comment.