In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from AlzheimerCNN import AlzheimerCNN

# Configuração do dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Transformações para fine-tuning com 128x128
train_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomAffine(degrees=0, scale=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Ajuste se necessário
])

# Transformações para teste
test_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Carregar os conjuntos de dados
train_dataset = datasets.ImageFolder(root='datasets_fine_tuning/train_1', transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

# Imprimir mapeamento de classes
class_to_idx = train_dataset.class_to_idx
print("Mapeamento de Classes:")
for class_name, class_index in class_to_idx.items():
    print(f"{class_index}: {class_name}")

# Carregar o modelo pré-treinado
model = AlzheimerCNN(num_classes=len(class_to_idx)).to(device)
model.load_state_dict(torch.load('models/best_model_resolution_invariant.pth'))

# Congelar camadas iniciais
for param in model.parameters():
    param.requires_grad = False
for param in model.fc1.parameters():
    param.requires_grad = True
for param in model.fc2.parameters():
    param.requires_grad = True

# Definir hiperparâmetros
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
num_epochs = 100

# Função de treinamento
def train_model(model, train_loader, criterion, optimizer, num_epochs=20):
    model = model.to(device)
    train_losses = []
    train_accuracies = []
    best_accuracy = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        train_loop = tqdm(train_loader, desc=f'Época {epoch+1}/{num_epochs} [Treino]')
        for images, labels in train_loop:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            train_loop.set_postfix(loss=loss.item(), acc=100.*correct/total)

        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100 * correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)

        if epoch_accuracy > best_accuracy:
            best_accuracy = epoch_accuracy
            torch.save(model.state_dict(), 'models/fine_tuned_alzheimer.pth')

        print(f'Época [{epoch+1}/{num_epochs}], '
              f'Perda Treino: {epoch_loss:.4f}, Acurácia Treino: {epoch_accuracy:.2f}%')

    return train_losses, train_accuracies

# Função de avaliação
def evaluate_model(model, test_loader):
    model = model.to(device)
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Avaliando'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    print(classification_report(all_labels, all_preds, target_names=class_to_idx.keys()))
    print("Matriz de Confusão:")
    le = LabelEncoder()
    le.fit_transform(all_labels)
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=le.classes_, yticklabels=le.classes_)
    plt.title('Matriz de Confusão')
    plt.xlabel('Predito')
    plt.ylabel('Verdadeiro')
    plt.show()

    return all_preds, all_labels

# Executar fine-tuning
train_losses, train_accuracies = train_model(model, train_loader, criterion, optimizer, num_epochs)

# Avaliar o modelo no subconjunto de teste
print("Avaliando o modelo no subconjunto de teste...")
#predictions, true_labels = evaluate_model(model, test_loader)

  model.load_state_dict(torch.load('fine_tuned_alzheimer_mri_model.pth'))


cuda
Mapeamento de Classes:
0: Mild Dementia
1: Moderate Dementia
2: Non Demented
3: Very mild Dementia


Época 1/100 [Treino]: 100%|██████████| 75/75 [00:12<00:00,  6.08it/s, acc=49.9, loss=1.88] 


KeyboardInterrupt: 