In [1]:
import os
import glob
import numpy as np
import cv2
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt


class Dataset5(Dataset):
    def __init__(self, root):
        """
        :param root: image folder
        """
        img_paths = glob.glob(os.path.join(root, '**/*.png'), recursive=True)
        img_paths.sort(key=str)
        data_array = np.empty(shape=(len(img_paths), 256, 256, 3), dtype=np.float32)
        for i, n in enumerate(img_paths):
            img = cv2.imread(n, cv2.IMREAD_REDUCED_COLOR_4)
            data_array[i, ...] = img
        self.data = np.moveaxis(data_array, -1, 1)
        self.root = root

    def __getitem__(self, item):
        img = self.data[item, ...]
        return img

    def __len__(self):
        return len(self.data)


### Autoencoder

class ConvAutoEncoder(nn.Module):
    def __init__(self, in_channels: int, hid_channels: int, code_channels: int,
                 kernel_size: int, stride: int = 1, pooling: int = 2,
                 activation: nn.Module = nn.ReLU()):

        super().__init__()

        self.encoder_hidden = nn.Conv2d(in_channels=in_channels,
                                        out_channels=hid_channels,
                                        kernel_size=kernel_size,
                                        stride=stride)
        self.encoder_pool = nn.AvgPool2d(kernel_size=pooling)
        self.encoder_output = nn.Conv2d(in_channels=hid_channels,
                                        out_channels=code_channels,
                                        kernel_size=kernel_size,
                                        stride=stride)

        self.act = activation

        self.decoder_hidden = nn.ConvTranspose2d(in_channels=code_channels,
                                                 out_channels=hid_channels,
                                                 kernel_size=kernel_size,
                                                 stride=stride)
        self.decoder_unpool = nn.ConvTranspose2d(in_channels=hid_channels,
                                                 out_channels=hid_channels,
                                                 kernel_size=pooling,
                                                 stride=pooling)

        self.decoder_unpool.weight.data.fill_(1 / (pooling * pooling))
        self.decoder_unpool.weight.requires_grad = False
        self.decoder_unpool.bias.requires_grad = False

        self.decoder_output = nn.ConvTranspose2d(in_channels=hid_channels,
                                                 out_channels=in_channels,
                                                 kernel_size=kernel_size,
                                                 stride=stride)

    def forward(self, x):
        x = self.encoder_hidden(x)
        x = self.act(x)
        x = self.encoder_pool(x)
        x = self.encoder_output(x)
        x = self.act(x)
        x = self.decoder_hidden(x)
        x = self.decoder_unpool(x)
        x = self.act(x)
        x = self.decoder_output(x)
        return x

    @torch.no_grad()
    def reconstruct(self, x: torch.Tensor) -> torch.Tensor:
        """
        x : torch.Tensor
            Inputs to be reconstructed.
        y : torch.Tensor
            Result of reconstruction, with values
            in the same range as the targets.
        """
        logits = self.forward(x)
        return torch.clamp(logits, x.min(), x.max())


def _forward(network: nn.Module, data: DataLoader, metric: callable):
    device = next(network.parameters()).device

    for x in data:
        x, y = x.to(device), x.to(device)
        logits = network(x)
        res = metric(logits, y)
        yield res


@torch.no_grad()
def evaluate(network: nn.Module, data: DataLoader, metric: callable) -> list:
    network.eval()

    results = _forward(network, data, metric)
    return [res.item() for res in results]


@torch.enable_grad()
def update(network: nn.Module, data: DataLoader, loss: nn.Module,
           opt: optim.Optimizer) -> list:
    network.train()

    errs = []
    for err in _forward(network, data, loss):
        errs.append(err.item())

        opt.zero_grad()
        err.backward()
        opt.step()

    return errs


def train_auto_encoder(auto_encoder: nn.Module, loader: DataLoader,
                       objective: nn.Module, optimiser: optim.Optimizer,
                       num_epochs: int = 10, vis_every: int = 5):

    # take random batch for visualising reconstructions
    #ref_inputs, _ = next(iter(loader))
    ref_inputs = next(iter(loader))

    # evaluate random performance
    errs = evaluate(auto_encoder, loader, objective)
    print(f"Epoch {0: 2d} - avg loss: {sum(errs) / len(errs):.6f}")
    #display_result(auto_encoder, ref_inputs)

    # train for some epochs
    for epoch in range(1, num_epochs + 1):
        errs = update(auto_encoder, loader, objective, optimiser)
        print(f"Epoch {epoch: 2d} - avg loss: {sum(errs) / len(errs):.6f}")

        #if epoch % vis_every == 0:
            #display_result(auto_encoder, ref_inputs)




root_dir = r"preprocessed_data"
dataset = Dataset5(root=root_dir)

### Hyperparameters

epochs = 10
batch_size = 7
lr = 1e-4

#examples_to_show = 8

### parameters for autoencoder
in_channels = 3
hid_channels = 128
code_channels = 8
kernel_size = 5
pooling = 2


### Start of training

print(dataset.root)
print(len(dataset))

train_loader = torch.utils.data.DataLoader(dataset, batch_size=7, shuffle=True, num_workers=0)

print(next(iter(train_loader)).shape)

loss_func = nn.MSELoss()

model = ConvAutoEncoder(in_channels = in_channels, hid_channels = hid_channels, code_channels = code_channels, kernel_size = kernel_size, pooling = pooling)
opt = optim.Adam(model.parameters(), lr = lr)

print("Code size:", int(code_channels*((28-kernel_size+1)/pooling-kernel_size+1)**2))

train_auto_encoder(auto_encoder = model, loader = train_loader,
                   objective = loss_func, optimiser = opt, num_epochs = epochs)

torch.save(model, 'trained_model.pt')
print("finished training")




preprocessed_data
0


ValueError: num_samples should be a positive integer value, but got num_samples=0