|
86 | 86 | ) from e
|
87 | 87 | logger.info(f"torch_device overrode to {torch_device}")
|
88 | 88 | else:
|
89 |
| - torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
| 89 | + if torch.cuda.is_available(): |
| 90 | + torch_device = "cuda" |
| 91 | + elif torch.xpu.is_available(): |
| 92 | + torch_device = "xpu" |
| 93 | + else: |
| 94 | + torch_device = "cpu" |
90 | 95 | is_torch_higher_equal_than_1_12 = version.parse(
|
91 | 96 | version.parse(torch.__version__).base_version
|
92 | 97 | ) >= version.parse("1.12")
|
@@ -1067,12 +1072,51 @@ def _is_torch_fp64_available(device):
|
1067 | 1072 | # Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
|
1068 | 1073 | if is_torch_available():
|
1069 | 1074 | # Behaviour flags
|
1070 |
| - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True} |
| 1075 | + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} |
1071 | 1076 |
|
1072 | 1077 | # Function definitions
|
1073 |
| - BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None} |
1074 |
| - BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0} |
1075 |
| - BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} |
| 1078 | + BACKEND_EMPTY_CACHE = { |
| 1079 | + "cuda": torch.cuda.empty_cache, |
| 1080 | + "xpu": torch.xpu.empty_cache, |
| 1081 | + "cpu": None, |
| 1082 | + "mps": torch.mps.empty_cache, |
| 1083 | + "default": None, |
| 1084 | + } |
| 1085 | + BACKEND_DEVICE_COUNT = { |
| 1086 | + "cuda": torch.cuda.device_count, |
| 1087 | + "xpu": torch.xpu.device_count, |
| 1088 | + "cpu": lambda: 0, |
| 1089 | + "mps": lambda: 0, |
| 1090 | + "default": 0, |
| 1091 | + } |
| 1092 | + BACKEND_MANUAL_SEED = { |
| 1093 | + "cuda": torch.cuda.manual_seed, |
| 1094 | + "xpu": torch.xpu.manual_seed, |
| 1095 | + "cpu": torch.manual_seed, |
| 1096 | + "mps": torch.mps.manual_seed, |
| 1097 | + "default": torch.manual_seed, |
| 1098 | + } |
| 1099 | + BACKEND_RESET_PEAK_MEMORY_STATS = { |
| 1100 | + "cuda": torch.cuda.reset_peak_memory_stats, |
| 1101 | + "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), |
| 1102 | + "cpu": None, |
| 1103 | + "mps": None, |
| 1104 | + "default": None, |
| 1105 | + } |
| 1106 | + BACKEND_RESET_MAX_MEMORY_ALLOCATED = { |
| 1107 | + "cuda": torch.cuda.reset_max_memory_allocated, |
| 1108 | + "xpu": None, |
| 1109 | + "cpu": None, |
| 1110 | + "mps": None, |
| 1111 | + "default": None, |
| 1112 | + } |
| 1113 | + BACKEND_MAX_MEMORY_ALLOCATED = { |
| 1114 | + "cuda": torch.cuda.max_memory_allocated, |
| 1115 | + "xpu": getattr(torch.xpu, "max_memory_allocated", None), |
| 1116 | + "cpu": 0, |
| 1117 | + "mps": 0, |
| 1118 | + "default": 0, |
| 1119 | + } |
1076 | 1120 |
|
1077 | 1121 |
|
1078 | 1122 | # This dispatches a defined function according to the accelerator from the function definitions.
|
@@ -1103,6 +1147,18 @@ def backend_device_count(device: str):
|
1103 | 1147 | return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
|
1104 | 1148 |
|
1105 | 1149 |
|
| 1150 | +def backend_reset_peak_memory_stats(device: str): |
| 1151 | + return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS) |
| 1152 | + |
| 1153 | + |
| 1154 | +def backend_reset_max_memory_allocated(device: str): |
| 1155 | + return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED) |
| 1156 | + |
| 1157 | + |
| 1158 | +def backend_max_memory_allocated(device: str): |
| 1159 | + return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED) |
| 1160 | + |
| 1161 | + |
1106 | 1162 | # These are callables which return boolean behaviour flags and can be used to specify some
|
1107 | 1163 | # device agnostic alternative where the feature is unsupported.
|
1108 | 1164 | def backend_supports_training(device: str):
|
@@ -1159,3 +1215,6 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name
|
1159 | 1215 | update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
|
1160 | 1216 | update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
|
1161 | 1217 | update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
|
| 1218 | + update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN") |
| 1219 | + update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN") |
| 1220 | + update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN") |
0 commit comments