In [None]:
# Import necessary libraries
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import train_test_split
from torch import nn, optim

# Define utility functions
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

# Data Preprocessor
class DataPreprocessor:
    def __init__(self, input_dir, output_dir, size=(256, 256)):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.size = size
        create_dir(self.output_dir)

    def preprocess_image(self, image_path, output_path):
        image = cv2.imread(image_path)
        if image is None:
            print(f"Error: Unable to read image {image_path}")
            return
        image = cv2.resize(image, self.size)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        cv2.imwrite(output_path, image)
        print(f"Processed {output_path}")

    def preprocess(self):
        for subdir, _, files in os.walk(self.input_dir):
            for file in files:
                input_path = os.path.join(subdir, file)
                output_subdir = subdir.replace(self.input_dir, self.output_dir)
                create_dir(output_subdir)
                output_path = os.path.join(output_subdir, file)
                self.preprocess_image(input_path, output_path)

# Data Augmentor
class DataAugmentor:
    def __init__(self, input_dir, output_dir, augmentations, num_samples=1000):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.augmentations = transforms.Compose(augmentations)
        self.num_samples = num_samples
        create_dir(self.output_dir)

    def augment_data(self):
        for class_dir in ['def_front', 'ok_front']:
            class_input_dir = os.path.join(self.input_dir, class_dir)
            class_output_dir = os.path.join(self.output_dir, class_dir)
            create_dir(class_output_dir)
            images = [f for f in os.listdir(class_input_dir) if os.path.isfile(os.path.join(class_input_dir, f))]
            
            if len(images) == 0:
                print(f"No images found in {class_input_dir}")
                continue
            
            for i in range(self.num_samples):
                img_name = images[i % len(images)]
                image = Image.open(os.path.join(class_input_dir, img_name))
                augmented_image = self.augmentations(image)
                augmented_image.save(os.path.join(class_output_dir, f"augmented_{i}_{img_name}"))
                print(f"Saved augmented image: {os.path.join(class_output_dir, f'augmented_{i}_{img_name}')}")

# Dataset class
class MedicalImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.classes = ['def_front', 'ok_front']
        self.transform = transform if transform else transforms.ToTensor()
        self.image_paths = []
        self.labels = []

        for label, class_dir in enumerate(self.classes):
            class_path = os.path.join(self.image_dir, class_dir)
            for image_name in os.listdir(class_path):
                self.image_paths.append(os.path.join(class_path, image_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Load image paths and labels
def load_image_paths_and_labels(image_dir):
    classes = ['def_front', 'ok_front']
    image_paths = []
    labels = []

    for label, class_dir in enumerate(classes):
        class_path = os.path.join(image_dir, class_dir)
        for image_name in os.listdir(class_path):
            image_paths.append(os.path.join(class_path, image_name))
            labels.append(label)
    
    return image_paths, labels

# CNN Model
class CNNModel(nn.Module):
    def __init__(self, num_classes):
        super(CNNModel, self).__init__()
        self.model = models.resnet18(weights='IMAGENET1K_V1')
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.model(x)

# Training and evaluation functions
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)

        # Evaluate on validation set
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

        print(f'Epoch {epoch + 1}/{num_epochs}, '
              f'Train Loss: {epoch_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss /= len(dataloader.dataset)
    val_acc = correct / total
    return val_loss, val_acc

# Configuration class
class Config:
    RAW_DATA_DIR = '../data/casting_data'
    PROCESSED_DATA_DIR = '../data/processed_data'
    AUGMENTED_DATA_DIR = '../data/augmentation'
    MODEL_PATH = 'models/trained_models/model.pth'
    BATCH_SIZE = 32
    LEARNING_RATE = 0.001
    NUM_EPOCHS = 25
    IMAGE_SIZE = (256, 256)

    @staticmethod
    def train_transforms():
        return transforms.Compose([
            transforms.Resize(Config.IMAGE_SIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

    @staticmethod
    def test_transforms():
        return transforms.Compose([
            transforms.Resize(Config.IMAGE_SIZE),
            transforms.ToTensor(),
        ])

# Main script
if __name__ == '__main__':
    # Preprocess data
    base_dir = os.path.dirname(os.path.abspath(__file__))
    train_input_dir = os.path.join(base_dir, Config.RAW_DATA_DIR, 'train')
    train_output_dir = os.path.join(base_dir, Config.PROCESSED_DATA_DIR, 'train')
    test_input_dir = os.path.join(base_dir, Config.RAW_DATA_DIR, 'test')
    test_output_dir = os.path.join(base_dir, Config.PROCESSED_DATA_DIR, 'test')

    preprocessor = DataPreprocessor(train_input_dir, train_output_dir)
    preprocessor.preprocess()

    preprocessor = DataPreprocessor(test_input_dir, test_output_dir)
    preprocessor.preprocess()

    # Augment data
    augmentations = [
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(Config.IMAGE_SIZE[0], scale=(0.8, 1.0)),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
    ]
    
    train_input_dir = os.path.join(base_dir, Config.PROCESSED_DATA_DIR, 'train')
    train_output_dir = os.path.join(base_dir, Config.AUGMENTED_DATA_DIR, 'train')
    test_input_dir = os.path.join(base_dir, Config.PROCESSED_DATA_DIR, 'test')
    test_output_dir = os.path.join(base_dir, Config.AUGMENTED_DATA_DIR, 'test')

    augmentor = DataAugmentor(train_input_dir, train_output_dir, augmentations)
    augmentor.augment_data()

    augmentor = DataAugmentor(test_input_dir, test_output_dir, augmentations)
    augmentor.augment_data()

    # Load dataset
    config = Config()
    image_paths, labels = load_image_paths_and_labels(os.path.join(base_dir, Config.PROCESSED_DATA_DIR, 'train'))
    
    # Split data into training and validation sets
    train_paths, val_paths, train_labels, val_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42)
    
    train_dataset = MedicalImageDataset(train_paths, train_labels, transform=config.train_transforms())
    val_dataset = MedicalImageDataset(val_paths, val_labels, transform=config.test_transforms())
    
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
    
    # Initialize model
    model = CNNModel(num_classes=2)
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    criterion = torch.nn.CrossEntropyLoss()
    
    # Train model
    train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=config.NUM_EPOCHS, device=device)
