<a href="https://colab.research.google.com/github/no-akatsu/training/blob/main/241025_AutoEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import

In [1]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

# Model

In [8]:
# シンプルなAutoEncoder
class SimpleAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SimpleAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid() # 出力を0-1の範囲に制限
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Data
- InputデータにOnehotを想定

In [9]:
category_data = torch.tensor(
    [[1, 0, 0, 0],
     [0, 1, 0, 0],
     [0, 0, 1, 0],
     [0, 0, 0, 1]], dtype=torch.float32)

# Training

In [11]:
input_dim = category_data.shape[1]
hidden_dim = 2

model = SimpleAutoencoder(input_dim, hidden_dim)

# 最適化手法と損失関数の定義
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

# 学習ループ
for epoch in range(1000):
    optimizer.zero_grad()
    output = model(category_data)
    loss = loss_fn(output, category_data)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch {epoch + 1}, Loss: {loss.item():.4f}')

# 学習後にエンコードされたベクトルを確認
with torch.no_grad():
    encoded_data = model.encoder(category_data)
    print(encoded_data)

Epoch 100, Loss: 0.0870
Epoch 200, Loss: 0.0343
Epoch 300, Loss: 0.0166
Epoch 400, Loss: 0.0097
Epoch 500, Loss: 0.0064
Epoch 600, Loss: 0.0045
Epoch 700, Loss: 0.0034
Epoch 800, Loss: 0.0026
Epoch 900, Loss: 0.0021
Epoch 1000, Loss: 0.0017
tensor([[0.0000, 0.0000],
        [0.0000, 2.4478],
        [2.5139, 0.0000],
        [4.7702, 5.0890]])
