-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Not preserving spatial dimensions #39
Comments
Ah yes, this will happen with odd length signals. I think your input is quite small, particularly for a 3 scale transform, but let's consider a larger example and fewer scales: j = 1
wave = 'db1'
mode = 'symmetric'
layer0 = DWTForward(J=j, wave=wave, mode=mode)
layer1 = DWTInverse(wave=wave, mode=mode)
test_input = torch.arange(27 * 9).reshape(1, 3, 9, 9).to(torch.float32)
low, high = layer0(test_input)
test_output = layer1((low, high))
print(test_input.shape, test_output.shape)
>> torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 10, 10]) Let's look at the size of low and high: print(low.shape)
>> torch.Size([1, 3, 5, 5])
print(high[0].shape)
>> torch.Size([1, 3, 3, 5, 5]) What's happening here? As we decimate by two as part of the transform then we need the signals to be even length, so the input is effectively padded, using the periodization mode you've selected (symmetric). If we look at the output of the above reconstructed tensor, you'll see what happens: import numpy as np
np.set_printoptions(linewidth=120, suppress=True, precision=2)
print(test_output.numpy())
[[[[ -0. 1. 2. 3. 4. 5. 6. 7. 8. 8.]
[ 9. 10. 11. 12. 13. 14. 15. 16. 17. 17.]
[ 18. 19. 20. 21. 22. 23. 24. 25. 26. 26.]
[ 27. 28. 29. 30. 31. 32. 33. 34. 35. 35.]
[ 36. 37. 38. 39. 40. 41. 42. 43. 44. 44.]
[ 45. 46. 47. 48. 49. 50. 51. 52. 53. 53.]
[ 54. 55. 56. 57. 58. 59. 60. 61. 62. 62.]
[ 63. 64. 65. 66. 67. 68. 69. 70. 71. 71.]
[ 72. 73. 74. 75. 76. 77. 78. 79. 80. 80.]
[ 72. 73. 74. 75. 76. 77. 78. 79. 80. 80.]]
[[ 81. 82. 83. 84. 85. 86. 87. 88. 89. 89.]
[ 90. 91. 92. 93. 94. 95. 96. 97. 98. 98.]
[ 99. 100. 101. 102. 103. 104. 105. 106. 107. 107.]
[108. 109. 110. 111. 112. 113. 114. 115. 116. 116.]
[117. 118. 119. 120. 121. 122. 123. 124. 125. 125.]
[126. 127. 128. 129. 130. 131. 132. 133. 134. 134.]
[135. 136. 137. 138. 139. 140. 141. 142. 143. 143.]
[144. 145. 146. 147. 148. 149. 150. 151. 152. 152.]
[153. 154. 155. 156. 157. 158. 159. 160. 161. 161.]
[153. 154. 155. 156. 157. 158. 159. 160. 161. 161.]]
[[162. 163. 164. 165. 166. 167. 168. 169. 170. 170.]
[171. 172. 173. 174. 175. 176. 177. 178. 179. 179.]
[180. 181. 182. 183. 184. 185. 186. 187. 188. 188.]
[189. 190. 191. 192. 193. 194. 195. 196. 197. 197.]
[198. 199. 200. 201. 202. 203. 204. 205. 206. 206.]
[207. 208. 209. 210. 211. 212. 213. 214. 215. 215.]
[216. 217. 218. 219. 220. 221. 222. 223. 224. 224.]
[225. 226. 227. 228. 229. 230. 231. 232. 233. 233.]
[234. 235. 236. 237. 238. 239. 240. 241. 242. 242.]
[234. 235. 236. 237. 238. 239. 240. 241. 242. 242.]]]] If you have to use odd length signals, you can crop the top left of the output. But otherwise it's good to have an input that's an integer multiple of |
Good day!
Thanks a lot for your efforts with this lib!
Recently I encountered some problems with preserving spatial dimensions of tensor.
Expected to get
(1, 3, 3, 3)
but got(1, 3, 4, 4)
The text was updated successfully, but these errors were encountered: