You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
from copy import deepcopy
import cplxmodule as cplx
n1 = cplx.nn.CplxConv2d(1,1,1)
n2 = deepcopy(n1)
This fails with
n2 = deepcopy(n1)
File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.6/copy.py", line 280, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib/python3.6/copy.py", line 150, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.6/copy.py", line 240, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.6/copy.py", line 306, in _reconstruct
value = deepcopy(value, memo)
File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.6/copy.py", line 280, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib/python3.6/copy.py", line 150, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.6/copy.py", line 240, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib/python3.6/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.6/copy.py", line 274, in _reconstruct
y = func(*args)
File "/usr/lib/python3.6/copyreg.py", line 88, in __newobj__
return cls.__new__(cls, *args)
TypeError: __new__() missing 1 required positional argument: 'real'
Process finished with exit code 1
The text was updated successfully, but these errors were encountered:
class ModelEMA:
def __init__(self, model, decay=0.9999, updates=0):
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)
def to(self, device):
self.device = device
self.ema.to(device)
def update(self, model):
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
ema_msd = self.ema.module.state_dict() if is_parallel(self.ema) else self.ema.state_dict()
for k, v in ema_msd.items():
if v.dtype.is_floating_point:
v *= d
v += (1. - d) * msd[k].detach()
I patched in support for copy.deepcopy and added some tests, which should resolve the problem. The issue was that cplx.Cplx object did not have the appropriate API outlined in copy module.
This fails with
The text was updated successfully, but these errors were encountered: