<h1 align="center">
  Nienadzorowana reprezentacja autoenkodery i modele generatywne
</h1>

<h4 align="center">
  11.10.2023
</h4>
<br/>


# Nienadzorowana reprezentacja danych - AutoEnkoder (AE)
### Importowanie niezbędnych modułow

In [None]:
import pickle
import numpy as np
import torch
import torchvision
from torchvision import datasets, models, transforms
from tqdm.notebook import tqdm

## Prosta implementacja PCA za pomocą warstwy liniowej

In [None]:
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt


mnist = load_digits(n_class=10)
x, y = mnist.data, mnist.target


pca = PCA(n_components=10)
pca.fit(x)

fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(6, 6))
for idx, ax in enumerate(axs.ravel()):
    ax.imshow(x[idx].reshape((8, 8)), cmap=plt.cm.gray)
    ax.axis("off")
plt.show()
plt.close()

v = pca.transform(x[:10])
x_reduced = np.dot(x[:10] - pca.mean_, pca.components_.T)
assert np.allclose(v, x_reduced)

x_original = np.dot(x_reduced, pca.components_) + pca.mean_
assert np.allclose(pca.inverse_transform(v), x_original)

# fig, ax = plt.subplots()
# ax.scatter(*v.T, marker='d', color='b', s=60)
# plt.show()
# plt.close()

fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(6, 6))
for idx, ax in enumerate(axs.ravel()):
    ax.imshow(x_original[idx].reshape((8, 8)), cmap=plt.cm.gray)
    ax.axis("off")
plt.show()
plt.close()

### PyTorch

In [None]:
class OwnPCA():
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.encoder = torch.nn.Linear(input_dim, output_dim, bias=False)
        self.dencoder = torch.nn.Linear(output_dim, input_dim, bias=False)
        
    def set_weight(self, W):
        self.encoder.weight.data[...] = torch.from_numpy(W)
        self.dencoder.weight.data[...] = torch.from_numpy(W.T)
        
    def transform(self, x):
        return self.encoder(x)
    
    def inverse_transform(self, x):
        return self.dencoder(x)

In [None]:
pca_own = OwnPCA(64, 10)
pca_own.set_weight(pca.components_)

v = pca_own.transform(
    torch.from_numpy(
        (x[:10] - pca.mean_).astype(np.float32)
    )
)

fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(6, 6))
for idx, ax in enumerate(axs.ravel()):
    ax.imshow(x[idx].reshape((8, 8)), cmap=plt.cm.gray)
    ax.axis("off")
plt.show()
plt.close()

x_ = pca_own.inverse_transform(v).detach().numpy() + pca.mean_
fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(6, 6))
for idx, ax in enumerate(axs.ravel()):
    ax.imshow(x_[idx].reshape((8, 8)), cmap=plt.cm.gray)
    ax.axis("off")
plt.show()
plt.close()

## Uczenie sieci AE

Klasa ,,*AverageMeter*'' przechowuje oraz przetwarza częściowe wyniki zapisywane w poszczegółnych etapach uczenia modelu. Funkcja ,,*count_parameters*'' zlicza liczbę parametrów sieci, zaś funkcja ,,*show*'' rysuje obrazki ze zbioru danych i ich rekonstrukcje.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(font_scale=2.5)
sns.set_style("whitegrid")


class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum = self.sum + val * n
        self.count = self.count + n
        self.avg = self.sum / self.count


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def show(img, recon_img, num_col=None):
    if recon_img is None:
        rec_images = img
    else:
        n = img.shape[0]
        assert n >= num_col
        rec_images = torch.empty((2 * num_col, *img.shape[1:]))
        rec_images.data[:num_col] = img.data[:num_col]
        rec_images.data[num_col:] = recon_img.data[:num_col]

    plt.figure(figsize=[16, 8])
    grid = torchvision.utils.make_grid(
        rec_images, nrow=num_col, padding=1, normalize=True, scale_each=True
    )
    np_grid = grid.cpu().numpy()
    plt.axis("off")
    plt.imshow(np.transpose(np_grid, (1, 2, 0)), interpolation="nearest")

# Dataloader

W tej części przygotowujemy zbiór danych do trenowania i walidacji modelu. Przetwarzamy obrazki ze zbioru *MNIST* do tensorów, które są pobierane iteracyjnie w batchach podczas trenowania sieci (zmienne: ,,*train_loader*'', ,,*test_loader*'').

In [None]:
root = "../datasets"
download = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root, download=download, train=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=False
)

