In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import PIL
import numpy as np
import matplotlib.pylab as plt
import os

%matplotlib inline

In [25]:
import modules.custom_transformers as custom_transformers

## TODO \#1: Change the image from RGB to YCbCr

In [3]:
trainset = torchvision.datasets.CIFAR10(
    root='./image_files',
    train=True,
    download=False,
    transform=transforms.ToTensor()
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=False, num_workers=0)
gen = iter(trainloader)
images_cpu, labels_cpu = next(gen)

In [4]:
one_image = images_cpu[4]

In [5]:
type(one_image)

torch.Tensor

In [6]:
one_image.shape

torch.Size([3, 32, 32])

In [7]:
rgb_to_ycbyr = torch.tensor([
    [0.299, 0.587, 0.114],
    [-0.169, -0.331, 0.5],
    [0.5, -0.419, -0.081]
])

In [16]:
torch.matmul(rgb_to_ycbyr, torch.tensor([0.5, 0.5, 0]))

tensor([ 0.4430, -0.2500,  0.0405])

In [17]:
one_image[:,0,0]

tensor([0.6667, 0.7059, 0.7765])

In [19]:
torch.matmul(rgb_to_ycbyr, one_image[:,0,0])

tensor([ 0.7022,  0.0419, -0.0253])

R is the conversion matrix

P is the picture (C, H, W)

$$y_{ikl} = \sum_j R_{ij} P_{jkl}$$

In [20]:
one_image_ycbcr = torch.einsum('ij,jkl->ikl', [rgb_to_ycbyr, one_image])

In [21]:
one_image_ycbcr[:,0,0]

tensor([ 0.7022,  0.0419, -0.0253])

In [22]:
class ToYCbYr(object):
    rgb_to_ycbyr_matrix = torch.tensor([
        [0.299, 0.587, 0.114],
        [-0.169, -0.331, 0.5],
        [0.5, -0.419, -0.081]
    ])
    
    def __call__(self, pic):
        return torch.einsum('ij,jkl->ikl', [ToYCbYr.rgb_to_ycbyr_matrix, pic])
    
    def __repr__(self):
        return self.__class__.__name__ + '()'
    

In [23]:
ToYCbYr()(one_image)[:,0,0]

tensor([ 0.7022,  0.0419, -0.0253])

In [26]:
custom_transformers.ToYCbYr()(one_image)[:,0,0]

tensor([ 0.7022,  0.0419, -0.0253])

In [27]:
trainset_ycbcr = torchvision.datasets.CIFAR10(
    root='./image_files',
    train=True,
    download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        custom_transformers.ToYCbYr()
    ])
)
trainloader_ycbcr = torch.utils.data.DataLoader(trainset_ycbcr, batch_size=16, shuffle=False, num_workers=0)
gen_ycbcr = iter(trainloader_ycbcr)
images_cpu_ycbcr, labels_cpu_ycbcr = next(gen_ycbcr)

In [28]:
images_cpu_ycbcr[0]

tensor([[[ 2.4007e-01,  1.7643e-01,  1.8835e-01,  ...,  5.3740e-01,
           5.1157e-01,  5.0503e-01],
         [ 7.3741e-02,  0.0000e+00,  3.9522e-02,  ...,  3.7138e-01,
           3.5295e-01,  3.6880e-01],
         [ 9.3949e-02,  3.4875e-02,  1.2318e-01,  ...,  3.5408e-01,
           3.5642e-01,  3.1463e-01],
         ...,
         [ 6.7814e-01,  6.0308e-01,  6.1440e-01,  ...,  5.2506e-01,
           1.4015e-01,  1.4935e-01],
         [ 5.7395e-01,  5.0477e-01,  5.6299e-01,  ...,  5.9846e-01,
           2.7166e-01,  2.3453e-01],
         [ 5.9088e-01,  5.3596e-01,  5.7566e-01,  ...,  7.3942e-01,
           4.8624e-01,  3.8819e-01]],

        [[ 3.9490e-03,  2.7448e-05, -1.1129e-02,  ..., -6.4290e-02,
          -6.2992e-02, -5.7082e-02],
         [ 2.6510e-03,  0.0000e+00, -2.2314e-02,  ..., -8.7902e-02,
          -8.8565e-02, -8.2020e-02],
         [-6.5451e-03, -1.9690e-02, -5.1835e-02,  ..., -8.9200e-02,
          -9.0526e-02, -8.4643e-02],
         ...,
         [-1.7028e-01, -2

In [29]:
custom_transformers.ToYCbYr()(images_cpu[0])

tensor([[[ 2.4007e-01,  1.7643e-01,  1.8835e-01,  ...,  5.3740e-01,
           5.1157e-01,  5.0503e-01],
         [ 7.3741e-02,  0.0000e+00,  3.9522e-02,  ...,  3.7138e-01,
           3.5295e-01,  3.6880e-01],
         [ 9.3949e-02,  3.4875e-02,  1.2318e-01,  ...,  3.5408e-01,
           3.5642e-01,  3.1463e-01],
         ...,
         [ 6.7814e-01,  6.0308e-01,  6.1440e-01,  ...,  5.2506e-01,
           1.4015e-01,  1.4935e-01],
         [ 5.7395e-01,  5.0477e-01,  5.6299e-01,  ...,  5.9846e-01,
           2.7166e-01,  2.3453e-01],
         [ 5.9088e-01,  5.3596e-01,  5.7566e-01,  ...,  7.3942e-01,
           4.8624e-01,  3.8819e-01]],

        [[ 3.9490e-03,  2.7448e-05, -1.1129e-02,  ..., -6.4290e-02,
          -6.2992e-02, -5.7082e-02],
         [ 2.6510e-03,  0.0000e+00, -2.2314e-02,  ..., -8.7902e-02,
          -8.8565e-02, -8.2020e-02],
         [-6.5451e-03, -1.9690e-02, -5.1835e-02,  ..., -8.9200e-02,
          -9.0526e-02, -8.4643e-02],
         ...,
         [-1.7028e-01, -2