# Classificação de Estágios da Doença de Alzheimer

## CNN (ResNet)

In [None]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

# Definir transformações
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Redimensionar as imagens
    transforms.ToTensor(),  # Converter para tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalizar
])

# Criar um dataset personalizado
class CustomDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

# Obter caminhos das imagens
dataset_folder = "./data/Alzheimer_MRI_4_classes_dataset"
folders = ["MildDemented", "ModerateDemented", "NonDemented", "VeryMildDemented"]
image_paths = []

for folder in folders:
    folder_path = os.path.join(dataset_folder, folder)
    if os.path.exists(folder_path):
        for img_name in os.listdir(folder_path):
            img_path = os.path.join(folder_path, img_name)
            image_paths.append(img_path)

# Dividir os dados em treinamento, validação e teste
train_paths, test_paths = train_test_split(image_paths, test_size=0.2, random_state=42)
train_paths, val_paths = train_test_split(train_paths, test_size=0.25, random_state=42)  # 0.25 * 0.8 = 0.2

# Criar datasets
train_dataset = CustomDataset(train_paths, transform=transform)
val_dataset = CustomDataset(val_paths, transform=transform)
test_dataset = CustomDataset(test_paths, transform=transform)

# Criar dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Agora você pode usar train_loader, val_loader e test_loader para treinar seu modelo