-
Notifications
You must be signed in to change notification settings - Fork 2
142 lines (129 loc) Β· 4.78 KB
/
train-agent.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
name: Start of agent training
on:
workflow_dispatch:
inputs:
training_location:
description: Train 'local' or 'gcp'?
required: true
default: 'gcp'
n_envs:
description: Number of parallel environments?
required: true
default: '8'
env_kwargs:
description: Optional **kwargs for constructor?
required: true
default: '{}'
save_freq:
description: Model storage at each n-step?
required: true
default: '500000'
eval_freq:
description: Model evaluation at each n-step?
required: true
default: '50000'
total_timesteps:
description: Total steps to train?
required: true
default: '5000000'
n_steps:
description: Model update after each n-step?
required: true
default: '4096'
policy_kwargs:
description: Additional arguments for policy creation?
required: true
default: '{}'
learning_rate:
description: Learning rate?
required: true
default: '0.0003'
gamma:
description: Discount factor?
required: true
default: '0.99'
jobs:
gcp:
# Final training should be here, but it could be expensive
if: ${{ github.event.inputs.training_location == 'gcp' }}
name: Create custom Vertex AI job
permissions:
contents: read
id-token: write
runs-on: ubuntu-latest
env:
PROJECT_ID: ${{ secrets.PROJECT_ID }}
BUCKET_NAME: super-mario-bros
REGION: us-east1
MACHINE_TYPE: e2-standard-4
IMAGE_URI: us-east1-docker.pkg.dev/${{ secrets.PROJECT_ID }}/super-mario-bros/train:latest
steps:
- name: Checkout
uses: actions/checkout@v3
- id: auth
name: Authenticate to Google Cloud
uses: google-github-actions/auth@v1
with:
workload_identity_provider: ${{ secrets.GOOGLE_WORKLOAD_IDENTITY_PROVIDER }}
service_account: ${{ secrets.GOOGLE_SERVICE_ACCOUNT }}
- name: Setup the Google Cloud SDK
uses: google-github-actions/setup-gcloud@v1
- name: Create Cloud Storage bucket
continue-on-error: true
run: |
gcloud storage buckets create "gs://$BUCKET_NAME" \
--location="$REGION" \
--public-access-prevention \
--uniform-bucket-level-access
- name: Create custom Vertex AI job
run: |
gcloud ai custom-jobs create \
--project="$PROJECT_ID" \
--region="$REGION" \
--display-name="smb-train" \
--worker-pool-spec="replica-count=1,machine-type=$MACHINE_TYPE,container-image-uri=$IMAGE_URI" \
--args=--n_envs=${{ github.event.inputs.n_envs }} \
--args=^*^--env_kwargs=${{ github.event.inputs.env_kwargs }} \
--args=--save_freq=${{ github.event.inputs.save_freq }} \
--args=--eval_freq=${{ github.event.inputs.eval_freq }} \
--args=--total_timesteps=${{ github.event.inputs.total_timesteps }} \
--args=--n_steps=${{ github.event.inputs.n_steps }} \
--args=--policy_kwargs=${{ github.event.inputs.policy_kwargs }} \
--args=--learning_rate=${{ github.event.inputs.learning_rate }} \
--args=--gamma=${{ github.event.inputs.gamma }}
local:
# Training here should be short and used for testing purposes only
if: ${{ github.event.inputs.training_location == 'local' }}
name: Train locally on GitHub Action
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.8.15"
- name: Install dependencies
run: pip install -r requirements.txt
- name: Create artifact folder
run: mkdir -p data
- name: Train locally on GitHub Action
run: |
python super_mario_bros/train.py \
--training_location=${{ github.event.inputs.training_location }} \
--artifact_folder=data \
--n_envs=${{ github.event.inputs.n_envs }} \
--env_kwargs=${{ github.event.inputs.env_kwargs }} \
--save_freq=${{ github.event.inputs.save_freq }} \
--eval_freq=${{ github.event.inputs.eval_freq }} \
--total_timesteps=${{ github.event.inputs.total_timesteps }} \
--n_steps=${{ github.event.inputs.n_steps }} \
--policy_kwargs=${{ github.event.inputs.policy_kwargs }} \
--learning_rate=${{ github.event.inputs.learning_rate }} \
--gamma=${{ github.event.inputs.gamma }}
- name: Upload artifacts
uses: actions/upload-artifact@main
with:
name: artifacts
path: data
retention-days: 3