test_dataset = torchvision.datasets.MNIST(
    root, download=download, train=False, transform=transform
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=200, shuffle=False, num_workers=4, pin_memory=False
)

In [None]:
d = train_dataset[1]
print(f"Rozmiar obrazka: {d[0].shape} oraz jego etykieta {d[1]}")

it = iter(train_loader)
d = next(it)
d[0].shape, d[1]

Poniżej tworzymy dodatkowe klasy, których będziemy używać do budowy sieci neuronowej (klasa ,,*View*'') jak również do uczenia jej (klasa ,,*LambdaLR*''). 

In [None]:
class View(torch.nn.Module):
    def __init__(self, *shape) -> None:
        super(View, self).__init__()
        self.shape = shape

    def forward(self, input_x: torch.Tensor) -> torch.Tensor:
        return input_x.view(*self.shape)


class LambdaLR(torch.optim.lr_scheduler.LambdaLR):
    def __init__(
        self, optimizer, lr_lambda, last_epoch=-1, verbose=False, min_val=1e-5
    ):
        self.min_val = min_val
        self.change = True

        super(LambdaLR, self).__init__(optimizer, lr_lambda, last_epoch, verbose)

    def step(self, epoch=None):
        if self.change:
            super().step(epoch)

            change = False
            values = self.get_last_lr()
            for i, data in enumerate(zip(self.optimizer.param_groups, values)):
                param_group, lr = data
                param_group["lr"] = lr if lr > self.min_val else self.min_val
                self.print_lr(self.verbose, i, lr, epoch)

            self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


sns.set(font_scale=1.5)
sns.set_style("whitegrid")

model = torch.nn.Linear(2, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

fun = lambda epoch: 0.9 ** epoch
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=fun, last_epoch=-1)

epochs = 300
lrs = []
for i in range(epochs):
    optimizer.step()
    lrs.append(optimizer.param_groups[0]["lr"])
    scheduler.step()

plt.figure(figsize=(10, 4))
plt.plot(range(epochs), lrs)
plt.show()
plt.close()

# ============================================================================

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler = LambdaLR(optimizer, lr_lambda=fun, min_val=1e-2)

epochs = 300
lrs = []
for i in range(epochs):
    optimizer.step()
    lrs.append(scheduler.get_last_lr())
    scheduler.step()

plt.figure(figsize=(10, 4))
plt.plot(range(epochs), lrs)
plt.show()
plt.close()

# AutoEncoder

Klasa ,,*AE*'' definuje sieć autoenkodera składająca się z dwóch części: kodującej (enkodera) i dekodującej (dekodera). 

### Zadanie
Napisać kase AE, której części enkoder i dekoder składają się z dwóch warstw fully-connected. Funkcja ,,forward'' powinna zwracać obrazek o rozmiarze (1, 28, 28).

In [None]:
class AE(torch.nn.Module):
    def __init__(self, latent_dim, dim_hidden):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, dim_hidden)
        self.fc2 = torch.nn.Linear(dim_hidden, latent_dim)
        
        self.fc3 = torch.nn.Linear(latent_dim, dim_hidden)
        self.fc4 = torch.nn.Linear(dim_hidden, 784)
        
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        x = self.fc1(x)
        x = self.relu(x)
        z = self.fc2(x)
        out = self.fc3(z)
        out = self.relu(out)
        out = self.fc4(out)
        return z, torch.sigmoid(out).reshape((-1, 1, 28, 28))

[Jak działa wartstwa konwolucyjna?](https://bfirst.tech/konwolucyjne-sieci-neuronowe/)

In [None]:
from typing import Tuple


class AE(torch.nn.Module):
    def __init__(self, latent_dim: int, dim_hidden: int) -> None:
        super(AE, self).__init__()

        self.latent_dim = latent_dim
        self.dim_h = dim_hidden

        # Encoder
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, self.dim_h, 4, 2, 1, bias=False),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(self.dim_h, self.dim_h * 2, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(self.dim_h * 2),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(self.dim_h * 2, self.dim_h * 4, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(self.dim_h * 4),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(self.dim_h * 4, self.dim_h * 8, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(self.dim_h * 8),
            torch.nn.ReLU(True),
            torch.nn.Flatten(),
            torch.nn.Linear(self.dim_h * (2**3), latent_dim),
        )

        # Decoder
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, self.dim_h * 8 * 7 * 7),
            torch.nn.ReLU(True),
            View(-1, self.dim_h * 8, 7, 7),
            torch.nn.ConvTranspose2d(self.dim_h * 8, self.dim_h * 4, 4),
            torch.nn.BatchNorm2d(self.dim_h * 4),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(self.dim_h * 4, self.dim_h * 2, 4),
            torch.nn.BatchNorm2d(self.dim_h * 2),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(self.dim_h * 2, 1, 4, stride=2),
            torch.nn.Sigmoid(),
        )

    def forward(self, input_x: torch.Tensor) -> Tuple[torch.Tensor]:
        z = self.encoder(input_x)
        return z, self.decoder(z)

