Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mindspore.nn.Sigmoid() causes the code to run successfully on the CPU but fails on the GPU #283

Open
PhyllisJi opened this issue May 17, 2024 · 0 comments

Comments

@PhyllisJi
Copy link

Environment

Hardware Environment(Ascend/GPU/CPU):

Uncomment only one /device <> line, hit enter to put that in a new line, and remove leading whitespaces from that line:
/device gpu

/device cpu

Software Environment:

  • MindSpore version (source or binary): 2.2.14 binary
  • Python version (e.g., Python 3.7.5): 3.9
  • OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.0.4
  • GCC/Compiler version (if compiled from source):

Describe the current behavior

The same code runs successfully on the cpu but fails to run on the gpu。

Describe the expected behavior

Running process and results are all consistent

Steps to reproduce the issue

import mindspore
import numpy as np
import os

mindspore.context.set_context(device_target='CPU')
class Model_dvw8Cn7J9ffhLheAbuf_EV7kuo5GpDcJ(mindspore.nn.Cell):
    def __init__(self):
        super(Model_dvw8Cn7J9ffhLheAbuf_EV7kuo5GpDcJ, self).__init__()
        self.conv1_mutated = mindspore.nn.Conv2dTranspose(in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), output_padding=(0, 0), dilation=(1, 1), group=1, has_bias=True)
        self.relu1 = mindspore.nn.ReLU()
        self.pool1_mutated = mindspore.nn.MaxPool2d(kernel_size=(2, 2), stride=(1, 2), pad_mode="pad", padding=(0, 0), dilation=1, return_indices=False, ceil_mode=False, data_format="NCHW")
        self.conv2_mutated = mindspore.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5, 5), stride=(8, 8), pad_mode="pad", padding=(0, 0, 0, 0), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.relu2 = mindspore.nn.ReLU()
        self.pool2 = mindspore.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), pad_mode="pad", padding=(0, 0), dilation=1, return_indices=False, ceil_mode=False, data_format="NCHW")
        self.flatten = mindspore.nn.Flatten(start_dim=1, end_dim=-1)
        self.linear1_mutated = mindspore.nn.Dense(in_channels=32, out_channels=120)
        self.relu3 = mindspore.nn.ReLU()
        self.linear2 = mindspore.nn.Dense(in_channels=120, out_channels=84)
        self.relu4_mutated = mindspore.nn.Sigmoid()
        self.tail_flatten = mindspore.nn.Flatten(start_dim=1, end_dim=-1)
        self.tail_fc = mindspore.nn.Dense(in_channels=84, out_channels=10)

    def construct(self, input):
        conv1_output = self.conv1_mutated(input)
        relu1_output = self.relu1(conv1_output)
        maxpool1_output = self.pool1_mutated(relu1_output)
        conv2_output = self.conv2_mutated(maxpool1_output)
        relu2_output = self.relu2(conv2_output)
        maxpool2_output = self.pool2(relu2_output)
        flatten_output = self.flatten(maxpool2_output)
        fc1_output = self.linear1_mutated(flatten_output)
        relu3_output = self.relu3(fc1_output)
        fc2_output = self.linear2(relu3_output)
        relu4_output = self.relu4_mutated(fc2_output)
        tail_flatten_output = self.tail_flatten(relu4_output)
        tail_fc_output = self.tail_fc(tail_flatten_output)

        tail_fc_output = tail_fc_output
        return tail_fc_output


def go():
    try:
        ms_model = Model_dvw8Cn7J9ffhLheAbuf_EV7kuo5GpDcJ()
        ms_input = mindspore.Tensor(np.random.randn(1, 1, 28, 28).astype(np.float32))
        ms_output = ms_model(ms_input)
        flag = True
    except Exception:
        flag = False
    return flag


def train(inp, label):
    ms_model = Model_dvw8Cn7J9ffhLheAbuf_EV7kuo5GpDcJ()
    initialize(ms_model)
    ms_input = mindspore.Tensor(inp.astype(np.float32))
    def forward_fn(label):
        ms_output = ms_model(ms_input)
        label = label.astype(np.int32)
        ms_targets = mindspore.Tensor(label)
        loss = mindspore.nn.CrossEntropyLoss(reduction='mean')(ms_output, ms_targets)
        return loss, ms_output

    (ms_loss, ms_output), ms_gradients = mindspore.value_and_grad(forward_fn, None, ms_model.trainable_params(), has_aux=True)(label)
    ms_gradients_dic = {}
    for var, gradient in zip(ms_model.trainable_params(), ms_gradients):
        ms_gradients_dic.setdefault(var.name, gradient.numpy())
    return ms_gradients_dic, ms_loss.numpy().item(), ms_output.numpy()

def initialize(model):
    module_dir = os.path.dirname(__file__)
    for name, param in model.parameters_and_names():
        layer_name, matrix_name = name.rsplit('.', 1)
        matrix_path = module_dir + '/../initializer/' + layer_name + '/' + matrix_name + '.npz'
        data = np.load(matrix_path)
        data = data['matrix']
        weight_tensor = mindspore.Tensor(data).float()
        param.set_data(weight_tensor)

change mindspore.context.set_context(device_target='CPU') to
mindspore.context.set_context(device_target='GPU')

Related log / screenshot

When using GPU, we obtained this error:

Traceback (most recent call last):
File "/mnt/AA_MoCoDiff/MoCoDiff/Components/performer.py", line 84, in perform
grad, loss, output = case_file.train(inp, label)
File "/mnt/AA_MoCoDiff/MoCoDiff/./tree/tree_LeNet_n4/11/399/mindspore_gpu/LeNet-11-399_mindspore_gpu.py", line 67, in train
(ms_loss, ms_output), ms_gradients = mindspore.value_and_grad(forward_fn, None, ms_model.trainable_params(), has_aux=True)(label)
File "/root/miniconda3/envs/mocodiff/lib/python3.9/site-packages/mindspore/ops/composite/base.py", line 625, in after_grad
return grad_(fn_, weights)(*args, **kwargs)
File "/root/miniconda3/envs/mocodiff/lib/python3.9/site-packages/mindspore/common/api.py", line 121, in wrapper
results = fn(*arg, **kwargs)
File "/root/miniconda3/envs/mocodiff/lib/python3.9/site-packages/mindspore/ops/composite/base.py", line 601, in after_grad
pynative_executor.grad(fn, grad, weights, grad_position, *args, **kwargs)
File "/root/miniconda3/envs/mocodiff/lib/python3.9/site-packages/mindspore/common/api.py", line 1249, in grad
self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values()))
RuntimeError: key not found

Special notes for this issue

Maybe Related to mindspore.nn.Sigmoid().
This is because it is possible to have normal operation and no differential behaviour after removing this layer.

@PhyllisJi PhyllisJi changed the title The same code runs successfully on the cpu but fails to run on the gpu mindspore.nn.Sigmoid() causes the code to run successfully on the CPU but fails on the GPU May 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant