In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# DDPM 스코어 모델 (간단한 UNet 기반)
class SimpleScoreModel(nn.Module):
    def __init__(self):
        super(SimpleScoreModel, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28 + 1, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 28*28)
        )

    def forward(self, x, t):
        t = t.view(-1, 1)  # 타임스텝을 배치 차원에 맞게 변형
        x_t = torch.cat([x, t], dim=1)  # 입력 데이터와 타임스텝 결합
        return self.net(x_t)

# Score-based 모델 학습 (단순 예제)
def train_score_model():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

    model = SimpleScoreModel()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    num_epochs = 5
    for epoch in range(num_epochs):
        for x, _ in dataloader:
            t = torch.rand(x.shape[0])  # 랜덤한 타임스텝 샘플링
            noise = torch.randn_like(x)  # 가우시안 노이즈 추가
            x_t = x + noise  # 노이즈 추가된 데이터

            pred_noise = model(x_t, t)
            loss = criterion(pred_noise, noise)  # 노이즈 예측 손실 계산

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

    return model

# 로그 가능도 추정 함수
def compute_likelihood(model, x, num_timesteps=50):
    """ Score-based Likelihood Estimation """
    likelihoods = []
    for t in torch.linspace(0, 1, num_timesteps):  # 여러 타임스텝에 대해 로그 가능도 계산
        noise = torch.randn_like(x)
        x_t = x + noise  # 노이즈 추가된 데이터 생성
        score = model(x_t, t.expand(x.shape[0]))  # 스코어 추정
        likelihood = -((score - noise) ** 2).sum(dim=1)  # Stein’s Method 기반 로그 가능도
        likelihoods.append(likelihood)
    
    return torch.stack(likelihoods).mean(dim=0)  # 평균 로그 가능도 반환

# 이상 탐지 수행
def anomaly_detection(model, normal_data, anomaly_data):
    normal_likelihood = compute_likelihood(model, normal_data)
    anomaly_likelihood = compute_likelihood(model, anomaly_data)

    print(f"Normal Data Log-Likelihood Mean: {normal_likelihood.mean().item():.4f}")
    print(f"Anomalous Data Log-Likelihood Mean: {anomaly_likelihood.mean().item():.4f}")

    return normal_likelihood, anomaly_likelihood

# 실행 예제
model = train_score_model()

# 정상 데이터 (MNIST 0~4) 및 이상 데이터 (MNIST 5~9) 샘플링
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

normal_data = torch.stack([dataset[i][0] for i in range(len(dataset)) if dataset[i][1] < 5])
anomaly_data = torch.stack([dataset[i][0] for i in range(len(dataset)) if dataset[i][1] >= 5])

# 이상 탐지 실행
normal_likelihood, anomaly_likelihood = anomaly_detection(model, normal_data, anomaly_data)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|████████████████████████████| 9912422/9912422 [00:06<00:00, 1568004.20it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|█████████████████████████████████| 28881/28881 [00:00<00:00, 141025.34it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|████████████████████████████| 1648877/1648877 [00:01<00:00, 1111376.49it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████| 4542/4542 [00:00<00:00, 1942543.98it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch [1/5], Loss: 0.8572
Epoch [2/5], Loss: 0.8698
Epoch [3/5], Loss: 0.8636
Epoch [4/5], Loss: 0.8595
Epoch [5/5], Loss: 0.8399
Normal Data Log-Likelihood Mean: -674.7410
Anomalous Data Log-Likelihood Mean: -674.8512
