Skip to content

Commit

Permalink
Merge pull request #19 from deeprender/master
Browse files Browse the repository at this point in the history
replace nn.Parameter with buffers
  • Loading branch information
fbcotter committed Oct 28, 2020
2 parents 9d7018b + adab9f6 commit 7f3163d
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 42 deletions.
2 changes: 1 addition & 1 deletion pytorch_wavelets/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# IMPORTANT: before release, remove the 'devN' tag from the release name
__version__ = '1.2.3'
__version__ = '1.2.4'
48 changes: 24 additions & 24 deletions pytorch_wavelets/dtcwt/transform2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,22 @@ def __init__(self, biort='near_sym_a', qshift='qshift_a',
self.mode = mode
if isinstance(biort, str):
h0o, _, h1o, _ = _biort(biort)
self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False)
self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False)
self.register_buffer('h0o', prep_filt(h0o, 1))
self.register_buffer('h1o', prep_filt(h1o, 1))
else:
self.h0o = torch.nn.Parameter(prep_filt(biort[0], 1), False)
self.h1o = torch.nn.Parameter(prep_filt(biort[1], 1), False)
self.register_buffer('h0o', prep_filt(biort[0], 1))
self.register_buffer('h1o', prep_filt(biort[1], 1))
if isinstance(qshift, str):
h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift)
self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False)
self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False)
self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False)
self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False)
self.register_buffer('h0a', prep_filt(h0a, 1))
self.register_buffer('h0b', prep_filt(h0b, 1))
self.register_buffer('h1a', prep_filt(h1a, 1))
self.register_buffer('h1b', prep_filt(h1b, 1))
else:
self.h0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False)
self.h0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False)
self.h1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False)
self.h1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False)
self.register_buffer('h0a', prep_filt(qshift[0], 1))
self.register_buffer('h0b', prep_filt(qshift[1], 1))
self.register_buffer('h1a', prep_filt(qshift[2], 1))
self.register_buffer('h1b', prep_filt(qshift[3], 1))

# Get the function to do the DTCWT
if isinstance(skip_hps, (list, tuple, ndarray)):
Expand Down Expand Up @@ -173,22 +173,22 @@ def __init__(self, biort='near_sym_a', qshift='qshift_a', o_dim=2,
self.mode = mode
if isinstance(biort, str):
_, g0o, _, g1o = _biort(biort)
self.g0o = torch.nn.Parameter(prep_filt(g0o, 1), False)
self.g1o = torch.nn.Parameter(prep_filt(g1o, 1), False)
self.register_buffer('g0o', prep_filt(g0o, 1))
self.register_buffer('g1o', prep_filt(g1o, 1))
else:
self.g0o = torch.nn.Parameter(prep_filt(biort[0], 1), False)
self.g1o = torch.nn.Parameter(prep_filt(biort[1], 1), False)
self.register_buffer('g0o', prep_filt(biort[0], 1))
self.register_buffer('g1o', prep_filt(biort[1], 1))
if isinstance(qshift, str):
_, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift)
self.g0a = torch.nn.Parameter(prep_filt(g0a, 1), False)
self.g0b = torch.nn.Parameter(prep_filt(g0b, 1), False)
self.g1a = torch.nn.Parameter(prep_filt(g1a, 1), False)
self.g1b = torch.nn.Parameter(prep_filt(g1b, 1), False)
self.register_buffer('g0a', prep_filt(g0a, 1))
self.register_buffer('g0b', prep_filt(g0b, 1))
self.register_buffer('g1a', prep_filt(g1a, 1))
self.register_buffer('g1b', prep_filt(g1b, 1))
else:
self.g0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False)
self.g0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False)
self.g1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False)
self.g1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False)
self.register_buffer('g0a', prep_filt(qshift[0], 1))
self.register_buffer('g0b', prep_filt(qshift[1], 1))
self.register_buffer('g1a', prep_filt(qshift[2], 1))
self.register_buffer('g1b', prep_filt(qshift[3], 1))

def forward(self, coeffs):
"""
Expand Down
10 changes: 5 additions & 5 deletions pytorch_wavelets/dwt/swt_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ def __init__(self, wave='db1', mode='zero', separable=True):
# Prepare the filters
if separable:
filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
self.g0_col = nn.Parameter(filts[0], requires_grad=False)
self.g1_col = nn.Parameter(filts[1], requires_grad=False)
self.g0_row = nn.Parameter(filts[2], requires_grad=False)
self.g1_row = nn.Parameter(filts[3], requires_grad=False)
self.register_buffer('g0_col', filts[0])
self.register_buffer('g1_col', filts[1])
self.register_buffer('g0_row', filts[2])
self.register_buffer('g1_row', filts[3])
else:
filts = lowlevel.prep_filt_sfb2d_nonsep(
g0_col, g1_col, g0_row, g1_row)
self.h = nn.Parameter(filts, requires_grad=False)
self.register_buffer('h', filts)
self.mode = mode
self.separable = separable

Expand Down
24 changes: 12 additions & 12 deletions pytorch_wavelets/dwt/transform2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def __init__(self, J=1, wave='db1', mode='zero'):

# Prepare the filters
filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
self.h0_col = nn.Parameter(filts[0], requires_grad=False)
self.h1_col = nn.Parameter(filts[1], requires_grad=False)
self.h0_row = nn.Parameter(filts[2], requires_grad=False)
self.h1_row = nn.Parameter(filts[3], requires_grad=False)
self.register_buffer('h0_col', filts[0])
self.register_buffer('h1_col', filts[1])
self.register_buffer('h0_row', filts[2])
self.register_buffer('h1_row', filts[3])
self.J = J
self.mode = mode

Expand Down Expand Up @@ -98,10 +98,10 @@ def __init__(self, wave='db1', mode='zero'):
g0_row, g1_row = wave[2], wave[3]
# Prepare the filters
filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
self.g0_col = nn.Parameter(filts[0], requires_grad=False)
self.g1_col = nn.Parameter(filts[1], requires_grad=False)
self.g0_row = nn.Parameter(filts[2], requires_grad=False)
self.g1_row = nn.Parameter(filts[3], requires_grad=False)
self.register_buffer('g0_col', filts[0])
self.register_buffer('g1_col', filts[1])
self.register_buffer('g0_row', filts[2])
self.register_buffer('g1_row', filts[3])
self.mode = mode

def forward(self, coeffs):
Expand Down Expand Up @@ -175,10 +175,10 @@ def __init__(self, J=1, wave='db1', mode='periodization'):

# Prepare the filters
filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
self.h0_col = nn.Parameter(filts[0], requires_grad=False)
self.h1_col = nn.Parameter(filts[1], requires_grad=False)
self.h0_row = nn.Parameter(filts[2], requires_grad=False)
self.h1_row = nn.Parameter(filts[3], requires_grad=False)
self.register_buffer('h0_col', filts[0])
self.register_buffer('h1_col', filts[1])
self.register_buffer('h0_row', filts[2])
self.register_buffer('h1_row', filts[3])

self.J = J
self.mode = mode
Expand Down

0 comments on commit 7f3163d

Please sign in to comment.