### Minimal Example of the Scattering Transform

This notebook goes through the basic usage of the scattering transform.

Imports:

In [1]:
import torch
from scattering_transform.scattering_transform import ScatteringTransform2d, reduce_coefficients
from scattering_transform.filters import Morlet

First, create some mock data to test. This scattering transform only works for square inputs! Note also that the first dimension is a batch dimension.

In [2]:
batch_size = 32
field_size = 128
data = torch.randn((batch_size, field_size, field_size))

To compute the scattering transform, we need to specify which filters we want to use. The standard Morlet wavelets are built-in, we just have to specify how many scales ($J$) and angles ($L$) to use, to construct the wavelet bank of $\psi_{jl}$

In [3]:
num_scales = 4
num_angles = 4
wavelets = Morlet(field_size, num_scales, num_angles)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Next, we set up our scattering transform object that will use these wavelets:

In [4]:
st = ScatteringTransform2d(wavelets)

Running the scattering transform on our data is simple:

In [5]:
s0, s1, s2  = st.scattering_transform(data)

The scattering transform will return the zeroth $s_0$, first order $s_1$ and second order $s_2$ scattering coefficients in a tuple. To convert these to a single tensor, we can use the reduce_coefficients function. There are 3 reduction schemes include: no reduction (i.e. all the coefficients), rotational averaging (which averages over all rotations) and angular averaging (which averages over all separation angles).

In [6]:
s_all = reduce_coefficients(s0, s1, s2, reduction=None)
print('All:', s_all.shape)

s_rot = reduce_coefficients(s0, s1, s2, reduction='rot_avg')
print('Rot:', s_rot.shape)

s_ang = reduce_coefficients(s0, s1, s2, reduction='ang_avg')
print('Ang:', s_ang.shape)


All: torch.Size([32, 113])
Rot: torch.Size([32, 11])
Ang: torch.Size([32, 23])


This implementation is fast thanks to some fourier space truncations, so we can do a lot of large STs quickly, even on CPU!

In [7]:
data = torch.randn((128, 512, 512))

In [8]:
%%time
output = reduce_coefficients(*st.scattering_transform(data))

CPU times: total: 4.39 s
Wall time: 575 ms


If we have a GPU available (and torch compiled with GPU enabled!), we can also move the calculation to GPU easily with the 'to' method. The scattering transform object behaves as a torch.nn.Module would.

In [9]:
device = torch.device('cuda')
st.to(device)
data_cuda = data.to(device)

In [10]:
%%time
output = reduce_coefficients(*st.scattering_transform(data_cuda))

CPU times: total: 781 ms
Wall time: 704 ms
