-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
122 lines (102 loc) · 3.96 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
import sys
import json
sys.path.insert(0, 'src/data')
sys.path.insert(0, 'src/main')
sys.path.insert(0, 'src/helper')
import pandas as pd
from animation_loader import AnimationLoader
import etl
import features
import models
import forecast
import clean
def print_df_indented(df_string, indent):
"""
Helper function to print a DataFrame with indentation.
Parameters
----------
df_string: str
String representation of the DataFrame.
indent: str
Indentation string.
"""
end_of_col_names = df_string.index('\n')
col_names = df_string[:end_of_col_names]
remaining_rows = df_string[end_of_col_names + 1:]
print(indent, col_names)
print(indent, '-'*len(col_names))
for row in remaining_rows.split('\n'):
print(indent, row)
def main(targets):
"""
Runs the main project pipeline logic, given the targets.
targets must contain: 'all'.
"""
for target in targets:
if target not in ['data', 'features', 'models', 'forecast', 'clean', 'all']:
raise Exception('TargetNotFoundException: input at least one valid target')
spinner_animation = AnimationLoader()
indent = ' '
with open('config/config.json', 'r') as fh:
params = json.load(fh)
run_all = False
if 'all' in targets: # target options: data, features, models, forecast, clean, all
run_all = True
if run_all or ('data' in targets):
curr_task = 'loading data:'
print()
spinner_animation.show(curr_task, finish_message=f'{curr_task} done', failed_message=f'{curr_task} failed')
etl.run(params)
spinner_animation.finished = True
print()
print(indent, "loaded data located at 'src/data/temp'")
if run_all or ('features' in targets):
curr_task = 'transforming features:'
print()
spinner_animation.show(curr_task, finish_message=f'{curr_task} done', failed_message=f'{curr_task} failed')
features.run()
spinner_animation.finished = True
print()
print(indent, "proccessed data located at 'src/data/temp'")
if run_all or ('models' in targets):
curr_task = 'training-evaluating models:'
print()
spinner_animation.show(curr_task, finish_message=f'{curr_task} done', failed_message=f'{curr_task} failed')
models.run()
spinner_animation.finished = True
print()
model_evaluations_df = pd.read_csv('src/data/temp/model_evaluations.csv')
model_evals_df_str = model_evaluations_df.to_string(index=False)
print_df_indented(model_evals_df_str, indent)
print()
print(indent, "3-year out forecast located at 'out/plots/final_forecasts.jpg'")
if run_all or ('forecast' in targets):
print()
forecast_year = int(input(indent + 'generate lstm forecasts up to (ex: 2050): '))
curr_task = 'generating forecasts:'
print()
spinner_animation.show(curr_task, finish_message=f'{curr_task} done', failed_message=f'{curr_task} failed')
try:
forecast.run(forecast_year)
spinner_animation.finished = True
except Exception as e:
spinner_animation.failed = True
print(e)
return ...
print()
print(indent, f"non-lstm forecast plots located at 'out/plots/model_year_zcta_forecast.jpg'")
print()
print(indent, f"non-lstm forecast tables located at 'out/forecast_tables/model_zcta_forecast.csv'")
print()
print(indent, f"lstm forecasts out to {forecast_year} located at 'out/plots/feedback_{forecast_year}_forecasts.jpg'")
if 'clean' in targets:
curr_task = 'removing temporary files:'
print()
spinner_animation.show(curr_task, finish_message=f'{curr_task} done', failed_message=f'{curr_task} failed')
clean.run()
spinner_animation.finished = True
print()
if __name__ == '__main__':
targets = sys.argv[1:]
main(targets)