<a href="https://colab.research.google.com/github/hatef-hosseinpour/dental/blob/main/dental_classify_binary_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import numpy as np
from PIL import Image

In [3]:
IMAGE_SIZE = 100
BATCH_SIZE = 32
LEARNING_RATE = 0.001
EPOCHS = 10
NORMALIZATION_MEAN = [0.5, 0.5, 0.5]
NORMALIZATION_STD = [0.5, 0.5, 0.5]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(NORMALIZATION_MEAN, NORMALIZATION_STD)
])

In [5]:
class DentalDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [6]:
def load_data(data_dir, label_mapping):
    images = []
    labels = []
    for folder, label in label_mapping.items():
        folder_path = os.path.join(data_dir, folder)
        for file_name in os.listdir(folder_path):
            file_path = os.path.join(folder_path, file_name)
            image = cv2.imread(file_path, cv2.IMREAD_COLOR)
            if image is not None:
                image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
                image = Image.fromarray(image)
                images.append(image)
                labels.append(label)
    return images, labels

In [7]:
def save_best_model(model, val_accuracy, best_accuracy, model_name):
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), f"/content/drive/MyDrive/Dentisrty/data/{model_name}_best.pth")
    return best_accuracy

In [8]:
def train_model(model, train_loader, val_loader, test_loader, criterion, optimizer, model_name):
    best_val_accuracy = 0.0
    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader.dataset):.4f}')

        # After each epoch, evaluate the model on the validation set
        val_accuracy = evaluate_model(model, val_loader)
        best_val_accuracy = save_best_model(model, val_accuracy, best_val_accuracy, model_name)
        print(f"Validation Accuracy after Epoch {epoch+1}: {val_accuracy:.4f}")

    print(f"Best model for {model_name} saved with validation accuracy: {best_val_accuracy:.4f}")

    # After training, evaluate the model on the test set
    test_accuracy = evaluate_model(model, test_loader)
    print(f"Test Accuracy for {model_name}: {test_accuracy:.4f}")

