# Handling images with Pytorch

## Image dataset

In [2]:
from torchvision.datasets import ImageFolder
from torchvision import transforms

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import torch.nn as nn

In [None]:
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128))
])

dataset_train = ImageFolder(
    "data/clouds_train",
    transform=train_transforms
)

## Data augmentation in Pytorch

In [None]:
train_transforms = transforms.Compose([
  transforms.RandomHorizontalFlip(),
  transforms.RandomRotation(45),
  transforms.ToTensor(),
  transforms.Resize((128, 128))
])

dataset_train = ImageFolder(
  "data/clouds_train",
  transform=train_transforms,
)

dataloader_train = DataLoader(
  dataset_train, shuffle=True, batch_size=1
)

image, label = next(iter(dataloader_train))
image = image.squeeze().permute(1, 2, 0) 
plt.imshow(image)
plt.show()

# Convolutional neural networks

## Building convolutional networks

In [3]:
class Net(nn.Module):
  def __init__(self, num_classes):
    super().__init__()
    self.feature_extractor = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, padding=1),
        nn.ELU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(32, 64, kernel_size=3, padding=1),
        nn.ELU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Flatten(),
    )
    # Define classifier
    self.classifier = nn.Linear(64*16*16, num_classes)

  def forward(self, x):  
    x = self.feature_extractor(x)
    x = self.classifier(x)
    return x