In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import mnist_reader
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis



# === 1. 資料準備 ===

x_train, y_train = mnist_reader.load_data('../dim_reduction/data/oracle', kind='train')
x_test, y_test = mnist_reader.load_data('../dim_reduction/data/oracle', kind='t10k')

x_train = x_train.reshape(-1, 1, 28, 28).astype(np.float32) / 255.0
x_test = x_test.reshape(-1, 1, 28, 28).astype(np.float32) / 255.0

y_train = y_train.astype(np.int64)
y_test = y_test.astype(np.int64)

train_dataset = TensorDataset(torch.tensor(x_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(x_test), torch.tensor(y_test))

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

# === 2. CNN 模型定義 ===

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x, return_features=False):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.flatten(x)
        features = F.relu(self.fc1(x))
        if return_features:
            return features
        return self.fc2(features)

# 1. 載入訓練好的模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=10).to(device)
# model.load_state_dict(torch.load("best_cnn_model.pt"))
model.load_state_dict(torch.load("best_cnn_model.pt", map_location=torch.device("cpu")))
model.eval()

# 2. 特徵提取函數
def extract_features(model, dataloader):
    model.eval()
    features = []
    labels = []
    with torch.no_grad():
        for images, targets in dataloader:
            images = images.to(device)
            feats = model(images, return_features=True)  # 使用 return_features=True
            features.append(feats.cpu().numpy())
            labels.append(targets.numpy())
    return np.vstack(features), np.hstack(labels)

# 3. 執行特徵擷取（可選擇 train 或 test）
X_train_feats, y_train = extract_features(model, train_loader)
X_test_feats, y_test = extract_features(model, test_loader)

# LDA 可降至 n_classes - 1 維（這裡是 9）
lda_components = 9

# 執行 LDA
lda = LinearDiscriminantAnalysis(n_components=lda_components)
x_train_lda = lda.fit_transform(X_train_feats, y_train)
x_test_lda = lda.transform(X_test_feats)

The size of train set: 27222
The size of t10k set: 3000


  model.load_state_dict(torch.load("best_cnn_model.pt", map_location=torch.device("cpu")))


In [3]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score

# LDA + K-means clustering

# 設定群數（Oracle MNIST 是 10 類）
n_clusters = 10

# 建立 K-means 模型並訓練
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
cluster_labels = kmeans.fit_predict(x_test_lda)

# 評估 clustering 效果
ari = adjusted_rand_score(y_test, cluster_labels)
print(f"Adjusted Rand Index (ARI): {ari:.4f}")

Adjusted Rand Index (ARI): 0.7246




In [4]:
from sklearn.mixture import GaussianMixture
from sklearn.metrics import adjusted_rand_score

# LDA + EM algorithm

# 設定群數（仍為 10）
n_components = 10

# 建立 GMM 模型
gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=42)
gmm.fit(x_test_lda)  # 只用 test 資料進行 unsupervised clustering

# 分群預測（返回最可能的群編號）
cluster_labels = gmm.predict(x_test_lda)

# 計算 Adjusted Rand Index (ARI)
ari = adjusted_rand_score(y_test, cluster_labels)
print(f"EM Clustering ARI: {ari:.4f}")



EM Clustering ARI: 0.7612
