Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0f20678
set max split size
priyakasimbeg Sep 26, 2023
6cf192a
tune max split size
priyakasimbeg Sep 28, 2023
b2f8ff9
typo
priyakasimbeg Sep 28, 2023
70625b0
add back deleted block
priyakasimbeg Sep 28, 2023
179abba
undo disable torch compile for conformer
priyakasimbeg Sep 28, 2023
a2beafb
remove whitespace
priyakasimbeg Sep 28, 2023
255d835
remove training whitespace
priyakasimbeg Sep 28, 2023
b7f4cbc
isort fix
priyakasimbeg Sep 28, 2023
bb29602
formatting
priyakasimbeg Sep 28, 2023
3738f35
print step hint
priyakasimbeg Sep 28, 2023
0e4dd85
make pytorch cuda alloc config specific to conformer
priyakasimbeg Oct 6, 2023
da89a8b
tune max split size
priyakasimbeg Oct 6, 2023
416b88d
fix
priyakasimbeg Oct 7, 2023
4600d78
reduce max split size
priyakasimbeg Oct 7, 2023
7a764e1
move env var
priyakasimbeg Oct 7, 2023
1dbf3e4
logging
priyakasimbeg Oct 7, 2023
04f5c94
debugging
priyakasimbeg Oct 7, 2023
b0b9f40
debugging
priyakasimbeg Oct 7, 2023
318202e
debug logging
priyakasimbeg Oct 7, 2023
3cec8c5
update
priyakasimbeg Oct 7, 2023
4fc6e1c
update_logging
priyakasimbeg Oct 7, 2023
557bf0d
fix
priyakasimbeg Oct 7, 2023
2598d39
fix
priyakasimbeg Oct 7, 2023
9418f4f
fix
priyakasimbeg Oct 7, 2023
931337d
remove logging
priyakasimbeg Oct 9, 2023
aeed475
revert checkpoint utils debugging
priyakasimbeg Oct 9, 2023
7098843
extend max_allowed_runtime_sec for conformer
priyakasimbeg Oct 9, 2023
cb68dba
Merge branch 'dev' into conformer_oom_debugging_2
priyakasimbeg Oct 9, 2023
24edc3b
Merge branch 'dev' into conformer_oom_debugging_2
priyakasimbeg Oct 11, 2023
09ceeec
remove conformer oom fixes from this branch
priyakasimbeg Oct 11, 2023
a0b624e
lint
priyakasimbeg Oct 11, 2023
061d5b3
pr feedback
priyakasimbeg Oct 13, 2023
a4bb0f0
isort
priyakasimbeg Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shutil
import subprocess
import sys
from typing import Any, Optional
from typing import Any, Dict, Optional

from absl import flags
from clu import metric_writers
Expand Down Expand Up @@ -96,14 +96,14 @@ def write_hparams(hparams: spec.Hyperparameters,
return hparams


def write_json(name: str, log_dict: dict, indent: int = 2) -> None:
def write_json(name: str, log_dict: Dict, indent: int = 2) -> None:
if RANK == 0:
with open(name, 'w') as f:
f.write(json.dumps(log_dict, indent=indent))


def write_to_csv(
metrics: dict,
metrics: Dict,
csv_path: str,
) -> None:
try:
Expand All @@ -120,7 +120,7 @@ def write_to_csv(
return


def _get_utilization() -> dict:
def _get_utilization() -> Dict:
util_data = {}

# CPU
Expand Down Expand Up @@ -180,7 +180,7 @@ def _get_utilization() -> dict:
return util_data


def _get_system_hardware_info() -> dict:
def _get_system_hardware_info() -> Dict:
system_hardware_info = {}
try:
system_hardware_info['cpu_model_name'] = _get_cpu_model_name()
Expand All @@ -200,7 +200,7 @@ def _get_system_hardware_info() -> dict:
return system_hardware_info


def _get_system_software_info() -> dict:
def _get_system_software_info() -> Dict:
system_software_info = {}

system_software_info['os_platform'] = \
Expand Down Expand Up @@ -243,7 +243,7 @@ def _is_primitive_type(item: Any) -> bool:
return isinstance(item, primitive)


def _get_workload_properties(workload: spec.Workload) -> dict:
def _get_workload_properties(workload: spec.Workload) -> Dict:
workload_properties = {}
skip_list = ['param_shapes', 'model_params_types']
keys = [
Expand All @@ -262,7 +262,8 @@ def _get_workload_properties(workload: spec.Workload) -> dict:
return workload_properties


def get_meta_data(workload: spec.Workload) -> dict:
def get_meta_data(workload: spec.Workload,
rng_seed: Optional[int] = None) -> Dict:
meta_data = {}
workload_properties = _get_workload_properties(workload)
meta_data.update(workload_properties)
Expand All @@ -272,15 +273,11 @@ def get_meta_data(workload: spec.Workload) -> dict:
meta_data.update(system_software_info)
system_hardware_info = _get_system_hardware_info()
meta_data.update(system_hardware_info)
if rng_seed is not None:
meta_data.update({'rng_seed': rng_seed})
return meta_data


def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str):
meta_data = get_meta_data(workload)
meta_data.update({'rng_seed': rng_seed})
write_json(meta_file_name, meta_data)


class MetricLogger(object):
"""Used to log all measurements during training.

Expand Down Expand Up @@ -308,7 +305,7 @@ def __init__(self,
wandb.config.update(hyperparameters._asdict())

def append_scalar_metrics(self,
metrics: dict,
metrics: Dict,
global_step: int,
preemption_count: Optional[int] = None,
is_eval: bool = False) -> None:
Expand Down
4 changes: 2 additions & 2 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def train_once(
else:
logging.info('Performing `torch.compile`.')
model_params = torch.compile(model_params)

logging.info('Initializing optimizer.')
with profiler.profile('Initializing optimizer'):
optimizer_state = init_optimizer_state(workload,
Expand Down Expand Up @@ -284,7 +283,8 @@ def train_once(
checkpoint_dir=log_dir)
meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json')
logging.info(f'Saving meta data to {meta_file_name}.')
logger_utils.save_meta_data(workload, rng_seed, preemption_count)
meta_data = logger_utils.get_meta_data(workload, rng_seed)
logger_utils.write_json(meta_file_name, meta_data)
flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json')
logging.info(f'Saving flags to {flag_file_name}.')
logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict())
Expand Down