In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import PIL
import numpy as np
import matplotlib.pylab as plt
import os

%matplotlib inline

In [3]:
import modules.custom_transformers as custom_transformers

In [4]:
class StackDCTs(nn.Module):
    
    @staticmethod
    def make_bases(length, num_bases):
        xs = torch.tensor(range(length)).type(torch.FloatTensor)
        bases = [
            torch.cos(np.pi * p * (2. * xs + 1) / (2 * length))
            for p in range(num_bases)
        ]

        def mesh_bases(b1, b2):
            rr, cc = torch.meshgrid([b1, b2])
            return rr * cc

        full_bases = torch.stack([
            mesh_bases(b1, b2)
            for b1 in bases
            for b2 in bases
        ])
        return full_bases

    def __init__(self, num_bases, lengths):
        super(StackDCTs, self).__init__()
        self.num_bases = num_bases
        self.lengths = lengths.copy()
        for length in self.lengths:
            buffer_name = 'basis_convolution_weights_{0}'.format(length)
            self.register_buffer(
                buffer_name,
                StackDCTs.make_bases(length, num_bases).repeat(3,1,1).unsqueeze(1)
            )
            
    def forward(self, minibatch):
        scales = []
        for length in self.lengths:
            buffer_name = 'basis_convolution_weights_{0}'.format(length)
            repeated_bases = self.state_dict()[buffer_name]
            left_padding = (repeated_bases.shape[-1] - 1) // 2
            right_padding = repeated_bases.shape[-1] - 1 - left_padding
            top_padding = (repeated_bases.shape[-2] - 1) // 2
            bottom_padding = repeated_bases.shape[-2] - 1 - top_padding
            minibatch_padded = F.pad(
                minibatch,
                (left_padding, right_padding, top_padding, bottom_padding)
            )
            scales.append(F.conv2d(
                input=minibatch_padded,
                weight=repeated_bases,
                groups=3
            ))
        return torch.stack(scales, dim=2)


In [23]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.stack = StackDCTs(3, [3, 4, 6, 8])
        self.pool_flat = nn.MaxPool3d(
            kernel_size=(1, 2, 2),
            stride=(1, 2, 2)
        )
        self.pool_deep = nn.MaxPool3d(
            kernel_size=2,
            stride=2
        )

        # Size is (n, 3 * (3 * 3), 4, 32, 32)
        self.conv1 = nn.Conv3d(
            in_channels=3 * (3*3),
            out_channels=64,
            kernel_size=3,
            padding=1,
        )
        self.conv2 = nn.Conv3d(
            in_channels=64,
            out_channels=128,
            kernel_size=3,
            padding=1,
        )
        self.conv3 = nn.Conv3d(
            in_channels=128,
            out_channels=128,
            kernel_size=3,
            padding=1,
        )
        self.conv4 = nn.Conv3d(
            in_channels=128,
            out_channels=256,
            kernel_size=3,
            padding=1,
        )
        self.fc1 = nn.Linear(256 * 2 * 2 * 2, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = self.stack(x)
        x = self.pool_flat(F.relu(self.conv1(x)))
        x = self.pool_flat(F.relu(self.conv2(x)))
        x = self.pool_flat(F.relu(self.conv3(x)))
        x = self.pool_deep(F.relu(self.conv4(x)))
        x = x.view(-1, 256 * 2 * 2 * 2)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [24]:
trainset_ycbcr = torchvision.datasets.CIFAR10(
    root='./image_files',
    train=True,
    download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.2, 0.2, 0.2)),
        custom_transformers.ToYCbYr()
    ])
)
trainloader_ycbcr = torch.utils.data.DataLoader(trainset_ycbcr, batch_size=16, shuffle=True, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [25]:
net = Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)

In [26]:
for epoch in range(5):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader_ycbcr, 0):
        # get the inputs
        inputs_cpu, labels_cpu = data
        inputs, labels = inputs_cpu.cuda(), labels_cpu.cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0


[1,   100] loss: 2.168
[1,   200] loss: 2.012
[1,   300] loss: 1.922
[1,   400] loss: 1.875
[1,   500] loss: 1.863
[1,   600] loss: 1.828
[1,   700] loss: 1.781
[1,   800] loss: 1.727
[1,   900] loss: 1.690
[1,  1000] loss: 1.679
[1,  1100] loss: 1.701
[1,  1200] loss: 1.609
[1,  1300] loss: 1.606
[1,  1400] loss: 1.552
[1,  1500] loss: 1.565
[1,  1600] loss: 1.633
[1,  1700] loss: 1.625
[1,  1800] loss: 1.563
[1,  1900] loss: 1.590
[1,  2000] loss: 1.556
[1,  2100] loss: 1.604
[1,  2200] loss: 1.549
[1,  2300] loss: 1.489
[1,  2400] loss: 1.547
[1,  2500] loss: 1.507
[1,  2600] loss: 1.530
[1,  2700] loss: 1.522
[1,  2800] loss: 1.621
[1,  2900] loss: 1.490
[1,  3000] loss: 1.497
[1,  3100] loss: 1.498
[2,   100] loss: 1.499
[2,   200] loss: 1.420
[2,   300] loss: 1.430
[2,   400] loss: 1.447
[2,   500] loss: 1.471
[2,   600] loss: 1.441
[2,   700] loss: 1.445
[2,   800] loss: 1.424
[2,   900] loss: 1.481
[2,  1000] loss: 1.528
[2,  1100] loss: 1.425
[2,  1200] loss: 1.441
[2,  1300] 

In [27]:
testset = torchvision.datasets.CIFAR10(root='./image_files', train=False,
                                       download=False,     transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.2, 0.2, 0.2)),
        custom_transformers.ToYCbYr()
    ])
)
testloader = torch.utils.data.DataLoader(testset, batch_size=200,
                                         shuffle=False, num_workers=2)


In [28]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images_cpu, labels_cpu = data
        images, labels = images_cpu.cuda(), labels_cpu.cuda()
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 47 %


In [29]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images_cpu, labels_cpu = data
        images, labels = images_cpu.cuda(), labels_cpu.cuda()
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))


Accuracy of plane : 40 %
Accuracy of   car : 52 %
Accuracy of  bird : 21 %
Accuracy of   cat : 47 %
Accuracy of  deer : 29 %
Accuracy of   dog : 46 %
Accuracy of  frog : 54 %
Accuracy of horse : 81 %
Accuracy of  ship : 44 %
Accuracy of truck : 85 %
