In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import PIL
import numpy as np
import matplotlib.pylab as plt
import os

%matplotlib inline

In [2]:
import modules.custom_transformers as custom_transformers

## TODO \#2: Make a transform that turns the image into a 4d block

### 27 channels (9 basis functions * 3 original channels) x scale x height x width


In [3]:
trainset_ycbcr = torchvision.datasets.CIFAR10(
    root='./image_files',
    train=True,
    download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        custom_transformers.ToYCbYr()
    ])
)
trainloader_ycbcr = torch.utils.data.DataLoader(trainset_ycbcr, batch_size=128, shuffle=False, num_workers=0)
gen_ycbcr = iter(trainloader_ycbcr)
images_cpu_ycbcr, labels_cpu_ycbcr = next(gen_ycbcr)

In [4]:
class StackDCTs(nn.Module):
    
    @staticmethod
    def make_bases(length, num_bases):
        xs = torch.tensor(range(length)).type(torch.FloatTensor)
        bases = [
            torch.cos(np.pi * p * (2. * xs + 1) / (2 * length))
            for p in range(num_bases)
        ]

        def mesh_bases(b1, b2):
            rr, cc = torch.meshgrid([b1, b2])
            return rr * cc

        full_bases = torch.stack([
            mesh_bases(b1, b2)
            for b1 in bases
            for b2 in bases
        ])
        return full_bases

    def __init__(self, num_bases, lengths):
        super(StackDCTs, self).__init__()
        self.num_bases = num_bases
        self.lengths = lengths.copy()
        for length in self.lengths:
            buffer_name = 'basis_convolution_weights_{0}'.format(length)
            self.register_buffer(
                buffer_name,
                StackDCTs.make_bases(length, num_bases).repeat(3,1,1).unsqueeze(1)
            )
            
    def forward(self, minibatch):
        scales = []
        for length in self.lengths:
            buffer_name = 'basis_convolution_weights_{0}'.format(length)
            repeated_bases = self.state_dict()[buffer_name]
            left_padding = (repeated_bases.shape[-1] - 1) // 2
            right_padding = repeated_bases.shape[-1] - 1 - left_padding
            top_padding = (repeated_bases.shape[-2] - 1) // 2
            bottom_padding = repeated_bases.shape[-2] - 1 - top_padding
            minibatch_padded = F.pad(
                minibatch,
                (left_padding, right_padding, top_padding, bottom_padding)
            )
            scales.append(F.conv2d(
                input=minibatch_padded,
                weight=repeated_bases,
                groups=3
            ))
        return torch.stack(scales, dim=2)


In [7]:
sd = StackDCTs(3, [3, 4, 6, 8, 11, 16, 22])
my_block = sd(images_cpu_ycbcr)
my_block.shape

torch.Size([128, 27, 7, 32, 32])

In [8]:
sd_cuda = StackDCTs(3, [3, 4, 6, 8, 11, 16, 22]).to('cuda')
images_ycbcr = images_cpu_ycbcr.to('cuda')
my_block_cuda = sd_cuda(images_ycbcr)

In [11]:
for i in range(10):
    my_block = sd(images_cpu_ycbcr)
my_block.shape

torch.Size([128, 27, 7, 32, 32])

In [12]:
for i in range(10):
    my_block_cuda = sd_cuda(images_ycbcr)
my_block_cuda.shape

torch.Size([128, 27, 7, 32, 32])

In [None]:
def make_bases(length, num_bases):
    xs = torch.tensor(range(length)).type(torch.FloatTensor)
    bases = [
        torch.cos(np.pi * p * (2. * xs + 1) / (2 * length))
        for p in range(num_bases)
    ]

    def mesh_bases(b1, b2):
        rr, cc = torch.meshgrid([b1, b2])
        return rr * cc

    full_bases = torch.stack([
        mesh_bases(b1, b2)
        for b1 in bases
        for b2 in bases
    ])
    return full_bases


In [None]:
num_bases = 3
for length in [6, 16]:
    fig, axes = plt.subplots(
        num_bases, num_bases,
        subplot_kw={'xticks': [], 'yticks': []},
        figsize=(6, 6)
    )
    bases = make_bases(length, num_bases)
    for i, ax in enumerate(axes.flat):
        ax.imshow(bases[i])
    plt.tight_layout()
    plt.show()
    print('-' * 60)


In [None]:
bases.shape

In [None]:
images_cpu_ycbcr.shape

In [None]:
bases.repeat(3,1,1).unsqueeze(1).shape

In [None]:
repeated_bases = bases.repeat(3,1,1).unsqueeze(1)

In [None]:
repeated_bases.shape

output width = input width + left padding + right padding - kernel width + 1

input width = bases.shape\[-1\]

kernel width = repeated_bases.shape\[-1\]

want output width = input width

padding should be (kernel width - 1)/2

if that's an even number, then:

output width = input width + 2*(kernel width - 1)/2  - kernel width + 1 = input width

if kernel width is even, then use

left padding = kernel width / 2 - 1

right padding = kernel width / 2

then:

output width = input width + (kernel width / 2 - 1) + (kernel width / 2) - kernel width + 1

= input width + kernel width - 1 - kernel width + 1

= input width


in all cases, want left_padding + right_padding - kernel width + 1 == 0

so given left_padding, set right_padding = kernel width - 1 - left_padding

In [None]:
left_padding = (repeated_bases.shape[-1] - 1) // 2
right_padding = repeated_bases.shape[-1] - 1 - left_padding
top_padding = (repeated_bases.shape[-2] - 1) // 2
bottom_padding = repeated_bases.shape[-2] - 1 - top_padding
images_padded = F.pad(images_cpu_ycbcr, (left_padding, right_padding, top_padding, bottom_padding))

In [None]:
maybe_convolved = F.conv2d(
    input=images_padded,
    weight=repeated_bases,
    groups=3
)

In [None]:
maybe_convolved.shape

In [None]:
images_cpu_ycbcr[0,1,:16,:16].sum()

In [None]:
maybe_convolved.shape

In [None]:
bases_2 = make_bases(6, 3)

In [None]:
repeated_bases_2 = bases_2.repeat(3,1,1).unsqueeze(1)

In [None]:
left_padding_2 = (repeated_bases_2.shape[-1] - 1) // 2
right_padding_2 = repeated_bases_2.shape[-1] - 1 - left_padding_2
top_padding_2 = (repeated_bases_2.shape[-2] - 1) // 2
bottom_padding_2 = repeated_bases_2.shape[-2] - 1 - top_padding_2
images_padded_2 = F.pad(images_cpu_ycbcr, (left_padding_2, right_padding_2, top_padding_2, bottom_padding_2))

In [None]:
maybe_convolved_2 = F.conv2d(
    input=images_padded_2,
    weight=repeated_bases_2,
    groups=3
)

In [None]:
maybe_convolved_2.shape

In [None]:
maybe_convolved.mean(dim=(0,2,3))

In [None]:
maybe_convolved_2.mean(dim=(0,2,3))

In [None]:
torch.stack([
    maybe_convolved,
    maybe_convolved_2,
], dim=2).shape

In [None]:
sd.state_dict()['basis_convolution_weights_5'].shape