In [None]:
import os
import shutil
import itertools
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from efficientnet_pytorch import EfficientNet
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np

# 경로 설정 (여기에 직접 경로를 입력)
train_dir = "/content/drive/MyDrive/data/태림산업 이미지셋/Processed_Data_TUBE/iteration_1/train"
test_dir = "/content/drive/MyDrive/data/태림산업 이미지셋/Processed_Data_TUBE/iteration_1/test"

balanced_fail_dir = os.path.join(train_dir, 'balanced_fail')
pass_dir = os.path.join(train_dir, 'pass')

# 데이터 전처리
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 이미지 크기 변경
    transforms.ToTensor(),  # 텐서 변환
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 정규화
])

class ImageDataset(Dataset):
    def __init__(self, fail_dir, pass_dir, transform=None):
        self.fail_dir = fail_dir
        self.pass_dir = pass_dir
        self.transform = transform
        self.image_paths = []

        # fail 이미지 경로 추가
        for filename in os.listdir(fail_dir):
            if filename.lower().endswith('.bmp'):
                self.image_paths.append((os.path.join(fail_dir, filename), 0))  # fail은 0 (label)

        # pass 이미지 경로 추가
        for filename in os.listdir(pass_dir):
            if filename.lower().endswith('.bmp'):
                self.image_paths.append((os.path.join(pass_dir, filename), 1))  # pass는 1 (label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path, label = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# 훈련 데이터셋과 데이터로더 설정
train_dataset = ImageDataset(balanced_fail_dir, pass_dir, transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 테스트 데이터셋과 데이터로더 설정
test_dataset = ImageDataset(balanced_fail_dir, pass_dir, transform)  # 동일한 폴더 구조를 사용하는 경우
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# EfficientNet 모델 로드
model = EfficientNet.from_pretrained('efficientnet-b0')  # 사전 학습된 모델 불러오기
model._fc = nn.Linear(model._fc.in_features, 2)  # 마지막 레이어를 2개의 클래스에 맞게 변경

# 모델을 GPU로 이동 (가능한 경우)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 손실 함수와 옵티마이저 설정
criterion = nn.CrossEntropyLoss()  # 다중 클래스 분류를 위한 크로스 엔트로피 손실
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 모델 훈련 함수
def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # 순전파 + 역전파 + 최적화
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # 통계 계산
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct_preds += torch.sum(preds == labels).item()
            total_preds += labels.size(0)

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct_preds / total_preds
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

# 모델 평가 함수
def evaluate_model(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # 혼동 행렬 계산
    cm = confusion_matrix(all_labels, all_preds)
    print("Confusion Matrix:")
    print(cm)

    # 혼동 행렬 시각화
    plt.figure(figsize=(6, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Fail', 'Pass'], yticklabels=['Fail', 'Pass'])
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title('Confusion Matrix')
    plt.show()

    # 정확도 출력
    accuracy = np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

# 모델 훈련
train_model(model, train_loader, criterion, optimizer, num_epochs=5)

# 모델 평가 (혼동 행렬 및 정확도 출력)
evaluate_model(model, test_loader)

# 모델 저장
torch.save(model.state_dict(), 'model.pth')
print("모델이 저장되었습니다.")
