In [9]:
from skimage.feature import graycomatrix, graycoprops, local_binary_pattern
from skimage.io import imread
import numpy as np
import os
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import pandas as pd
import torch
from PIL import Image


class CervicalCellDataset(Dataset):
    """
    Retorna:
        img_rgb  : tensor [3,70,70] normalizado
        features : tensor [n_attr]
        label    : tensor scalar
        cell_id  : inteiro (ou string se preferir)
    """
    def __init__(self, rgb_dirs, features_csv, transform=None, mean=None, std=None):
        super().__init__()

        if transform is not None:
            self.transform = transform
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        # ---- 1. ler CSV de atributos
        df = pd.read_csv(features_csv)
        self.attr_names = df.columns[2:-1]  # feat_0 ... feat_n
        feature_map = {
            f"{row.image_name}_celula_{row.cell_id}.png": (
                row[self.attr_names].astype(np.float32).values,
                int(row.label),
                row.cell_id
            )
            for _, row in df.iterrows()
        }

        # ---- 2. indexar imagens e dados
        self.image_paths, self.features, self.labels, self.image_names, self.cell_ids = [], [], [], [], []
        for d in rgb_dirs:
            for f in os.listdir(d):
                if f.endswith(".png") and f in feature_map:
                    self.image_paths.append(os.path.join(d, f))
                    feat_vec, lab, cid = feature_map[f]
                    self.features.append(feat_vec)
                    self.labels.append(lab)
                    self.image_names.append(f)
                    self.cell_ids.append(cid)

        # ---- 3. normalizar atributos
        feats = np.stack(self.features)
        if mean is None or std is None:
            self.mean, self.std = feats.mean(0), feats.std(0) + 1e-8
        else:
            self.mean, self.std = mean, std
        self.features = (feats - self.mean) / self.std

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = self.transform(img)
        feat = torch.tensor(self.features[idx], dtype=torch.float32)
        lab = torch.tensor(self.labels[idx], dtype=torch.long)
        image_name = self.image_names[idx]
        cell_id = self.cell_ids[idx]
        return img, feat, lab, image_name, cell_id


def get_dataset_from_dirs(image_dirs, features_csv, transform=None):
    return CervicalCellDataset(rgb_dirs=image_dirs, features_csv=features_csv, transform=transform)


def get_train_dataset(transform=None):
    train_dirs = [
        "E:/datasets/imagens/treino/treino/2classes/treino-dir-positivo-rgb",
        "E:/datasets/imagens/treino/treino/2classes/treino-dir-negativo-rgb"
    ]
    features_csv = "/Users/xr4good/Desktop/Ingrid/DIFF/train_2classes.csv"
    dataset = get_dataset_from_dirs(train_dirs, features_csv, transform=transform)
    return dataset, dataset.mean, dataset.std


def get_val_dataset(mean, std, transform=None):
    val_dirs = [
        "E:/datasets/imagens/validacao/validacao/2classes/validacao-dir-positivo-rgb",
        "E:/datasets/imagens/validacao/validacao/2classes/validacao-dir-negativo-rgb"
    ]
    features_csv = "/Users/xr4good/Desktop/Ingrid/DIFF/val_2classes.csv"
    return CervicalCellDataset(val_dirs, features_csv, transform=transform, mean=mean, std=std)


def get_test_dataset(mean, std, transform=None):
    test_dirs = [
        "E:/datasets/imagens/teste/teste/2classes/teste-dir-positivo-rgb",
        "E:/datasets/imagens/teste/teste/2classes/teste-dir-negativo-rgb"
    ]
    features_csv = "/Users/xr4good/Desktop/Ingrid/DIFF/test_2classes.csv"
    return CervicalCellDataset(test_dirs, features_csv, transform=transform, mean=mean, std=std)


def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def load_checkpoint(model, ckpt_path):
    checkpoint = torch.load(ckpt_path)
    if 'state_dict' in checkpoint:
        checkpoint = checkpoint['state_dict']
    ckpt = {}
    for k, v in checkpoint.items():
        ckpt[k[7:] if k.startswith('module.') else k] = v
    model.load_state_dict(ckpt)
