In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
from skimage.measure import regionprops, label
from skimage.color import rgb2gray
import mahotas as mh
from skimage.feature import graycomatrix, graycoprops, local_binary_pattern
import networkx as nx
from scipy.stats import pearsonr
import torch
from torch_geometric.data import Data, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, BatchNorm
from torch_geometric.data import DataLoader
from sklearn.preprocessing import MinMaxScaler

# Função para carregar o JSON
def load_json(json_path):
    with open(json_path) as f:
        data = json.load(f)
    return data

# Função para obter o bounding box ao redor do núcleo
def get_bounding_box(nucleus_x, nucleus_y, box_size=100):
    half_size = box_size // 2
    min_row = max(nucleus_y - half_size, 0)
    max_row = min(nucleus_y + half_size, 1020)
    min_col = max(nucleus_x - half_size, 0)
    max_col = min(nucleus_x + half_size, 1376)
    return (min_row, max_row, min_col, max_col)

# Função para extrair regiões das imagens
def extract_regions(image, json_data, box_size=100):
    regions = []
    for cell in json_data:
        nucleus_x = cell['nucleus_x']
        nucleus_y = cell['nucleus_y']
        
        min_row, max_row, min_col, max_col = get_bounding_box(nucleus_x, nucleus_y, box_size)
        region = image[min_row:max_row, min_col:max_col]
        regions.append(region)
    
    return regions

# Função para extrair atributos de uma região
def extract_features_from_region(region):
    gray_region = rgb2gray(region)
    
    if np.sum(gray_region) == 0:
        print("Erro: Região em branco ou inválida.")
        return None

    gray_region = (gray_region * 255).astype(np.uint8)
    labeled_region = label(gray_region)

    props = regionprops(labeled_region)
    if len(props) == 0:
        return None

    props = props[0]

    # Características básicas
    area = props.area
    perimeter = props.perimeter
    eccentricity = props.eccentricity
    solidity = props.solidity
    equivalent_diameter = props.equivalent_diameter
    convex_area = props.convex_area
    extent = props.extent
    major_axis_length = props.major_axis_length
    minor_axis_length = props.minor_axis_length

    # Características de intensidade
    mean_intensity = np.mean(gray_region)
    std_intensity = np.std(gray_region)

    # Características de textura usando GLCM
    glcm = graycomatrix(gray_region, distances=[1], angles=[0], levels=256, symmetric=True, normed=True)
    glcm_entropy = -np.sum(glcm * np.log(glcm + np.finfo(float).eps))
    contrast = graycoprops(glcm, 'contrast')[0, 0]

    # Extraindo características de Haralick
    haralick_features = mh.features.haralick(gray_region).mean(axis=0)

    # Extraindo características de LBP
    lbp = local_binary_pattern(gray_region, P=8, R=1, method='uniform')
    lbp_hist, _ = np.histogram(lbp, bins=np.arange(0, 11), range=(0, 10))

    # Novas características estatísticas
    skewness = np.mean((gray_region - mean_intensity)**3) / (std_intensity**3) if std_intensity > 0 else 0
    kurtosis = np.mean((gray_region - mean_intensity)**4) / (std_intensity**4) - 3 if std_intensity > 0 else 0

    # Preenchendo a lista de características
    features = [
        area, perimeter, eccentricity, solidity, equivalent_diameter, convex_area,
        extent, major_axis_length, minor_axis_length, mean_intensity,
        std_intensity, glcm_entropy, contrast, skewness, kurtosis
    ] + list(haralick_features) + list(lbp_hist)

    
    return features

# Função para mapear classificações para rótulos
def map_labels(classifications):
    labels = []
    for classification_group in classifications:
        for cell in classification_group:
            if cell['bethesda_system'] == 'POSITIVE':
                labels.append(1)
            else:
                labels.append(0)
    return torch.tensor(labels, dtype=torch.long)

# Função para verificar se os grafos criados contêm nós e arestas válidos
def build_graph_across_cells(features, threshold=0.8):
    num_cells = len(features)
    num_attributes = len(features[0])  # Número de atributos por célula
    
    graphs = []
    
    # Agora vamos construir o grafo para cada célula
    for feature_set in features:
        G = nx.Graph()
        
        # Adiciona os nós (atributos) com cada feature como um nó individual
        for i, feature in enumerate(feature_set):
            G.add_node(i, feature=feature)
        
        # Calcula a correlação de Pearson entre os atributos de diferentes células
        for i in range(num_attributes):
            for j in range(i + 1, num_attributes):
                # Calcula a correlação entre os atributos i e j em várias células
                attribute_i = [cell[i] for cell in features]
                attribute_j = [cell[j] for cell in features]
                
                # Verifica se ambos os arrays têm variância para evitar cálculo em arrays constantes
                if np.std(attribute_i) > 0 and np.std(attribute_j) > 0:
                    corr, _ = pearsonr(attribute_i, attribute_j)
                    
                    # Se a correlação for maior que o threshold, cria uma aresta
                    if corr > threshold:
                        G.add_edge(i, j, weight=corr)
        
        # Verificar se o grafo contém nós e arestas válidos
        if G.number_of_nodes() > 0 and G.number_of_edges() > 0:
            graphs.append(G)
    
    return graphs

