Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepcopy doesn't work with cplxmodule modules #18

Closed
pfeatherstone opened this issue Jul 6, 2021 · 3 comments
Closed

deepcopy doesn't work with cplxmodule modules #18

pfeatherstone opened this issue Jul 6, 2021 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@pfeatherstone
Copy link

pfeatherstone commented Jul 6, 2021

from copy import deepcopy
import cplxmodule as cplx

n1 = cplx.nn.CplxConv2d(1,1,1)
n2 = deepcopy(n1)

This fails with

n2 = deepcopy(n1)
  File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.6/copy.py", line 306, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.6/copy.py", line 274, in _reconstruct
    y = func(*args)
  File "/usr/lib/python3.6/copyreg.py", line 88, in __newobj__
    return cls.__new__(cls, *args)
TypeError: __new__() missing 1 required positional argument: 'real'

Process finished with exit code 1
@pfeatherstone
Copy link
Author

Whereas something like:

n1 = torch.nn.Conv2d(1,1,1)
n2 = deepcopy(n1)

works fine

@pfeatherstone
Copy link
Author

I need this to do EMA:

class ModelEMA:
    def __init__(self, model, decay=0.9999, updates=0):
        self.ema        = deepcopy(model.module if is_parallel(model) else model).eval()  # FP32 EMA
        self.updates    = updates  # number of EMA updates
        self.decay      = lambda x: decay * (1 - math.exp(-x / 2000))  # decay exponential ramp (to help early epochs)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def to(self, device):
        self.device = device
        self.ema.to(device)

    def update(self, model):
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)
            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  # model state_dict
            ema_msd = self.ema.module.state_dict() if is_parallel(self.ema) else self.ema.state_dict()
            for k, v in ema_msd.items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1. - d) * msd[k].detach()

@ivannz
Copy link
Owner

ivannz commented Jul 6, 2021

I patched in support for copy.deepcopy and added some tests, which should resolve the problem. The issue was that cplx.Cplx object did not have the appropriate API outlined in copy module.

@ivannz ivannz added the bug Something isn't working label Jul 7, 2021
@ivannz ivannz self-assigned this Jul 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants