# Computer Vision

# 0. Computer vision libraries
* `torchvision` - base library in PyTorch
* `torchvision.dataset` - datasets and data loading functions
* `torchvision.models` - pretrained computer vision models
* `torchvision.transforms` - functions for manipulating vision data
* `torch.utils.data.Dataset` - base dataset class for PT
* `torch.utils.data.Dataloader`  - create pythorn iterable over a dataset

In [None]:
import torch
from torch import nn

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt

# version
print(torch.__version__)
print(torchvision.__version__)

# 1. Dataset

In [None]:
# FashionMNIST
train_data = datasets.FashionMNIST(
    root='data', # where to put data
    train=True, # get training dataset
    download=True,
    transform=ToTensor(),
    target_transform=None
)

test_data = datasets.FashionMNIST(
    root='data', # where to put data
    train=False, # get training dataset
    download=True,
    transform=ToTensor(),
    target_transform=None
)




In [None]:
len(train_data), len(test_data)

In [None]:
# view first training example
image, label = train_data[0]
image, label

In [None]:
class_names = train_data.classes
class_names[0]

In [None]:
class_to_idx = train_data.class_to_idx
class_to_idx

In [None]:
# shapes
print(f"image.shape: {image.shape} -> [color_channels, height, width] ")
print(f"Image label: {class_names[label]}")

# visualizing data

In [None]:
image, label = train_data[0]
print(f"image shape: {image.shape}")
# remove first dim 
plt.imshow(image.squeeze())
plt.title(label)

In [None]:
plt.imshow(image.squeeze(), cmap="gray")
plt.title(label)
plt.axis(False)

In [None]:
# plt more images
torch.manual_seed(42)
fig = plt.figure(figsize=(9,9))
rows, cols = 4, 4
for i in range(1, rows*cols+1):
  random_idx = torch.randint(0, len(train_data), size=[1]).item()
  img, label = train_data[random_idx]
  fig.add_subplot(rows, cols, i)
  plt.imshow(img.squeeze(), cmap="gray")
  plt.title(class_names[label])
  plt.axis(False)