# Loading images and computing relevant statistics

Faisal Z. Qureshi     
http://vclab.science.uoit.ca

You can find excellent documentation for Pytorch at [https://pytorch.org/docs/stable/index.html](https://pytorch.org/docs/stable/index.html)

## Loading images

- [PIL image concepts](https://pillow.readthedocs.io/en/stable/handbook/concepts.html).
- [Numpy image concepts](https://scikit-image.org/docs/dev/user_guide/numpy_images.html)

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

### Loading an RGBA image

In [None]:
filename = './3063.jpg'
image = Image.open(filename)
print('Image:', image.size, image.getbands())
plt.imshow(image)
plt.show();

### Loading a single channel image (mask)

In [None]:
filename = './3063.png'
image = Image.open(filename)
print('Image:', image.size, image.getbands())
plt.imshow(image)
plt.show();

### Loading an RGBA image

In [None]:
filename = './frog.png'
image = Image.open(filename)
print('Image:', image.size, image.getbands())
plt.imshow(image)
plt.show();

### Loading a grayscale image

In [None]:
filename = './soldiers.jpg'
image = Image.open(filename)
print('Image:', image.size, image.getbands())
plt.imshow(image)
plt.show();

### Converting the image to a numpy array

In [None]:
import numpy as np

filenames = ['frog.png', '3063.png', '3063.jpg', 'soldiers.jpg']
image = Image.open(filenames[1])

numpy_im = np.array(image)
print('Image: ', 'Numpy shape=', numpy_im.shape, ' [PIL bands=', image.getbands(),']')
plt.imshow(numpy_im)
plt.show()

### Converting an image to torch tensor

#### Method 1

Image -> Numpy array -> Torch tensor 

In [None]:
t = torch.from_numpy(numpy_im)
print(t.shape)
plt.imshow(t)

#### Method 2

Image -> Torch tensor

We use `torchvision.transforms.ToTensor()` to convert the PIL image to a tensor.  Note that torch assumes that the first dimension includes channels.  This is different from how we typically represent images.  In general we assume that the last dimension includes channels. This suggests we will have to transpose the torch tensor before we can display it using matplotlib.  Of course this assumes that tensor type is `float` and pixel values lie between 0.0 and 1.0.

In [None]:
from torchvision import transforms

transform = transforms.Compose([transforms.ToTensor()])
t = transform(image)
print(t.shape)

t1= t.transpose(0,2); print(t1.shape)  # 4, h, w -> w, h, 4
t1=t1.transpose(0,1); print(t1.shape)  # w, h, 4 -> w, h, 4
plt.imshow(t1)

## Computing Image Mean and Variance

Computing mean and variance for red, blue and green channels.  We assume that an image is `(num_rows) x (num_columns) x (num_channels)`.  Note that num_rows refer to the height of the image, num_columns refer to the width of the image.  We will use `torch.mean` to compute mean values for each row (i.e., each channel).  This operation is routinely performed during a data preprocessing step.

In [None]:
filenames = ['frog.png', '3063.png', '3063.jpg', 'soldiers.jpg']
image = Image.open(filenames[0])
t = transforms.Compose([transforms.ToTensor()])(image)
plt.imshow(t.transpose(0,2).transpose(0,1))

In [None]:
n_channels = t.shape[0]
print(f'Num of channels = {n_channels}')

t_flattened = t.view(-1, n_channels)
print(t_flattened.shape)

means = torch.mean(t_flattened, dim=0, keepdim=True)
print(means)
var = torch.var(t_flattened, dim=0, unbiased=True, keepdim=True)
print(var)

In [None]:
std_dev = torch.sqrt(var)
print(std_dev)

In [None]:
normalize_image = transforms.Normalize(means.squeeze(), std_dev.squeeze())
normalized_t = normalize_image(t)

In [None]:
plt.title('Plotting the normalized image')
tmp = normalized_t.transpose(0,2).transpose(0,1)[:,:,:3] # This only works
plt.imshow(tmp);

## See timing with and without using cuda

In [None]:
import time

In [None]:
is_cuda = torch.cuda.is_available()

In [None]:
print('Without cuda')
start_time_1 = time.time()
for i in range(100):
    tim.var()
end_time_1 = time.time()
print(end_time_1 - start_time_1)

In [None]:
if is_cuda:
    print('With cuda')
    tim_ = tim.cuda()
    start_time_2 = time.time()
    for i in range(100):
        tim_.var()
    end_time_2 = time.time()
    print(end_time_2 - start_time_2)