오토인코더 만들기.

1. Dataset
    - CIFAR-10
2. 이미지를 8x8 혹은 4x4 블록으로 쪼개서 DCT 변환 후, DCT Coeff를 Output으로 내는 AutoEncoder.
3. 모델의 Output을 IDCT 해서 원래 이미지가 잘 나오는 지 확인

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # 물리적 순서대로 정렬
os.environ["CUDA_VISIBLE_DEVICES"] = "6"        # 나는 6번 GPU만 보이게 하겠다!

import torch
print(f"현재 사용 가능한 GPU 개수: {torch.cuda.device_count()}")
print(f"현재 선택된 GPU 번호: {torch.cuda.current_device()}")
print(f"장치 이름: {torch.cuda.get_device_name(0)}")

현재 사용 가능한 GPU 개수: 1
현재 선택된 GPU 번호: 0
장치 이름: NVIDIA GeForce RTX 3090


In [2]:
print("오토인코더")

오토인코더


In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torchvision import datasets, transforms
# import numpy as np
# from scipy.fftpack import dct, idct

# # ---------------------------
# # DCT / IDCT utilities
# # ---------------------------

# def dct2(block):
#     return dct(dct(block.T, norm='ortho').T, norm='ortho')

# def idct2(block):
#     return idct(idct(block.T, norm='ortho').T, norm='ortho')

# def blockify(img, block_size=8):
#     blocks = []
#     C, H, W = img.shape
#     for c in range(C):
#         for i in range(0, H, block_size):
#             for j in range(0, W, block_size):
#                 blocks.append(img[c, i:i+block_size, j:j+block_size])
#     return blocks

# def deblockify(blocks, img_shape, block_size=8):
#     C, H, W = img_shape
#     img = torch.zeros(img_shape)
#     idx = 0
#     for c in range(C):
#         for i in range(0, H, block_size):
#             for j in range(0, W, block_size):
#                 img[c, i:i+block_size, j:j+block_size] = blocks[idx]
#                 idx += 1
#     return img

# # ---------------------------
# # AutoEncoder
# # ---------------------------

# class DCTAutoEncoder(nn.Module):
#     def __init__(self, dim=64):
#         super().__init__()
#         self.encoder = nn.Sequential(
#             nn.Linear(dim, 32),
#             nn.ReLU()
#         )
#         self.decoder = nn.Sequential(
#             nn.Linear(32, dim)
#         )

#     def forward(self, x):
#         z = self.encoder(x)
#         out = self.decoder(z)
#         return out

# # ---------------------------
# # Dataset
# # ---------------------------

# transform = transforms.Compose([
#     transforms.ToTensor()
# ])

# dataset = datasets.CIFAR10(
#     root="./data",
#     train=True,
#     download=True,
#     transform=transform
# )

# loader = torch.utils.data.DataLoader(
#     dataset,
#     batch_size=1,
#     shuffle=True
# )

# # ---------------------------
# # Model / Optim
# # ---------------------------

# model = DCTAutoEncoder()
# criterion = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), lr=1e-3)

# # ---------------------------
# # Training (very small demo)
# # ---------------------------

# model.train()

# for epoch in range(1):
#     for img, _ in loader:
#         img = img.squeeze(0)  # [3,32,32]

#         blocks = blockify(img)
#         dct_blocks = []

#         for b in blocks:
#             b_np = b.numpy()
#             dct_b = dct2(b_np)
#             dct_blocks.append(torch.tensor(dct_b).flatten())

#         dct_blocks = torch.stack(dct_blocks)

#         output = model(dct_blocks)
#         loss = criterion(output, dct_blocks)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         print("loss:", loss.item())
#         break
#     break

# # ---------------------------
# # Reconstruction check
# # ---------------------------

# model.eval()
# with torch.no_grad():
#     img, _ = next(iter(loader))
#     img = img.squeeze(0)

#     blocks = blockify(img)
#     recon_blocks = []

