In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import pandas as pd
import cv2
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split


# Simple CNN for classification and localization
class DentalNet(nn.Module):
    def __init__(self):
        super(DentalNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(128 * 80 * 80, 512),
            nn.ReLU(),
            nn.Linear(512, 5)  # 1 for class + 4 for bbox coordinates
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Dataset class
class DentalDataset(Dataset):
    def __init__(self, csv_file, img_dir):
        self.annotations = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_name)
        image = self.transform(image)

        # Get class and bbox
        label = self.annotations.iloc[idx, 3]  # class
        bbox = self.annotations.iloc[idx, 4:8].values.astype(float)  # xmin, ymin, xmax, ymax

        target = np.concatenate(([label], bbox))
        return image, torch.FloatTensor(target)

def extract_treated_teeth(model, img_dir, csv_file, output_dir,dataloader):
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Load data
    dataset = DentalDataset(csv_file, img_dir)
    # dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    with torch.no_grad():
        for i, (images, targets) in enumerate(dataloader):
            images = images.to(device)
            outputs = model(images)

            # Get original image name
            img_name = dataset.annotations.iloc[i, 0]

            # Get predictions
            pred_class = outputs[0][0].item() > 0.5  # threshold at 0.5
            pred_bbox = outputs[0][1:].cpu().numpy()

            if pred_class:  # if treated tooth detected
                # Load original image
                orig_img = cv2.imread(os.path.join(img_dir, img_name))

                # Extract coordinates
                xmin, ymin, xmax, ymax = map(int, pred_bbox)

                # Crop the tooth
                cropped_tooth = orig_img[ymin:ymax, xmin:xmax]

                # Save cropped image
                output_path = os.path.join(output_dir, f'treated_{img_name}')
                cv2.imwrite(output_path, cropped_tooth)

# Training function
def train_model(model, train_loader, criterion, optimizer, num_epochs=50):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, targets in train_loader:
            images, targets = images.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            # Calculate loss
            class_loss = criterion(outputs[:, 0], targets[:, 0])
            bbox_loss = criterion(outputs[:, 1:], targets[:, 1:])
            loss = class_loss + bbox_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}')

# Usage
def main():
    # Initialize model
    model = DentalNet()

    # Define loss and optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Load dataset and split into train/test
    dataset = DentalDataset('annotations/filtered_annotations.csv', 'filtered_images')
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # Train model
    train_model(model, train_loader, criterion, optimizer)

    # Save model
    torch.save(model.state_dict(), 'dental_model.pth')

    # Extract treated teeth
    extract_treated_teeth(
        model,
        'filtered_images',
        'annotations/filtered_annotations.csv',
        'output_treated_teeth',
        test_loader
    )

if __name__ == '__main__':
    main()