Skip to content

Commit

Permalink
add npu support
Browse files Browse the repository at this point in the history
  • Loading branch information
MengqingCao committed Feb 5, 2024
1 parent 3ff1faf commit a6a2032
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 2 deletions.
7 changes: 7 additions & 0 deletions requirements-npu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
torch <= 2.1.0
torchvision <= 0.16.0
attrs
decortator
scipy
psutil
torch_npu
9 changes: 9 additions & 0 deletions src/training/distributed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os

import torch
try:
import torch_npu
except ImportError:
torch_npu = None

import torch.distributed as dist

try:
Expand Down Expand Up @@ -107,6 +112,10 @@ def init_distributed_device(args):
else:
device = 'cuda:0'
torch.cuda.set_device(device)
if torch_npu != None and torch.npu.is_available():
# TODO: add distributed code for npu
device = "npu:0"
torch_npu.npu.set_device(device)
else:
device = 'cpu'
args.device = device
Expand Down
10 changes: 9 additions & 1 deletion src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import numpy as np
import torch
from torch import optim
from torch.cuda.amp import GradScaler
try:
import torch_npu
except ImportError:
torch_npu = None

try:
import wandb
Expand Down Expand Up @@ -329,6 +332,11 @@ def main(args):
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

if args.precision == "amp":
if torch_npu != None and torch.npu.is_available():
from torch.npu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
scaler = GradScaler() if args.precision == "amp" else None

# optionally resume from a checkpoint
Expand Down
5 changes: 4 additions & 1 deletion src/training/precision.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch
from contextlib import suppress

try:
import torch_npu
except ImportError:
torch_npu = None

def get_autocast(precision):
if precision == 'amp':
Expand Down
7 changes: 7 additions & 0 deletions src/training/profiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import argparse

import torch
try:
import torch_npu
except ImportError:
torch_npu = None

import open_clip
import pandas as pd
from torch.utils.flop_counter import FlopCounterMode
Expand Down Expand Up @@ -133,6 +138,8 @@ def profile_model(model_name, batch_size=1, profiler='torch'):
model.eval()
if torch.cuda.is_available():
model = model.cuda()
elif torch_npu != None and torch.npu.is_available():
model = model.npu()

if isinstance(model.visual.image_size, (tuple, list)):
image_input_size = (3,) + tuple(model.visual.image_size[-2:])
Expand Down

0 comments on commit a6a2032

Please sign in to comment.