Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support open_clip with NPU backend #813

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions requirements-npu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
torch == 2.1.0
torchvision <= 0.16.0
attrs
decortator
scipy
psutil
torch_npu
2 changes: 2 additions & 0 deletions src/open_clip/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class PretrainedConfig:

from .hf_configs import arch_dict

from training.distributed import try_import_npu
try_import_npu()

# utils
def _camel2snake(s):
Expand Down
3 changes: 3 additions & 0 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from training.distributed import try_import_npu
try_import_npu()

from functools import partial

from .hf_model import HFTextEncoder
Expand Down
2 changes: 2 additions & 0 deletions src/open_clip/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
from training.distributed import try_import_npu
try_import_npu()

__all__ = ["list_openai_models", "load_openai_model"]

Expand Down
2 changes: 2 additions & 0 deletions src/open_clip/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

from .utils import freeze_batch_norm_2d

from training.distributed import try_import_npu
try_import_npu()

class TimmModel(nn.Module):
""" timm model adapter
Expand Down
2 changes: 2 additions & 0 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from training.distributed import try_import_npu
try_import_npu()

from .utils import to_2tuple
from .pos_embed import get_2d_sincos_pos_embed
Expand Down
19 changes: 19 additions & 0 deletions src/training/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@
hvd = None


def try_import_npu():
try:
import torch_npu
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
torch.npu.set_device(0)
return True
except ImportError:
return False

TORCH_NPU_AVAILABLE = False
if try_import_npu():
TORCH_NPU_AVAILABLE = True

def is_global_master(args):
return args.rank == 0

Expand Down Expand Up @@ -107,6 +120,12 @@ def init_distributed_device(args):
else:
device = 'cuda:0'
torch.cuda.set_device(device)
if TORCH_NPU_AVAILABLE and torch.npu.is_available():
if args.distributed and not args.no_set_device_rank:
device = 'npu:%d' % args.local_rank
else:
device = "npu:0"
torch.npu.set_device(device)
else:
device = 'cpu'
args.device = device
Expand Down
11 changes: 9 additions & 2 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
import torch
from torch import optim
from torch.cuda.amp import GradScaler

try:
import wandb
Expand All @@ -30,13 +29,16 @@

from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss
from training.data import get_data
from training.distributed import is_master, init_distributed_device, broadcast_object
from training.distributed import is_master, init_distributed_device, broadcast_object, try_import_npu
from training.logger import setup_logging
from training.params import parse_args
from training.scheduler import cosine_lr, const_lr, const_lr_cooldown
from training.train import train_one_epoch, evaluate
from training.file_utils import pt_load, check_exists, start_sync_process, remote_sync

TORCH_NPU_AVAILABLE = False
if try_import_npu():
TORCH_NPU_AVAILABLE = True

LATEST_CHECKPOINT_NAME = "epoch_latest.pt"

Expand Down Expand Up @@ -329,6 +331,11 @@ def main(args):
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

if args.precision == "amp":
if TORCH_NPU_AVAILABLE and torch.npu.is_available():
from torch.npu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
scaler = GradScaler() if args.precision == "amp" else None

# optionally resume from a checkpoint
Expand Down
2 changes: 1 addition & 1 deletion src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def parse_args(args):
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
"--dist-backend", default="nccl", type=str, help="distributed backend. \"nccl\" for GPU, \"hccl\" for Ascend NPU"
)
parser.add_argument(
"--report-to",
Expand Down
2 changes: 2 additions & 0 deletions src/training/precision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from contextlib import suppress
from training.distributed import try_import_npu
try_import_npu()


def get_autocast(precision):
Expand Down
7 changes: 7 additions & 0 deletions src/training/profiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import argparse

import torch
from training.distributed import try_import_npu
TORCH_NPU_AVAILABLE = False
if try_import_npu():
TORCH_NPU_AVAILABLE = True

import open_clip
import pandas as pd
from torch.utils.flop_counter import FlopCounterMode
Expand Down Expand Up @@ -133,6 +138,8 @@ def profile_model(model_name, batch_size=1, profiler='torch'):
model.eval()
if torch.cuda.is_available():
model = model.cuda()
elif TORCH_NPU_AVAILABLE and torch.npu.is_available():
model = model.npu()

if isinstance(model.visual.image_size, (tuple, list)):
image_input_size = (3,) + tuple(model.visual.image_size[-2:])
Expand Down
Loading