In [1]:
import math
import argparse
import sys
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable, Function

## Define Module

In [13]:
def fftfreqs(res, onesided=True, norm=True):
    """
    Helper function to return frequency tensors
    :param res: n_dims int tuple of number of frequency modes
    :param t: n_dims tuple of period in each dimension
    :param onsided (bool): onesided for real frequencies
    :param norm (bool): normalize frequencies to 2*pi
    :return:
    """

    n_dims = len(res)
    freqs = []
    for dim in range(n_dims - 1):
        r_ = res[dim]
        if not norm:
            freq = np.fft.fftfreq(r_, d=1/r_)
        else:
            freq = np.fft.fftfreq(r_)*2*np.pi
        freqs.append(freq)
    r_ = res[-1]
    if onesided:
        if not norm:
            freqs.append(np.fft.rfftfreq(r_, d=1/r_))
        else:
            freqs.append(np.fft.rfftfreq(r_)*2*np.pi)
    else:
        if not norm:
            freqs.append(np.fft.fftfreq(r_, d=1/r_))
        else:
            freqs.append(np.fft.fftfreq(r_)*2*np.pi)
    omega = np.meshgrid(*freqs, indexing='ij')
    omega = list(omega)
    omega[0], omega[1] = omega[1], omega[0]
    omega = np.stack(omega, axis=-1)

    return omega.astype(np.float32)


class SpConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, degree=2, bias=True, device='cuda'):
        assert degree in [0,1,2,3]
        super(SpConv2d, self).__init__() 
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.degree = degree
        self.device = torch.device(device)
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
#         self.ncoeff = int((self.degree + 1) * (self.degree + 2) / 2)
        self.ncoeff = 2
        self.coeffs = torch.Tensor(out_channels, in_channels, self.ncoeff).to(self.device)
        self.coeffs = Parameter(self.coeffs)
        self.set_coeffs()
        self.ops = None
        
    def set_coeffs(self):
        n = self.in_channels * self.ncoeff
        stdv = 1. / math.sqrt(n)
        self.coeffs.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
            
#     def get_diff_operators(self, input):
#         self.in_shape = list(input.size())[-2:]
#         k = torch.tensor(fftfreqs(self.in_shape))
#         self.k_shape = list(k.size())[:2]
#         u, v = k[..., 0], k[..., 1]
#         zeros = torch.zeros(*self.k_shape)
#         ops = []
#         # 0th order
#         ops.append(torch.stack([torch.ones(*self.k_shape), zeros], dim=-1))
#         if self.degree > 0:
#             # 1st order
#             # d/dx
#             ops.append(torch.stack([zeros, -u], dim=-1))
#             # d/dy
#             ops.append(torch.stack([zeros, -v], dim=-1))
#         if self.degree > 1:
#             # 2nd order
#             # d^2/dx/dy
#             ops.append(torch.stack([-u*v, zeros], dim=-1))
#             # d^2/dx^2
#             ops.append(torch.stack([-u**2, zeros], dim=-1))
#             # d^2/dx^2
#             ops.append(torch.stack([-v**2, zeros], dim=-1))
#         if self.degree > 2:
#             # 3nd order
#             # d^3/dx/dx/dx
#             ops.append(torch.stack([zeros, u**3], dim=-1))
#             # d^3/dx/dx/dy
#             ops.append(torch.stack([zeros, (u**2)*v], dim=-1))
#             # d^3/dx/dy/dy
#             ops.append(torch.stack([zeros, u*(v**2)], dim=-1))
#             # d^3/dy/dy/dy
#             ops.append(torch.stack([zeros, v**3], dim=-1))

#         self.ops = torch.stack(ops, dim=0).to(self.device) # shape (ncoeff, *self.k_shape, 2)
        
    def get_diff_operators(self, input):
        self.in_shape = list(input.size())[-2:]
        k = torch.tensor(fftfreqs(self.in_shape))
        self.k_shape = list(k.size())[:2]
        u, v = k[..., 0], k[..., 1]
        zeros = torch.zeros(*self.k_shape)
        ops = []
        # I
        ops.append(torch.stack([torch.ones(*self.k_shape), zeros], dim=-1))
        # Nabla^2
        ops.append(torch.stack([-u**2-v**2, zeros], dim=-1))
# #       # Nabla^4
#         ops.append(torch.stack([(u**2+v**2)**2, zeros], dim=-1)/10)
        self.ncoeff = len(ops)
        self.ops = torch.stack(ops, dim=0).to(self.device) # shape (ncoeff, *self.k_shape, 2)
        
    def forward(self, input):
        if self.ops is None:
            self.get_diff_operators(input)
        F_input = torch.rfft(input, 2)
        c = self.coeffs.view(*self.coeffs.size(), 1, 1, 1)
        self.weight = torch.sum(torch.mul(c, self.ops), dim=2)
