In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,)),
])

to_pil_image = transforms.ToPILImage()

train_data = datasets.MNIST(
    root='input/data',
    train=True,
    download=True,
    transform=transform
)
print(len(train_data))
testimg = train_data[0][0].to(device)
batch_size = 2048
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# Autoencoder
class Autoencoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.encoder = nn.Linear(input_size, hidden_size, bias=True)
        self.decoder = nn.Linear(hidden_size, input_size, bias=True)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

input_size = 784  # Eingabegröße
hidden_size = 128  # Größe der versteckten Schicht
model = Autoencoder(input_size, hidden_size).to(device)

# Optimierer
optimizer = optim.Adam(model.parameters(), lr=0.01)

enc, test_out = model(testimg.flatten())
plt.figure(figsize=[8,8])
plt.subplot(2,1,2)
plt.plot(torch.sort(torch.abs(enc)).values.detach().cpu().numpy())
plt.subplot(2,2,2)
plt.imshow(test_out.reshape(28,28).detach().cpu().numpy())
plt.subplot(2,2,1)
plt.imshow(testimg[0])
plt.show()

# Trainingsloop
num_epochs = 10
criterion = nn.MSELoss()

# ------------------------------------------------------
# --------------- Wähle Maß für Sparsity ---------------
sparsity = lambda x: 0.1*torch.mean(torch.abs(x)) # MAE
#sparsity = lambda x: 0.1*torch.mean(torch.abs(x)**2) # MSE
#sparsity = lambda x: -0.1*torch.mean(torch.softmax(2*x, dim=-1)*torch.log(torch.softmax(2*x, dim=-1))) # Entropie
#sparsity = lambda x: 0 # keine Sparsity
# ------------------------------------------------------

for epoch in range(num_epochs):
    for img, _ in tqdm(train_loader):
        img = torch.reshape(img, (-1, 784)).to(device)
        img_corrupted = img + 0.0*torch.randn(img.shape).to(device)
        encoded, reconstructed = model(img_corrupted)

        loss = criterion(reconstructed, img)+sparsity(encoded)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if epoch%1==0:
        encoded, reconstructed = model(testimg.flatten())
        plt.figure(figsize=[8,8])
        plt.subplot(2,1,2)
        plt.plot(torch.sort(torch.abs(encoded)).values.detach().cpu().numpy())
        plt.subplot(2,2,2)
        plt.imshow(reconstructed.reshape(28,28).detach().cpu().numpy())
        plt.subplot(2,2,1)
        plt.imshow(testimg[0])
        plt.show()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

encoded, reconstructed = model(testimg.flatten())
plt.figure(figsize=[8,8])
plt.subplot(2,1,2)
plt.plot(torch.sort(torch.abs(encoded)).values.detach().cpu().numpy())
plt.subplot(2,2,2)
plt.imshow(reconstructed.reshape(28,28).detach().cpu().numpy())
plt.subplot(2,2,1)
plt.imshow(testimg[0])
plt.show()


In [None]:
for i in range(10):
    testimg_corrupted = train_data[i][0] + 0.0*torch.randn(testimg.shape)
    plt.figure()
    plt.subplot(1,3,1)
    plt.imshow(testimg_corrupted[0])
    plt.subplot(1,3,2)
    plt.imshow(model(testimg_corrupted.flatten())[1].reshape(28,28).detach().cpu().numpy())
    plt.subplot(1,3,3)
    plt.imshow(train_data[i][0][0])
    plt.show()

In [None]:
from torchvision.utils import make_grid
enc = model.decoder.weight.data.t()
print(enc.shape)
enc -= torch.min(enc, axis=-1, keepdim=True).values
enc /= torch.max(enc)
enc = enc.reshape(-1, 1, 28,28)
image_enc = make_grid(enc)
print(image_enc.shape)
plt.figure(figsize=[10,20])
plt.imshow(image_enc.transpose(0,1).transpose(1,2))
plt.show()

plt.figure()
plt.imshow(model.decoder.bias.data.reshape(28,28))
plt.show()