Skip to content

Commit

Permalink
Migration to Python 3.6, pytorch 0.4/1.0, pynvrtc 9.+
Browse files Browse the repository at this point in the history
  • Loading branch information
ip0001 committed Oct 7, 2018
1 parent ccdf8fb commit 0653ac7
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 27 deletions.
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
cupy
torch >= 0.1.11
torch >= 0.4
numpy
scikit-cuda
scipy
pynvrtc
pynvrtc >= 9.0
tqdm
git+git://github.com/pytorch/tnt.git#egg=tnt
torch_testing
7 changes: 3 additions & 4 deletions scatwave/scattering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import torch
from .utils import cdgmm, Modulus, Periodize, Fft
from .filters_bank import filters_bank
from torch.legacy.nn import SpatialReflectionPadding as pad_function

from torch.nn import ReflectionPad2d as pad_function

class Scattering(object):
"""Scattering module.
Expand Down Expand Up @@ -44,7 +43,7 @@ def __init__(self, M, N, J, pre_pad=False, jit=True):

def _type(self, _type):
for key, item in enumerate(self.Psi):
for key2, item2 in self.Psi[key].iteritems():
for key2, item2 in self.Psi[key].items():
if torch.is_tensor(item2):
self.Psi[key][key2] = item2.type(_type)
self.Phi = [v.type(_type) for v in self.Phi]
Expand Down Expand Up @@ -77,7 +76,7 @@ def _pad(self, input):
output = input.new(input.size(0), input.size(1), input.size(2), input.size(3), 2).fill_(0)
output.narrow(output.ndimension()-1, 0, 1).copy_(input)
else:
out_ = self.padding_module.updateOutput(input)
out_ = self.padding_module(input)
output = input.new(out_.size(0), out_.size(1), out_.size(2), out_.size(3), 2).fill_(0)
output.narrow(4, 0, 1).copy_(out_.unsqueeze(4))
return output
Expand Down
16 changes: 8 additions & 8 deletions scatwave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(self, input, k):
kernel = Template(kernel).substitute(B=B, H=H, W=W, k=k, Dtype=getDtype(input))
name = str(input.get_device())+'-'+str(B)+'-'+str(k)+'-'+str(H)+'-'+str(W)+'-periodize.cu'
print(name)
prog = Program(kernel, name.encode())
prog = Program(kernel, name)
ptx = prog.compile(['-arch='+get_compute_arch(input)])
module = Module()
module.load(bytes(ptx.encode()))
Expand Down Expand Up @@ -120,7 +120,7 @@ def GET_BLOCKS(self, N):
def __call__(self, input):
if not self.jit or not isinstance(input, torch.cuda.FloatTensor):
norm = input.norm(2, input.dim() - 1)
return torch.cat([norm, norm.new(norm.size()).zero_()], input.dim() - 1)
return torch.stack([norm, norm.new(norm.size()).zero_()], -1)

out = input.new(input.size())
input = input.contiguous()
Expand All @@ -129,7 +129,7 @@ def __call__(self, input):
raise TypeError('The input and outputs should be complex')

if (self.modulus_cache[input.get_device()] is None):
kernel = b"""
kernel = """
extern "C"
__global__ void abs_complex_value(const float * x, float2 * z, int n)
{
Expand All @@ -141,7 +141,7 @@ def __call__(self, input):
}
"""
print('modulus.cu')
prog = Program(kernel, b'modulus.cu')
prog = Program(kernel, 'modulus.cu')
ptx = prog.compile(['-arch='+get_compute_arch(input)])
module = Module()
module.load(bytes(ptx.encode()))
Expand Down Expand Up @@ -257,11 +257,11 @@ def cdgmm(A, B, jit=True, inplace=False):
if not jit or isinstance(A, (torch.FloatTensor, torch.DoubleTensor)):
C = A.new(A.size())

A_r = A[..., 0].contiguous().view(-1, A.size(-2)*A.size(-3))
A_i = A[..., 1].contiguous().view(-1, A.size(-2)*A.size(-3))
A_r = A[..., 0]
A_i = A[..., 1]

