Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support musa #992

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions mmrotate/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mmcv.parallel import collate, scatter
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmengine.device import is_musa_available

from mmrotate.core import get_multiscale_patch, merge_results, slide_window

Expand Down Expand Up @@ -71,6 +72,8 @@ def inference_detector_by_patches(model,
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
if is_musa_available() and next(model.parameters()).is_musa:
data = scatter(data, [device])[0]
else:
for m in model.modules():
assert not isinstance(
Expand Down
13 changes: 11 additions & 2 deletions mmrotate/core/bbox/samplers/rotate_random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from mmdet.core.bbox.samplers.base_sampler import BaseSampler
from mmdet.core.bbox.samplers.sampling_result import SamplingResult
from mmengine.device import is_cuda_available, is_musa_available

from ..builder import ROTATED_BBOX_SAMPLERS

Expand Down Expand Up @@ -48,8 +49,16 @@ def random_choice(self, gallery, num):

is_tensor = isinstance(gallery, torch.Tensor)
if not is_tensor:
gallery = torch.tensor(
gallery, dtype=torch.long, device=torch.cuda.current_device())
if is_cuda_available():
gallery = torch.tensor(
gallery,
dtype=torch.long,
device=torch.cuda.current_device())
elif is_musa_available():
gallery = torch.tensor(
gallery,
dtype=torch.long,
device=torch.musa.current_device())
perm = torch.randperm(gallery.numel(), device=gallery.device)[:num]
rand_inds = gallery[perm]
if not is_tensor:
Expand Down
25 changes: 15 additions & 10 deletions mmrotate/core/bbox/utils/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from mmengine.device import is_cuda_available, is_musa_available


class GaussianMixture():
Expand Down Expand Up @@ -49,6 +50,10 @@ def _init_params(self, mu_init=None, var_init=None):
if var_init is not None:
self.var_init = var_init

if is_musa_available():
device = 'musa'
elif is_cuda_available():
device = 'cuda'
if self.requires_grad:
if self.mu_init is not None:
assert torch.is_tensor(self.mu_init)
Expand All @@ -57,11 +62,11 @@ def _init_params(self, mu_init=None, var_init=None):
), 'Input mu_init does not have required tensor dimensions' \
' (%i, %i, %i)' % (
self.T, self.n_components, self.n_features)
self.mu = self.mu_init.clone().requires_grad_().cuda()
self.mu = self.mu_init.clone().requires_grad_().to(device)
else:
self.mu = torch.randn(
(self.T, self.n_components, self.n_features),
requires_grad=True).cuda()
requires_grad=True).to(device)

if self.var_init is not None:
assert torch.is_tensor(self.var_init)
Expand All @@ -72,16 +77,16 @@ def _init_params(self, mu_init=None, var_init=None):
(self.T, self.n_components,
self.n_features,
self.n_features)
self.var = self.var_init.clone().requires_grad_().cuda()
self.var = self.var_init.clone().requires_grad_().to(device)
else:
self.var = torch.eye(self.n_features).reshape(
(1, 1, self.n_features, self.n_features))\
.repeat(self.T, self.n_components, 1, 1)\
.requires_grad_().cuda()
.requires_grad_().to(device)

self.pi = torch.empty(
(self.T, self.n_components,
1)).fill_(1. / self.n_components).requires_grad_().cuda()
1)).fill_(1. / self.n_components).requires_grad_().to(device)
else:
if self.mu_init is not None:
assert torch.is_tensor(self.mu_init)
Expand All @@ -90,10 +95,10 @@ def _init_params(self, mu_init=None, var_init=None):
), 'Input mu_init does not have required tensor dimensions' \
' (%i, %i, %i)' % (
self.T, self.n_components, self.n_features)
self.mu = self.mu_init.clone().cuda()
self.mu = self.mu_init.clone().to(device)
else:
self.mu = torch.randn(
(self.T, self.n_components, self.n_features)).cuda()
(self.T, self.n_components, self.n_features)).to(device)

if self.var_init is not None:
assert torch.is_tensor(self.var_init)
Expand All @@ -104,15 +109,15 @@ def _init_params(self, mu_init=None, var_init=None):
(self.T, self.n_components,
self.n_features,
self.n_features)
self.var = self.var_init.clone().cuda()
self.var = self.var_init.clone().to(device)
else:
self.var = torch.eye(self.n_features).reshape(
(1, 1, self.n_features,
self.n_features)).repeat(self.T, self.n_components, 1,
1).cuda()
1).to(device)

self.pi = torch.empty((self.T, self.n_components,
1)).fill_(1. / self.n_components).cuda()
1)).fill_(1. / self.n_components).to(device)

self.params_fitted = False

Expand Down
6 changes: 5 additions & 1 deletion mmrotate/datasets/dota.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from mmcv.ops import nms_rotated
from mmdet.datasets.custom import CustomDataset
from mmengine.device import is_cuda_available, is_musa_available

from mmrotate.core import eval_rbbox_map, obb2poly_np, poly2obb_np
from .builder import ROTATED_DATASETS
Expand Down Expand Up @@ -373,7 +374,10 @@ def _merge_func(info, CLASSES, iou_thr):
big_img_results.append(dets[labels == i])
else:
try:
cls_dets = torch.from_numpy(dets[labels == i]).cuda()
if is_cuda_available():
cls_dets = torch.from_numpy(dets[labels == i]).cuda()
elif is_musa_available():
cls_dets = torch.from_numpy(dets[labels == i]).musa()
except: # noqa: E722
cls_dets = torch.from_numpy(dets[labels == i])
nms_dets, keep_inds = nms_rotated(cls_dets[:, :5], cls_dets[:, -1],
Expand Down
16 changes: 11 additions & 5 deletions mmrotate/models/losses/convex_giou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch.nn as nn
from mmcv.ops import convex_giou
from mmengine.device import is_cuda_available, is_musa_available
from torch.autograd import Function
from torch.autograd.function import once_differentiable

Expand Down Expand Up @@ -229,11 +230,16 @@ def forward(ctx,

target_aspect = AspectRatio(target)
smooth_loss_weight = torch.exp((-1 / 4) * target_aspect)
loss = \
smooth_loss_weight * (diff_mean_loss.reshape(-1, 1).cuda() +
diff_corners_loss.reshape(-1, 1).cuda()) + \
1 - (1 - 2 * smooth_loss_weight) * convex_gious

if is_cuda_available():
loss = \
smooth_loss_weight * (diff_mean_loss.reshape(-1, 1).cuda() +
diff_corners_loss.reshape(-1, 1).cuda()) + \
1 - (1 - 2 * smooth_loss_weight) * convex_gious
elif is_musa_available():
loss = \
smooth_loss_weight * (diff_mean_loss.reshape(-1, 1).musa() +
diff_corners_loss.reshape(-1, 1).musa()) + \
1 - (1 - 2 * smooth_loss_weight) * convex_gious
if weight is not None:
loss = loss * weight
grad = grad * weight.reshape(-1, 1)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmrotate
known_third_party = PIL,cv2,e2cnn,matplotlib,mmcv,mmdet,numpy,pytest,pytorch_sphinx_theme,terminaltables,torch,ts,yaml
known_third_party = PIL,cv2,e2cnn,matplotlib,mmcv,mmdet,mmengine,numpy,pytest,pytorch_sphinx_theme,terminaltables,torch,ts,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

Expand Down