diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 2f808b64b..f9ee2f138 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -388,7 +388,7 @@ def download_fastmri(data_dir, def extract(source, dest, mode='r:xz'): if not os.path.exists(dest): - os.path.makedirs(dest) + os.makedirs(dest) logging.info(f'Extracting {source} to {dest}') tar = tarfile.open(source, mode) logging.info('Opened tar') diff --git a/setup.cfg b/setup.cfg index 9aa4ffb5f..a00da91fc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -132,13 +132,13 @@ jax_gpu = # PyTorch CPU pytorch_cpu = - torch==2.0.1 - torchvision==0.15.2 + torch==2.1.0 + torchvision==0.16.0 # PyTorch GPU pytorch_gpu = - torch==2.0.1+cu118 - torchvision==0.15.2+cu118 + torch==2.1.0+cu118 + torchvision==0.16.0+cu118 # wandb wandb = diff --git a/submission_runner.py b/submission_runner.py index 7f6150d19..12494cd6e 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -220,7 +220,9 @@ def train_once( model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: - compile_error_workloads = ['librispeech_conformer', 'ogbg', 'criteo1tb'] + compile_error_workloads = [ + 'librispeech_conformer', 'ogbg', 'criteo1tb', 'imagenet_vit' + ] eager_backend_workloads = ['librispeech_deepspeech'] aot_eager_backend_workloads = [] if FLAGS.workload in compile_error_workloads: