In [None]:
# =============================================================================
# 1. MIG 환경에서 반드시 먼저 실행
# =============================================================================
import os
for key in list(os.environ.keys()):
    if key.startswith(("CUDA_", "NVIDIA_", "TF_")):
        os.environ.pop(key, None)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# =============================================================================
# 2. TensorFlow 라이브러리
# =============================================================================
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.callbacks import TensorBoard
import time
import matplotlib.pyplot as plt

# =============================================================================
# 3. GPU 확인
# =============================================================================
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"성공! GPU 잡힘 → {gpus}")
        print(f"GPU 이름: {tf.config.experimental.get_device_details(gpus[0])['device_name']}")
        device_name = '/GPU:0'
    except Exception as e:
        print("GPU 설정 오류:", e)
        device_name = '/CPU:0'
else:
    print("GPU를 못 찾았습니다. 환경변수를 다시 확인하세요.")
    device_name = '/CPU:0'

# =============================================================================
# 4. CIFAR-10 데이터셋 로드 및 정규화
# =============================================================================
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.astype('float32') / 255.0
X_test  = X_test.astype('float32')  / 255.0

# =============================================================================
# 5. 모델 정의
# =============================================================================
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)),
    MaxPooling2D((2, 2), padding='same'),
    Conv2D(64, (3, 3), activation='relu', padding='same'),
    MaxPooling2D((2, 2), padding='same'),
    Conv2D(128, (3, 3), activation='relu', padding='same'),
    MaxPooling2D((2, 2), padding='same'),
    Flatten(),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

model.summary()

# =============================================================================
# 6. 모델 컴파일 (Keras 3 호환 방식)
# =============================================================================
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# =============================================================================
# 7. TensorBoard 설정
# =============================================================================
log_dir = "logs/fit/" + time.strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

# =============================================================================
# 8. 학습 (MIG 1g.10gb에서는 batch_size=128~256이 최적!)
# =============================================================================
print(f"\nGPU 학습 시작! ({device_name}) - A100 MIG 1g.10gb")
start_time = time.time()

with tf.device(device_name):
    history = model.fit(
        X_train, y_train,
        epochs=10,
        batch_size=256,                 # ← 1g.10gb에서 가장 빠름 (32보다 3~4배 빠름)
        validation_data=(X_test, y_test),
        callbacks=[tensorboard_callback],
        verbose=1
    )

end_time = time.time()
print(f"학습 완료! 소요시간: {end_time - start_time:.2f} 초")

# =============================================================================
# 9. 평가 및 결과 출력
# =============================================================================
print("Evaluating model...")
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc * 100:.2f}%")

# =============================================================================
# 10. 학습 곡선 시각화
# =============================================================================
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.title('Loss'); plt.legend(); plt.grid()
plt.subplot(1,2,2)
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='val_acc')
plt.title('Accuracy'); plt.legend(); plt.grid()
plt.tight_layout()
plt.show()

# =============================================================================
# 11. TensorBoard 실행 방법
# =============================================================================
print(f"\nTensorBoard 실행 명령어:")
print(f"tensorboard --logdir {log_dir} --bind_all")