#     for b in blocks:
#         dct_b = dct2(b.numpy())
#         dct_flat = torch.tensor(dct_b).flatten()
#         out = model(dct_flat)
#         out_block = out.view(8, 8).numpy()
#         recon = idct2(out_block)
#         recon_blocks.append(torch.tensor(recon))

#     recon_img = deblockify(recon_blocks, img.shape)

# print("Original min/max:", img.min().item(), img.max().item())
# print("Reconstructed min/max:", recon_img.min().item(), recon_img.max().item())


Files already downloaded and verified
loss: 0.4836006164550781
Original min/max: 0.0 1.0
Reconstructed min/max: -0.631921648979187 0.4574740529060364


In [None]:
# dct_ae_cifar10.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import numpy as np
from scipy.fftpack import dct, idct
import matplotlib.pyplot as plt

# -----------------------------
# 1. DCT / IDCT utilities
# -----------------------------

def dct2(block):
    # 2D DCT (orthonormal)
    return dct(dct(block, axis=0, norm='ortho'), axis=1, norm='ortho')

def idct2(block):
    # 2D IDCT (orthonormal)
    return idct(idct(block, axis=0, norm='ortho'), axis=1, norm='ortho')

def image_to_dct_blocks(img, block_size=8):
    """
    img: (H, W) numpy array
    return: (num_blocks, block_size*block_size)
    """
    H, W = img.shape
    blocks = []
    for i in range(0, H, block_size):
        for j in range(0, W, block_size):
            block = img[i:i+block_size, j:j+block_size]
            blocks.append(dct2(block).flatten())
    return np.stack(blocks)

def dct_blocks_to_image(blocks, H, W, block_size=8):
    """
    blocks: (num_blocks, block_size*block_size)
    return: reconstructed image (H, W)
    """
    img = np.zeros((H, W))
    idx = 0
    for i in range(0, H, block_size):
        for j in range(0, W, block_size):
            block = blocks[idx].reshape(block_size, block_size)
            img[i:i+block_size, j:j+block_size] = idct2(block)
            idx += 1
    return img

# -----------------------------
# 2. AutoEncoder definition
# -----------------------------

class DCTAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16)
        )
        self.decoder = nn.Sequential(
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 64)
        )

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out

# -----------------------------
# 3. Dataset (CIFAR-10)
# -----------------------------

transform = transforms.Compose([
    transforms.Grayscale(),  # 단순화를 위해 grayscale
    transforms.ToTensor()
])

dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=True
)

# -----------------------------
# 4. Training setup
# -----------------------------

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DCTAutoEncoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# -----------------------------
# 5. Training loop
# -----------------------------

epochs = 5
model.train()

for epoch in range(epochs):
    total_loss = 0.0
    for img, _ in loader:
        img = img.squeeze().numpy()  # (32, 32)

        dct_blocks = image_to_dct_blocks(img)  # (16, 64)
        dct_blocks = torch.tensor(dct_blocks, dtype=torch.float32).to(device)

        optimizer.zero_grad()
        recon = model(dct_blocks)
        loss = criterion(recon, dct_blocks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"[Epoch {epoch+1}] loss: {total_loss:.4f}")

# -----------------------------
# 6. Visualization (fixed image)
# -----------------------------

model.eval()
with torch.no_grad():
    img, _ = dataset[0]
    img = img.squeeze().numpy()

    dct_blocks = image_to_dct_blocks(img)
    dct_blocks_t = torch.tensor(dct_blocks, dtype=torch.float32).to(device)

    recon_blocks = model(dct_blocks_t).cpu().numpy()
    recon_img = dct_blocks_to_image(recon_blocks, 32, 32)

print("Original min/max:", img.min(), img.max())
print("Reconstructed min/max:", recon_img.min(), recon_img.max())

plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plt.title("Original")
plt.imshow(img, cmap='gray')
plt.axis('off')

plt.subplot(1,2,2)
plt.title("Reconstructed")
plt.imshow(recon_img, cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()