#         F_output = torch.sum(torch.mul(F_input.unsqueeze(1), self.weight), dim=2)
        # convolution
        Fr, Fi = F_input[..., 0].unsqueeze(1), F_input[..., 1].unsqueeze(1)
        Wr, Wi = self.weight[..., 0], self.weight[..., 1]
        out_real = torch.sum(torch.mul(Fr, Wr) - torch.mul(Fi, Wi), dim=2)
        out_imag = torch.sum(torch.mul(Fr, Wi) + torch.mul(Fi, Wr), dim=2)
        F_output = torch.stack([out_real, out_imag], dim=-1)
        f_output = torch.irfft(F_output, 2, signal_sizes=self.in_shape) + self.bias.view(1, -1, 1, 1)
        return f_output

## Define Model

In [14]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.csize = 64*7*7
#         self.conv1 = nn.Conv2d(1, 10, kernel_size=3, padding=1)
#         self.conv2 = nn.Conv2d(10, 20, kernel_size=3, padding=1)
        self.conv1 = SpConv2d(1, 32, degree=2)
        self.conv2 = SpConv2d(32, 64, degree=2)
        self.conv2_drop = nn.Dropout2d()
#         self.bn1 = nn.BatchNorm2d(32)
#         self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(self.csize, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
#         x = F.relu(F.max_pool2d(self.bn1(self.conv1(x)), 2))
#         x = F.relu(F.max_pool2d(self.conv2_drop(self.bn2(self.conv2(x))), 2))
        x = x.view(-1, self.csize)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

## Train and Test

In [15]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
#         print(model.conv1.coeffs.grad)
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            sys.stdout.write('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \r'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            sys.stdout.flush()

def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%) \r'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

## Run

In [16]:
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=1e-2, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                    help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args('')
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
#                        transforms.RandomRotation(180),
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
#                        transforms.RandomRotation(180),
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,)),
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
# optimizer = optim.Adam(model.parameters(), lr=args.lr)

for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)

Test set: Average loss: 0.2821, Accuracy: 9139/10000 (91%) 
Test set: Average loss: 0.2079, Accuracy: 9393/10000 (94%) 
Test set: Average loss: 0.1978, Accuracy: 9374/10000 (94%) 
Test set: Average loss: 0.1457, Accuracy: 9533/10000 (95%) 
Test set: Average loss: 0.1346, Accuracy: 9563/10000 (96%) 
Test set: Average loss: 0.1226, Accuracy: 9609/10000 (96%) 
Test set: Average loss: 0.1098, Accuracy: 9655/10000 (97%) 
Test set: Average loss: 0.1015, Accuracy: 9686/10000 (97%) 
Test set: Average loss: 0.1005, Accuracy: 9672/10000 (97%) 
Test set: Average loss: 0.0951, Accuracy: 9706/10000 (97%) 
Test set: Average loss: 0.0937, Accuracy: 9723/10000 (97%) 
Test set: Average loss: 0.0901, Accuracy: 9708/10000 (97%) 
Test set: Average loss: 0.0852, Accuracy: 9722/10000 (97%) 
Test set: Average loss: 0.0845, Accuracy: 9733/10000 (97%) 
Test set: Average loss: 0.0824, Accuracy: 9741/10000 (97%) 
Test set: Average loss: 0.0814, Accuracy: 9755/10000 (98%) 
Test set: Average loss: 0.0777, Accuracy

In [None]:
import matplotlib.pyplot as plt
c1 = model.conv1.coeffs.detach().cpu().detach().numpy().reshape(-1, 10)
c2 = model.conv2.coeffs.detach().cpu().detach().numpy().reshape(-1, 10)

plt.figure()
plt.imshow(np.absolute(c1))
plt.colorbar()
plt.show()

plt.figure(figsize=(5, 30
                   ))
plt.imshow(np.absolute(c2))
plt.colorbar()
plt.show()

In [None]:
plt.figure()
plt.bar(np.arange(c1.shape[1]), np.absolute(c1).mean(0))
plt.show()
plt.figure()
plt.bar(np.arange(c2.shape[1]), np.absolute(c2).mean(0))
plt.show()

In [None]:
params = dict(model.named_parameters())

In [None]:
for k in params.keys():
    print(k, "shape: ", params[k].shape)
#     print(params[k])

In [None]:
import torch
from torch.autograd import Variable
from torch import nn
from torchviz import make_dot, make_dot_from_trace

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    break
output = model(data)
make_dot(output, params=dict(model.named_parameters()))

In [None]:
z