B_r = B[...,0].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_i)
B_i = B[..., 1].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_r)
B_r = B[..., 0].unsqueeze(0)
B_i = B[..., 1].unsqueeze(0)

C[..., 0].copy_(A_r * B_r - A_i * B_i)
C[..., 1].copy_(A_r * B_i + A_i * B_r)
Expand Down
24 changes: 11 additions & 13 deletions test/test_scattering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

import torch
import unittest
import torch_testing as tt
from scatwave.scattering import Scattering
from scatwave import utils as sl

def linfnorm(x,y):
return torch.max(torch.abs(x-y))

class TestScattering(unittest.TestCase):
def testFFTCentralFreq(self):
# Checked the 0 frequency
Expand All @@ -21,7 +19,7 @@ def testFFTCentralFreq(self):
fft = sl.Fft()
fft(x, inplace=True)
b = x[0,0,0]
self.assertAlmostEqual(a, b, places=6)
tt.assert_almost_equal(a.cpu(), b.cpu(), decimal=6)

def testFFTCentralFreqBatch(self):
# Same for batches
Expand All @@ -35,7 +33,7 @@ def testFFTCentralFreqBatch(self):
fft = sl.Fft()
fft(x, inplace=True)
c = x[:,0,0,0].sum()
self.assertEqual(a, c)
tt.assert_equal(a.cpu(), c.cpu())

def testFFTUnormalized(self):
# Check for a random tensor:
Expand All @@ -55,7 +53,7 @@ def testFFTUnormalized(self):
z /= 17*3 # FFTs are unnormalized


self.assertAlmostEqual(linfnorm(x.select(3,0), z),0,places=6)
tt.assert_allclose(x.select(3,0).cpu(), z.cpu(), atol=1e-6)



Expand All @@ -67,9 +65,9 @@ def testModulus(self):
x = torch.cuda.FloatTensor(100,10,4,2).copy_(torch.rand(100,10,4,2))
y = modulus(x)
u = torch.squeeze(torch.sqrt(torch.sum(x * x, 3)))
v = y.narrow(3, 0, 1)
v = y[..., 0]

self.assertLess((u - v).abs().max(), 1e-6)
tt.assert_allclose(u.cpu(), v.cpu(), atol=1e-6)


def testPeriodization(self):
Expand All @@ -88,10 +86,10 @@ def testPeriodization(self):
periodize = sl.Periodize(jit=jit)

z = periodize(x, k=16)
self.assertLess((y - z).abs().max(), 1e-8)
tt.assert_allclose(y.cpu(), z.cpu(), atol=1e-8)

z = periodize(x.cpu(), k=16)
self.assertLess((y.cpu() - z).abs().max(), 1e-8)
tt.assert_allclose(y.cpu(), z, atol=1e-8)


# Check the CUBLAS routines
Expand All @@ -108,7 +106,7 @@ def testCublas(self):
y[i, :, :, 1] = x[i, :, :, 1] * filter[:, :, 0] + x[i, :, :, 0] *filter[:, :, 1]
z = sl.cdgmm(x, filter, jit=jit)

self.assertLess((y-z).abs().max(), 1e-6)
tt.assert_allclose(y.cpu(), z.cpu(), atol=1e-6)

def testScattering(self):
data = torch.load('test/test_data.pt')
Expand All @@ -118,7 +116,7 @@ def testScattering(self):
scat.cuda()
x = x.cuda()
S = S.cuda()
self.assertLess(((S - scat(x))).abs().max(), 1e-6)
tt.assert_allclose(S.cpu(), scat(x).cpu(), atol=1e-6)

scat = Scattering(128, 128, 4, pre_pad=False, jit=False)
Sg = []
Expand All @@ -134,7 +132,7 @@ def testScattering(self):
Sc = scat(x)
"""there are huge round off errors with fftw, numpy fft, cufft...
and the kernels of periodization. We do not wish to play with that as it is meaningless."""
self.assertLess((Sg.cpu()-Sc).abs().max(), 1e-1)
tt.assert_allclose(Sg.cpu(), Sc.cpu(), atol=1e-1)



Expand Down

0 comments on commit 0653ac7

Please sign in to comment.