In [9]:
def evaluate_model(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    print(f'Accuracy: {accuracy:.4f}')
    return accuracy

In [10]:
def build_pretrained_model(model_name, num_classes):
    if model_name == 'resnet18':
        model = models.resnet18(pretrained=True)
    #elif model_name == 'vgg16':
    #   model = models.vgg16(pretrained=True)

    #   model = models.mobilenet_v2(pretrained=True)
    # Add more models here as needed

    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(DEVICE)

In [11]:
# Predict function for unseen images
def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
    image = data_transforms(image).unsqueeze(0)
    return image

In [12]:
def load_model(model_path, model_class):
    model = model_class()
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model

In [13]:
def predict_image(image_path, model_path1, model_path2, model_class):
    image_tensor = preprocess_image(image_path).to(DEVICE)

    # Model 1: Caries detection
    model1 = load_model(model_path1, model_class).to(DEVICE)

    with torch.no_grad():
        output1 = model1(image_tensor)
        prob1 = torch.sigmoid(output1).item()

        if prob1 < 0.5:  # Predicts Caries (class 0)
            print("Caries")
        else:
            # Load and use model2 (to check if it's Amalgam or Normal)
            model2 = load_model(model_path2, model_class).to(DEVICE)
            output2 = model2(image_tensor)
            prob2 = torch.sigmoid(output2).item()

            # Model 2: Amalgam vs Normal
            if prob2 < 0.5:
                print("Amalgam")
            else:
                print("Normal")

In [15]:
if __name__ == '__main__':
    data_dir = '/content/drive/MyDrive/Dentisrty/data'  # Update this path

    # Step 1: Define class mappings
    folder_to_label_mapping_caries = {'Caries': 0, 'Amalgam': 1, 'Normal': 1}
    folder_to_label_mapping_non_caries = {'Amalgam': 0, 'Normal': 1}

    # Step 2: Load data for both models
    print("Loading data...")
    images_caries, labels_caries = load_data(data_dir, folder_to_label_mapping_caries)
    images_non_caries, labels_non_caries = load_data(data_dir, folder_to_label_mapping_non_caries)

    # Step 3: Split data into train (80%) and test (20%) sets
    print("Splitting data...")
    X_train_caries, X_test_caries, y_train_caries, y_test_caries = train_test_split(images_caries, labels_caries, test_size=0.2, random_state=42)
    X_train_non_caries, X_test_non_caries, y_train_non_caries, y_test_non_caries = train_test_split(images_non_caries, labels_non_caries, test_size=0.2, random_state=42)

    # Step 4: Split train data into train (80%) and validation (20%)
    X_train_caries, X_val_caries, y_train_caries, y_val_caries = train_test_split(X_train_caries, y_train_caries, test_size=0.2, random_state=42)
    X_train_non_caries, X_val_non_caries, y_train_non_caries, y_val_non_caries = train_test_split(X_train_non_caries, y_train_non_caries, test_size=0.2, random_state=42)

    # Step 5: Prepare dataloaders for both models (train, val, and test)
    print("Preparing dataloaders...")
    train_dataset_caries = DentalDataset(X_train_caries, y_train_caries, transform=data_transforms)
    val_dataset_caries = DentalDataset(X_val_caries, y_val_caries, transform=data_transforms)
    test_dataset_caries = DentalDataset(X_test_caries, y_test_caries, transform=data_transforms)

    train_loader_caries = DataLoader(train_dataset_caries, batch_size=BATCH_SIZE, shuffle=True)
    val_loader_caries = DataLoader(val_dataset_caries, batch_size=BATCH_SIZE, shuffle=False)
    test_loader_caries = DataLoader(test_dataset_caries, batch_size=BATCH_SIZE, shuffle=False)

    train_dataset_non_caries = DentalDataset(X_train_non_caries, y_train_non_caries, transform=data_transforms)
    val_dataset_non_caries = DentalDataset(X_val_non_caries, y_val_non_caries, transform=data_transforms)
    test_dataset_non_caries = DentalDataset(X_test_non_caries, y_test_non_caries, transform=data_transforms)

    train_loader_non_caries = DataLoader(train_dataset_non_caries, batch_size=BATCH_SIZE, shuffle=True)
    val_loader_non_caries = DataLoader(val_dataset_non_caries, batch_size=BATCH_SIZE, shuffle=False)
    test_loader_non_caries = DataLoader(test_dataset_non_caries, batch_size=BATCH_SIZE, shuffle=False)

    # Step 6: Train multiple models for both Caries and Non-Caries detection
    models_to_test = ['resnet18']  # You can add more models like 'vgg16', 'densenet121', 'mobilenet_v2'

    for model_name in models_to_test:
        # Caries Detection Model (2 classes: Caries and Non-Caries)
        model_caries = build_pretrained_model(model_name, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model_caries.parameters(), lr=LEARNING_RATE)
        print(f"Training model: {model_name} for Caries Detection")
        train_model(model_caries, train_loader_caries, val_loader_caries, test_loader_caries, criterion, optimizer, model_name + "_caries")

        # Non-Caries Detection Model (2 classes: Amalgam and Normal)
        model_non_caries = build_pretrained_model(model_name, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model_non_caries.parameters(), lr=LEARNING_RATE)
        print(f"Training model: {model_name} for Non-Caries Detection")
        train_model(model_non_caries, train_loader_non_caries, val_loader_non_caries, test_loader_non_caries, criterion, optimizer, model_name + "_non_caries")

Loading data...
Splitting data...
Preparing dataloaders...
Training model: resnet18 for Caries Detection
Epoch 1/10, Loss: 0.9048
Accuracy: 0.2800
Validation Accuracy after Epoch 1: 0.2800
Epoch 2/10, Loss: 0.4708
Accuracy: 0.6800
Validation Accuracy after Epoch 2: 0.6800
Epoch 3/10, Loss: 0.4631
Accuracy: 0.7600
Validation Accuracy after Epoch 3: 0.7600
Epoch 4/10, Loss: 0.0284
Accuracy: 0.6400
Validation Accuracy after Epoch 4: 0.6400
Epoch 5/10, Loss: 0.0468
Accuracy: 0.6000
Validation Accuracy after Epoch 5: 0.6000
Epoch 6/10, Loss: 0.0867
Accuracy: 0.5600
Validation Accuracy after Epoch 6: 0.5600
Epoch 7/10, Loss: 0.0790
Accuracy: 0.6000
Validation Accuracy after Epoch 7: 0.6000
Epoch 8/10, Loss: 0.0101
Accuracy: 0.7200
Validation Accuracy after Epoch 8: 0.7200
Epoch 9/10, Loss: 0.0338
Accuracy: 0.7600
Validation Accuracy after Epoch 9: 0.7600
Epoch 10/10, Loss: 0.0612
Accuracy: 0.7600
Validation Accuracy after Epoch 10: 0.7600
Best model for resnet18_caries saved with validation 



Training model: resnet18 for Non-Caries Detection
Epoch 1/10, Loss: 0.7869
Accuracy: 0.3333
Validation Accuracy after Epoch 1: 0.3333
Epoch 2/10, Loss: 0.4751
Accuracy: 0.8095
Validation Accuracy after Epoch 2: 0.8095
Epoch 3/10, Loss: 0.3119
Accuracy: 0.3810
Validation Accuracy after Epoch 3: 0.3810
Epoch 4/10, Loss: 0.1155
Accuracy: 0.3333
Validation Accuracy after Epoch 4: 0.3333
Epoch 5/10, Loss: 0.1525
Accuracy: 0.6190
Validation Accuracy after Epoch 5: 0.6190
Epoch 6/10, Loss: 0.1069
Accuracy: 0.6667
Validation Accuracy after Epoch 6: 0.6667
Epoch 7/10, Loss: 0.0115
Accuracy: 0.6667
Validation Accuracy after Epoch 7: 0.6667
Epoch 8/10, Loss: 0.0249
Accuracy: 0.7143
Validation Accuracy after Epoch 8: 0.7143
Epoch 9/10, Loss: 0.0210
Accuracy: 0.7143
Validation Accuracy after Epoch 9: 0.7143
Epoch 10/10, Loss: 0.0049
Accuracy: 0.9048
Validation Accuracy after Epoch 10: 0.9048
Best model for resnet18_non_caries saved with validation accuracy: 0.9048
Accuracy: 0.9259
Test Accuracy for