Uczenie modelu i jego walidacja.

In [None]:
latent_dim = 8
dim_hidden = 16

model = AE(latent_dim=latent_dim, dim_hidden=dim_hidden)
model = model.to(device)
print(model)
print(f"Number of parameters: {count_parameters(model)}")

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

use_scheduler = True
scheduler = None
if use_scheduler:
    scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.8**epoch, min_val=1e-5)

mse_loss = torch.nn.MSELoss()

scores = {"train": {"loss": []}, "test": {"loss": []}}

epochs = 25
for epoch in range(epochs):

    # training
    model.train()

    losses = AverageMeter()

    train_tqdm = tqdm(train_loader, total=len(train_loader), leave=False)
    for image, _ in train_tqdm:
        image = image.to(device)

        _, recon = model(image)
        loss = mse_loss(recon, image)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item())

        train_tqdm.set_description(f"TRAIN loss: {losses.val:.4g} ({losses.avg:.4g})")

    scores["train"]["loss"].append(losses.avg)
    if use_scheduler:
        scheduler.step()

    # validating
    model.eval()

    losses = AverageMeter()

    with torch.no_grad():
        eval_tqdm = tqdm(test_loader, total=len(test_loader), leave=False)
        for image, _ in eval_tqdm:
            image = image.to(device)

            _, recon = model(image)
            loss = mse_loss(recon, image)

            losses.update(loss.item())

            eval_tqdm.set_description(f"TEST loss: {losses.val:.4g} ({losses.avg:.4g})")

    scores["test"]["loss"].append(losses.avg)

    print(
        f"Epoch: [{epoch + 1}/{epochs}]; "
        f"train: {scores['train']['loss'][-1]:.4g}; "
        f"test: {scores['test']['loss'][-1]:.4f}"
        f"{f'; lr: {scheduler.get_last_lr()[0]:.4g}' if use_scheduler else ''}"
    )

torch.save(
    model.state_dict(), "ae.pth"
)  # zapisujemy model do dalszej ewaluacji np. obliczenia FID score

Poniżej przedstawiamy zmianę funkcji kosztu sieci AE w trakcie jej uczenia.

In [None]:
# loss
fig = plt.figure(figsize=(22, 8))
ax = fig.add_subplot(111)
ax.plot(
    scores["test"]["loss"],
    "r--",
    linewidth=4,
    markersize=12,
    label="Loss na zbiorze treningowym",
)
ax.plot(
    scores["train"]["loss"],
    "r-",
    linewidth=4,
    markersize=12,
    label="Loss na zbiorze testowym",
)
ax.tick_params(
    axis="both",
    which="both",
    direction="out",
    length=6,
    width=2,
    colors="k",
    grid_alpha=0.5,
)
ax.grid(which="both")
ax.grid(which="major", color="#CCCCCC", linestyle="--", alpha=0.8)
ax.grid(which="minor", color="#CCCCCC", linestyle=":", alpha=0.8)

ax.legend(loc=0)
ax.set_ylabel("Losses")
ax.set_xlabel("Epoka")
plt.tight_layout(pad=0.5)

plt.show()
plt.close()

Czas teraz na pokazanie jak wyuczona sieć AE rekonstruuje obrazki ze zbioru MNIST. W tym celu przepuszczamy przez sieć obrazki ze zbioru walidującego (górny wiersz obrazka), a sieć zwraca ich rekonstrukcje (dolny wiersz obrazka).  

In [None]:
# validating
model.eval()

mses = AverageMeter()

with torch.no_grad():
    eval_tqdm = tqdm(enumerate(test_loader), total=len(test_loader), leave=False)
    for i, (image, _) in eval_tqdm:
        image = image.to(device)

        _, recon = model(image)
        loss = mse_loss(recon, image)

        mses.update(loss.item())

        eval_tqdm.set_description(f"mse: {mses.val:.4g} ({mses.avg:.4g})")

        if i == len(test_loader) - 1:
            show(image, recon, 10)

print(f"Ewaluation MSE: {mses.avg:.4g}")