Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor] add Parameter inheritance for ColoParameter #1041

Merged
merged 5 commits into from
May 30, 2022

Conversation

Wesley-Jzy
Copy link
Contributor

@Wesley-Jzy Wesley-Jzy commented May 30, 2022

image

实验了一下钻石继承。
由于mro机制的存在,Python的多继承中所有通过super()调用的子类的方法都是通过对类的mro列表使用bfs遍历,保证只调用一次,只需要多继承顺序中将ColoTensor写在Parameter左边即可。这里的test()方法加在了Parameter和ColoTensor中,ColoParameter没有overload这个方法,可以看到因为默认的overload是super().xxx,确实是ColoTensor被调用了而Parameter没有调用。
然后是init时候的对象属性冗余测试。init时候我们没有调用super而是直接使用torch的make_subclass,可以看到init只执行ColoParameter的方法,也不会重复去调用Parameter的init方法。
目前看下来这样写是没有副作用的。
一个简单的example。

class Grand(object):
    def __init__(self, name):
        print('Grand init')
        self.name = name

class Parent1(Grand):
    def __init__(self, name):
        print('Parent1 init')
        Grand.__init__(self, name)

class Parent2(Grand):
    def __init__(self, name):
        print('Parent2 init')
        Grand.__init__(self, name)

class Son(Parent1, Parent2):
    def __init__(self, name):
        print('Son init')
        super().__init__(name)

son = Son('test')

而Python钻石继承会出现多次调用问题的场景如下:

class Grand(object):
    def __init__(self, name):
        print('Grand init')
        self.name = name

class Parent1(Grand):
    def __init__(self, name):
        print('Parent1 init')
        Grand.__init__(self, name)

class Parent2(Grand):
    def __init__(self, name):
        print('Parent2 init')
        Grand.__init__(self, name)

class Son(Parent1, Parent2):
    def __init__(self, name):
        print('Son init')
        Parent1.__init__(self, name)
        Parent2.__init__(self, name)

son = Son('test')

只有显式指定调用多个来自多继承的方法才会重复调用。

@Wesley-Jzy
Copy link
Contributor Author

@feifeibear Add explanation for multi inheritance.



class ColoParameter(ColoTensor):
class ColoParameter(ColoTensor, torch.nn.Parameter):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can ColoParameter use torch_function?
Can you add a test for this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works.
image

Copy link
Contributor

@feifeibear feifeibear left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Wesley-Jzy Wesley-Jzy merged commit 7c530b9 into hpcaitech:main May 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants