In [None]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt
import torch.optim as optim
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.utils.data import ConcatDataset
import torch.nn.functional as F

# Focal Loss Definition
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.cross_entropy(inputs, targets, reduction="none")
        pt = torch.exp(-BCE_loss)
        F_loss = (1 - pt) ** self.gamma * BCE_loss
        return F_loss.mean() if self.reduction == 'mean' else F_loss.sum()

# Custom dataset class for augmented images
class AugmentedDataset(Dataset):
    def __init__(self, root_dir, transform=None, target_length=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, fname))]

        if target_length and len(self.image_paths) < target_length:
            self.image_paths *= (target_length // len(self.image_paths)) + 1
            self.image_paths = self.image_paths[:target_length]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx % len(self.image_paths)]
        image = Image.open(img_path).convert("RGB")
        label = 1 if 'pass' in img_path else 0
        if self.transform:
            image = self.transform(image)
        return image, label

# Paths (adjust as needed)
train_fail_dir = '/content/drive/MyDrive/data/태림산업 이미지셋/Processed_Data_SHAFT/iteration_1/train/fail'
train_pass_dir = '/content/drive/MyDrive/data/태림산업 이미지셋/Processed_Data_SHAFT/iteration_1/train/pass'
test_fail_dir = '/content/drive/MyDrive/data/태림산업 이미지셋/Processed_Data_SHAFT/iteration_1/Test/fail'
test_pass_dir = '/content/drive/MyDrive/data/태림산업 이미지셋/Processed_Data_SHAFT/iteration_1/Test/pass'

# Define data augmentations and transformations
augment_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create balanced datasets
train_fail_dataset = AugmentedDataset(train_fail_dir, transform=augment_transform)
train_pass_dataset = AugmentedDataset(train_pass_dir, transform=augment_transform, target_length=len(train_fail_dataset))
train_dataset = ConcatDataset([train_fail_dataset, train_pass_dataset])
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

# Test datasets
test_fail_dataset = AugmentedDataset(test_fail_dir, transform=test_transform)
test_pass_dataset = AugmentedDataset(test_pass_dir, transform=test_transform)
test_dataset = ConcatDataset([test_fail_dataset, test_pass_dataset])
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

# Load a pre-trained ResNet model and modify the final layer
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)

# Define device, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = FocalLoss(gamma=2)  # Focal Loss to focus more on difficult examples
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training function
def train(model, loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(loader)}")

# Evaluation function
def evaluate(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy:.2f}%")
    return all_preds, all_labels

# Train the model
train(model, train_loader, criterion, optimizer, epochs=10)

# Evaluate on test data
preds, labels = evaluate(model, test_loader)

# Confusion Matrix
cm = confusion_matrix(labels, preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Fail', 'Pass'])
disp.plot(cmap=plt.cm.Blues)
plt.show()
