In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
from tensorflow.keras.utils import to_categorical

In [6]:
# 데이터 로드
data = pd.read_csv("C:\\Users\\LG\\Downloads\\seeds\\seeds_dataset.txt", sep=r'\s+', header=None)
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values.astype(int) - 1  # 클래스 0~2
y_encoded = to_categorical(y, num_classes=3)

In [8]:
# 셔플 및 stratify 분할
X, y_encoded = shuffle(X, y_encoded, random_state=42)
y_raw = np.argmax(y_encoded, axis=1)

X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, stratify=y_raw, random_state=42)
y_train_raw = np.argmax(y_train, axis=1)

X_train_final, X_val, y_train_final, y_val = train_test_split(X_train, y_train, test_size=0.1, stratify=y_train_raw, random_state=42)

In [10]:
# 정규화
scaler = StandardScaler()
X_train_final = scaler.fit_transform(X_train_final)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

In [12]:
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Dropout

# Model 클래스 상속
class SeedsClassifier(Model):
    def __init__(self):
        super(SeedsClassifier, self).__init__()
        self.dense1 = Dense(32, activation='relu')
        self.dropout = Dropout(0.3)
        self.dense2 = Dense(16, activation='relu')
        self.prediction = Dense(3, activation='softmax')  # 출력층 변수명: prediction

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        if training:
            x = self.dropout(x, training=training)
        x = self.dense2(x)
        return self.prediction(x)

In [14]:
#컴파일 및 학습
model = SeedsClassifier()

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

history = model.fit(
    X_train_final, y_train_final,
    validation_data=(X_val, y_val),
    epochs=100,
    batch_size=8,
    callbacks=[early_stop]
)

Epoch 1/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 55ms/step - accuracy: 0.4492 - loss: 1.0098 - val_accuracy: 0.8824 - val_loss: 0.8427
Epoch 2/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.7124 - loss: 0.8537 - val_accuracy: 0.8824 - val_loss: 0.6980
Epoch 3/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 22ms/step - accuracy: 0.8104 - loss: 0.7328 - val_accuracy: 0.8824 - val_loss: 0.5642
Epoch 4/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8865 - loss: 0.5777 - val_accuracy: 0.8824 - val_loss: 0.4555
Epoch 5/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.8694 - loss: 0.5343 - val_accuracy: 0.8824 - val_loss: 0.3812
Epoch 6/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.9010 - loss: 0.4489 - val_accuracy: 0.8824 - val_loss: 0.3402
Epoch 7/100
[1m19/19[0m [

In [16]:
loss, accuracy = model.evaluate(X_test, y_test)
print(f"테스트 정확도: {accuracy:.4f}")

pred_classes = np.argmax(model.predict(X_test), axis=1)
true_classes = np.argmax(y_test, axis=1)

from sklearn.metrics import confusion_matrix, classification_report
print(confusion_matrix(true_classes, pred_classes))
print(classification_report(true_classes, pred_classes, digits=4))

[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - accuracy: 0.8948 - loss: 0.3541
테스트 정확도: 0.9048
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 165ms/step
[[12  1  1]
 [ 1 13  0]
 [ 1  0 13]]
              precision    recall  f1-score   support

           0     0.8571    0.8571    0.8571        14
           1     0.9286    0.9286    0.9286        14
           2     0.9286    0.9286    0.9286        14

    accuracy                         0.9048        42
   macro avg     0.9048    0.9048    0.9048        42
weighted avg     0.9048    0.9048    0.9048        42