def create_pyg_data_with_batching(features, graphs, labels):
    data_list = []
    for i in range(len(features)):
        num_features_por_nó = len(features[i]) if features[i] is not None else 0
        
        if num_features_por_nó == 0:
            print(f"Erro: Não há features para o índice {i}.")
            continue

        x = torch.tensor(features[i], dtype=torch.float).view(-1, num_features_por_nó)

        # Verifique se graphs[i] é realmente um grafo com o método edges
        if not hasattr(graphs[i], 'edges'):
            print(f"Erro: O item graphs[{i}] não é um grafo válido. Tipo encontrado: {type(graphs[i])}")
            continue

        edge_index = torch.tensor(list(graphs[i].edges), dtype=torch.long).t().contiguous()
        
        if edge_index.size(0) != 2:
            print(f"Erro: edge_index para o índice {i} não tem formato correto.")
            continue
        
        if edge_index.max().item() >= x.size(0):
            print(f"Erro: edge_index para o índice {i} tem valores fora do intervalo. Max index: {edge_index.max().item()} para num_nodes: {x.size(0)}")
            continue

        label = labels[i] if i < len(labels) else -1
        data = Data(x=x, edge_index=edge_index, y=torch.tensor([label], dtype=torch.long))
        
        print(f"Data {i} - x shape: {data.x.shape}, edge_index shape: {data.edge_index.shape}, y: {data.y}")
        data_list.append(data)

    return data_list


