diff --git a/requirements-npu.txt b/requirements-npu.txt new file mode 100644 index 000000000..8130ddb40 --- /dev/null +++ b/requirements-npu.txt @@ -0,0 +1,7 @@ +torch == 2.1.0 +torchvision <= 0.16.0 +attrs +decortator +scipy +psutil +torch_npu diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index 281a06cc5..3b9bc2d18 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -26,6 +26,8 @@ class PretrainedConfig: from .hf_configs import arch_dict +from training.distributed import try_import_npu +try_import_npu() # utils def _camel2snake(s): diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 469d7f5a9..cbf684db9 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -13,6 +13,9 @@ import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint +from training.distributed import try_import_npu +try_import_npu() + from functools import partial from .hf_model import HFTextEncoder diff --git a/src/open_clip/openai.py b/src/open_clip/openai.py index 6c2c02352..69bf26e6a 100644 --- a/src/open_clip/openai.py +++ b/src/open_clip/openai.py @@ -12,6 +12,8 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url +from training.distributed import try_import_npu +try_import_npu() __all__ = ["list_openai_models", "load_openai_model"] diff --git a/src/open_clip/timm_model.py b/src/open_clip/timm_model.py index 5ddb9a76b..1f489f884 100644 --- a/src/open_clip/timm_model.py +++ b/src/open_clip/timm_model.py @@ -24,6 +24,8 @@ from .utils import freeze_batch_norm_2d +from training.distributed import try_import_npu +try_import_npu() class TimmModel(nn.Module): """ timm model adapter diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 6d4e604d8..02b9cb90e 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -7,6 +7,8 @@ from torch import nn from torch.nn import functional as F from torch.utils.checkpoint import checkpoint +from training.distributed import try_import_npu +try_import_npu() from .utils import to_2tuple from .pos_embed import get_2d_sincos_pos_embed diff --git a/src/training/distributed.py b/src/training/distributed.py index 268a6c7ad..1b4773986 100644 --- a/src/training/distributed.py +++ b/src/training/distributed.py @@ -9,6 +9,19 @@ hvd = None +def try_import_npu(): + try: + import torch_npu + # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue + torch.npu.set_device(0) + return True + except ImportError: + return False + +TORCH_NPU_AVAILABLE = False +if try_import_npu(): + TORCH_NPU_AVAILABLE = True + def is_global_master(args): return args.rank == 0 @@ -107,6 +120,12 @@ def init_distributed_device(args): else: device = 'cuda:0' torch.cuda.set_device(device) + if TORCH_NPU_AVAILABLE and torch.npu.is_available(): + if args.distributed and not args.no_set_device_rank: + device = 'npu:%d' % args.local_rank + else: + device = "npu:0" + torch.npu.set_device(device) else: device = 'cpu' args.device = device diff --git a/src/training/main.py b/src/training/main.py index 94496999f..4d8712aaf 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -11,7 +11,6 @@ import numpy as np import torch from torch import optim -from torch.cuda.amp import GradScaler try: import wandb @@ -30,13 +29,16 @@ from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss from training.data import get_data -from training.distributed import is_master, init_distributed_device, broadcast_object +from training.distributed import is_master, init_distributed_device, broadcast_object, try_import_npu from training.logger import setup_logging from training.params import parse_args from training.scheduler import cosine_lr, const_lr, const_lr_cooldown from training.train import train_one_epoch, evaluate from training.file_utils import pt_load, check_exists, start_sync_process, remote_sync +TORCH_NPU_AVAILABLE = False +if try_import_npu(): + TORCH_NPU_AVAILABLE = True LATEST_CHECKPOINT_NAME = "epoch_latest.pt" @@ -329,6 +331,11 @@ def main(args): hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) + if args.precision == "amp": + if TORCH_NPU_AVAILABLE and torch.npu.is_available(): + from torch.npu.amp import GradScaler + else: + from torch.cuda.amp import GradScaler scaler = GradScaler() if args.precision == "amp" else None # optionally resume from a checkpoint diff --git a/src/training/params.py b/src/training/params.py index c3d19302d..829b63817 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -314,7 +314,7 @@ def parse_args(args): help="url used to set up distributed training", ) parser.add_argument( - "--dist-backend", default="nccl", type=str, help="distributed backend" + "--dist-backend", default="nccl", type=str, help="distributed backend. \"nccl\" for GPU, \"hccl\" for Ascend NPU" ) parser.add_argument( "--report-to", diff --git a/src/training/precision.py b/src/training/precision.py index a63b92256..14e49419b 100644 --- a/src/training/precision.py +++ b/src/training/precision.py @@ -1,5 +1,7 @@ import torch from contextlib import suppress +from training.distributed import try_import_npu +try_import_npu() def get_autocast(precision): diff --git a/src/training/profiler.py b/src/training/profiler.py index 1805ca693..c33a6e02b 100644 --- a/src/training/profiler.py +++ b/src/training/profiler.py @@ -1,6 +1,11 @@ import argparse import torch +from training.distributed import try_import_npu +TORCH_NPU_AVAILABLE = False +if try_import_npu(): + TORCH_NPU_AVAILABLE = True + import open_clip import pandas as pd from torch.utils.flop_counter import FlopCounterMode @@ -133,6 +138,8 @@ def profile_model(model_name, batch_size=1, profiler='torch'): model.eval() if torch.cuda.is_available(): model = model.cuda() + elif TORCH_NPU_AVAILABLE and torch.npu.is_available(): + model = model.npu() if isinstance(model.visual.image_size, (tuple, list)): image_input_size = (3,) + tuple(model.visual.image_size[-2:])