In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from PIL import Image
import os
import pandas as pd
import numpy as np
import random
from cyclegan_model import CycleGAN

class EyeDiseaseDataset(Dataset):
    def __init__(self, blue_light_dir=None, white_light_dir=None, png_dir=None, csv_file=None, transform=None, balance_data=True, use_cyclegan=True):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.use_cyclegan = use_cyclegan

        keratitis_images = []
        epiphora_images = []
        normal_images = []

        if blue_light_dir and os.path.exists(blue_light_dir):
            blue_light_files = [f for f in os.listdir(blue_light_dir) if f.lower().endswith('.tif')]
            keratitis_images = [(os.path.join(blue_light_dir, f), 2) for f in blue_light_files]

        if white_light_dir and os.path.exists(white_light_dir):
            white_light_files = [f for f in os.listdir(white_light_dir) if f.lower().endswith('.tif')]
            keratitis_images += [(os.path.join(white_light_dir, f), 2) for f in white_light_files]

        if csv_file and png_dir:
            data = pd.read_csv(csv_file).dropna()
            data.iloc[:, 0] = data.iloc[:, 0].astype(str)
            
            label_mapping = {"normal": 0, "mild": 1, "moderate": 1, "severe": 1}
            png_files = [(os.path.join(png_dir, f), label_mapping[l]) for f, l in zip(data.iloc[:, 0], data.iloc[:, 1]) if os.path.exists(os.path.join(png_dir, f))]

            for path, label in png_files:
                if label == 0:
                    normal_images.append((path, label))
                elif label == 1:
                    epiphora_images.append((path, label))

        if balance_data:
            num_keratitis = len(keratitis_images)
            normal_images = random.sample(normal_images, min(num_keratitis, len(normal_images)))
            epiphora_images = random.sample(epiphora_images, min(num_keratitis, len(epiphora_images)))

        self.image_paths, self.labels = zip(*(keratitis_images + epiphora_images + normal_images))

        print(f"Balanced Dataset - Normal: {len(normal_images)}, Epiphora: {len(epiphora_images)}, Keratitis: {len(keratitis_images)}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.use_cyclegan:
            image = CycleGAN.translate(image)

        if self.transform:
            image = self.transform(image)

        return image, label

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = EyeDiseaseDataset(
    blue_light_dir=r"C:\Users\Ludovic\Pytorch\ExpoSciences\Datasets\Blue_Light",
    white_light_dir= r"C:\Users\Ludovic\Pytorch\ExpoSciences\Datasets\White_Light",
    png_dir=r"C:\Users\Ludovic\Pytorch\ExpoSciences\Datasets\26172919\train\train",
    csv_file= r"C:\Users\Ludovic\Pytorch\ExpoSciences\Datasets\26172919\SLID_E_information.csv",
    transform=transform,
    balance_data=True
)

train_indices, val_indices = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

def collate_fn(batch):
    batch = [b for b in batch if b[0] is not None]
    return torch.utils.data.dataloader.default_collate(batch) if batch else None

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

class EyeDiseaseModel(nn.Module):
    def __init__(self, num_classes=3):
        super(EyeDiseaseModel, self).__init__()
        self.model = models.resnet34(pretrained=True)
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EyeDiseaseModel(num_classes=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

MODEL_PATH = "eye_disease_model.pth"

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=50):
    best_accuracy = 0.0
    if os.path.exists(MODEL_PATH):
        model.load_state_dict(torch.load(MODEL_PATH))
        print("Loaded pre-trained model.")
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0
        
        for batch in train_loader:
            if batch is None: continue
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_preds += (predicted == labels).sum().item()
            total_preds += labels.size(0)
        
        val_accuracy = correct_preds / total_preds * 100
        print(f"Epoch {epoch+1}: Loss {running_loss:.4f}, Accuracy {val_accuracy:.2f}%")
        
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), MODEL_PATH)
            print("Saved best model.")

def test_model(model, test_loader):
    model.eval()
    correct_preds = 0
    total_preds = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct_preds += (predicted == labels).sum().item()
            total_preds += labels.size(0)
    accuracy = correct_preds / total_preds * 100
    print(f"Test Accuracy: {accuracy:.2f}%")

train_model(model, train_loader, val_loader, criterion, optimizer)
test_model(model, val_loader)