# Definição do modelo GCN com Global Pooling
class GCNGraphLevel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, max_features=136, dropout=0.5):
        super(GCNGraphLevel, self).__init__()
        self.max_features = max_features
        self.fc = nn.Linear(max_features, max_features)  # Ajusta a entrada para max_features
        self.batch_norm1 = BatchNorm(max_features)
        self.conv1 = GCNConv(max_features, hidden_dim)
        self.batch_norm2 = BatchNorm(hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.batch_norm3 = BatchNorm(hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        self.dropout = dropout

    def forward(self, data):
        # Padding das features
        x = data.x
        if x.size(1) < self.max_features:
            padding = torch.zeros(x.size(0), self.max_features - x.size(1), device=x.device)
            x = torch.cat([x, padding], dim=1)
        elif x.size(1) > self.max_features:
            x = x[:, :self.max_features]

        # Camada densa e normalização
        x = self.fc(x)
        x = self.batch_norm1(x)
        edge_index = data.edge_index
        
        # Normalização do grafo
        num_nodes = x.size(0)
        edge_index, edge_weight = gcn_norm(edge_index, data.edge_attr, num_nodes=num_nodes, dtype=x.dtype)
        
        # Camadas convolucionais GCN com normalização e dropout
        x = self.conv1(x, edge_index, edge_weight)
        x = self.batch_norm2(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.conv2(x, edge_index)
        x = self.batch_norm3(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.conv3(x, edge_index)
        
        # Pooling global
        x = global_mean_pool(x, data.batch)

        return x


# Função de treino que inclui validação
def train_with_validation(model, train_loader, val_loader, optimizer, criterion, epochs=1000, log_interval=10):
    model.train()
    train_losses = []
    val_accuracies = []

    for epoch in range(epochs):
        epoch_loss = 0  # Acumular a perda por época
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        # Armazena a perda média a cada log_interval épocas
        if (epoch + 1) % log_interval == 0:
            avg_loss = epoch_loss / len(train_loader)
            train_losses.append(avg_loss)

            val_accuracy = test(model, val_loader)
            val_accuracies.append(val_accuracy)
            print(f'Epoch {epoch + 1}, Avg Loss: {avg_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')

    return train_losses, val_accuracies

# Função para testar o modelo
def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    for batch in loader:  # Iterar sobre os batches no DataLoader
        out = model(batch)
        pred = out.argmax(dim=1)
        correct += (pred == batch.y).sum().item()
        total += len(batch.y)
    return correct / total


# Diretórios das imagens e arquivos JSON
train_image_dir = "/Users/xr4good/Documents/Ingrid/datasets/imagens/treino"
train_json_path = "/Users/xr4good/Documents/Ingrid/datasets/imagens/treino/treino.json"
val_image_dir = "/Users/xr4good/Documents/Ingrid/datasets/imagens/validacao"
val_json_path = "/Users/xr4good/Documents/Ingrid/datasets/imagens/validacao/validacao.json"
test_image_dir = "/Users/xr4good/Documents/Ingrid/datasets/imagens/testeboundingbox"
test_json_path = "/Users/xr4good/Documents/Ingrid/datasets/imagens/testeboundingbox/teste.json"

# Especificar o diretório de saída
output_dir = "/Users/xr4good/Documents/Ingrid/"


# Carregar JSONs
train_data = load_json(train_json_path)
val_data = load_json(val_json_path)
test_data = load_json(test_json_path)


# Função para processar um conjunto de dados e construir os grafos e features
def process_data(data, image_dir, image_range):
    features_list = []
    graphs_list = []
    labels_list = []

    for entry in data:
        image_id = entry['image_id']
        
        # Verificar se a imagem está dentro do intervalo desejado
        if image_id not in image_range:
            continue
        
        image_name = f"cric_{image_id}.png"
        image_path = os.path.join(image_dir, image_name)
        
        image = io.imread(image_path)
        if image.shape[2] == 4:  # Converter para RGB se necessário
            image = image[:, :, :3]

        classifications = entry['classifications']
        
        # Extrair regiões e construir grafo para cada célula na imagem
        features = []
        labels = []
        regions = extract_regions(image, classifications)
        
        for idx, region in enumerate(regions):
            feature_vector = extract_features_from_region(region)
            if feature_vector is not None:
                features.append(feature_vector)
                bethesda_system = classifications[idx]['bethesda_system']
                labels.append(1 if bethesda_system == "POSITIVE" else 0)
        
        # Construir o grafo com base nas features
        if features:
            cell_graphs = build_graph_across_cells(features)
            features_list.append(features)
            graphs_list.extend(cell_graphs)  # Adicionar cada grafo individualmente
            labels_list.extend(labels)  # Adicionar todas as labels das células

    return features_list, graphs_list, labels_list

# Carregar JSONs
train_data = load_json(train_json_path)
val_data = load_json(val_json_path)
test_data = load_json(test_json_path)

# Definir intervalos de imagens para cada conjunto
train_image_range = list(range(1, 160)) + list(range(162, 281))  # cric_1 a cric_280, exceto 160 e 161
val_image_range = list(range(282, 343))  # cric_282 a cric_342
test_image_range = list(range(343, 400))  # cric_343 a cric_399

# Processar os conjuntos de treino, validação e teste
train_features, train_graphs, train_labels = process_data(train_data, train_image_dir, train_image_range)
val_features, val_graphs, val_labels = process_data(val_data, val_image_dir, val_image_range)
test_features, test_graphs, test_labels = process_data(test_data, test_image_dir, test_image_range)

# Preparar dados para PyTorch Geometric
train_data_list = create_pyg_data_with_batching(train_features, train_graphs, train_labels)
val_data_list = create_pyg_data_with_batching(val_features, val_graphs, val_labels)
test_data_list = create_pyg_data_with_batching(test_features, test_graphs, test_labels)

# Criar DataLoaders
train_loader = DataLoader(train_data_list, batch_size=1, shuffle=True)
val_loader = DataLoader(val_data_list, batch_size=1, shuffle=False)
test_loader = DataLoader(test_data_list, batch_size=1, shuffle=False)

# Definir o modelo, critério e otimizador
num_node_features = len(train_features[0][0])  # Número de características de cada nó
num_classes = 2  # Classificação binária
model = GCNGraphLevel(num_node_features, hidden_dim=64, output_dim=num_classes)

criterion = nn.CrossEntropyLoss()
# Definindo o otimizador com weight decay
learning_rate =  0.00005  # Ajuste de acordo com sua necessidade
weight_decay = 1e-3  # Por exemplo, 0.0001

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)


# Treinar o modelo com validação
epochs = 1000
# Treinar o modelo
train_losses, val_accuracies = train_with_validation(
    model, train_loader, val_loader, optimizer, criterion, epochs=1000, log_interval=10
)


# Plotar os resultados
epochs_range = list(range(10, epochs + 1, 10))  # Intervalos de 10 épocas

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train Loss over Epochs')
plt.legend()

# Salvar o gráfico de Train Loss
train_loss_path = os.path.join(output_dir, "train_loss.png")
plt.savefig(train_loss_path)
print(f"Gráfico de Train Loss salvo em {train_loss_path}")

plt.subplot(1, 2, 2)
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch (every 10)')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy over Epochs')
plt.legend()

# Salvar o gráfico de Validation Accuracy
val_accuracy_path = os.path.join(output_dir, "val_accuracy.png")
plt.savefig(val_accuracy_path)
print(f"Gráfico de Validation Accuracy salvo em {val_accuracy_path}")

plt.tight_layout()
plt.show()


# Testar o modelo no conjunto de teste
test_accuracy = test(model, test_loader)
print(f'Test Accuracy: {test_accuracy:.4f}')
