### Imports

In [None]:
import os
from pathlib import Path
import re
import random
from sklearn.model_selection import KFold
from PIL import Image
from torchvision import transforms
from torch.utils.data import random_split, Dataset, DataLoader
import json
from tqdm import tqdm
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from models.unet import UNet
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import copy
import segmentation_models_pytorch as smp
from torch.autograd import Function
from torch.nn import functional as F
from sklearn.metrics import jaccard_score as iou


### Modelos

In [2]:
def get_model(model_name, encoder):

    if model_name == "unet":
        return smp.Unet(encoder_name=encoder, encoder_weights="imagenet", in_channels=3, classes=3)
    elif model_name == "linknet":
        return smp.Linknet(encoder_name=encoder, encoder_weights="imagenet", in_channels=3, classes=3)
    elif model_name == "pan":
        return smp.PAN(encoder_name=encoder, encoder_weights="imagenet", in_channels=3, classes=3)
    elif model_name == "deeplab":
        return smp.DeepLabV3(encoder_name=encoder, encoder_weights="imagenet", in_channels=3, classes=3)

### Tratamento do dataset

In [3]:
def draw_contours(contours):
    # Criar uma imagem binária inicializada com valores 0 (preto)
    imagem_binaria = np.zeros((256, 256), dtype=np.uint8)

    # Verificar o número de contornos
    if len(contours) == 1:
        # Caso haja apenas um contorno, converte e desenha diretamente
        contornos_np = np.array(contours[0], dtype=np.int32)
        cv2.drawContours(imagem_binaria, [contornos_np], -1, (1, 1, 1), thickness=cv2.FILLED)
    elif len(contours) == 0:
        return imagem_binaria
    else:
        # Caso haja mais de um contorno, converte cada um e desenha
        for contorno in contours:
            contornos_np = np.array(contorno, dtype=np.int32)
            cv2.drawContours(imagem_binaria, [contornos_np], -1, (1, 1, 1), thickness=cv2.FILLED)

    return imagem_binaria
    

def organize_masks(dataset_masks, data, camera, frame):
    
    dict_human = dataset_masks[f'{data}']
    dict_robot = dataset_masks[f'{data}_robot']
    
    for _, all_masks_found in dict_human.items():
        masks_data_camera = all_masks_found[f'subimage_{camera}']
        #print(len(masks_data_camera))
        contours_human = masks_data_camera[frame]
    
    for _, all_masks_found in dict_robot.items():
        masks_data_camera = all_masks_found[f'subimage_{camera}']
        #print(len(masks_data_camera))
        contours_robot = masks_data_camera[frame]
        
        #print(contours_human)
        
    
    return draw_contours(contours_human), draw_contours(contours_robot)
        
        
def transform_masks(mask_human, mask_robot, mask_mode=None):
    # Converter as máscaras para binário (0 ou 1)
    mask_human = np.where(mask_human > 0, 1, 0)
    mask_robot = np.where(mask_robot > 0, 1, 0)
    
    # Cria a máscara final combinando as classes
    if mask_mode == "entropy":
        # Resolver sobreposição dando prioridade para robôs
        mask = np.where(mask_robot == 1, 2, mask_human)  # Robô = 2, Humano = 1
    else:
        # Resolver a máscara criando três classes: 0 (background), 1 (humano), 2 (robô)
        mask = np.zeros_like(mask_human)  # Inicializar com background (0)
        mask[mask_human == 1] = 1  # Definir classe humano como 1
        mask[mask_robot == 1] = 2  # Definir classe robô como 2, sobrepõe humano se necessário

    # Converte a máscara para um tensor PyTorch
    mask_tensor = torch.tensor(mask, dtype=torch.long)
    
    return mask_tensor


# Define a custom dataset class
class CustomImageDataset(Dataset):
    def __init__(self, image_paths, masks, transform=None, mask_mode=None):
        self.image_paths = image_paths
        self.masks = masks
        self.transform = transform
        self.mask_mode = mask_mode

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        #_, file_base_name = os.path.split(img_path)
        file_base_name = os.path.split(img_path)[1].split(".")[0]
        data, frame, camera = self.__get_mask_info__(file_base_name)
        #print(data, frame, camera)
        
        mask_human, mask_robot = organize_masks(self.masks, data, int(camera), int(frame))
        #mask_tensor = transform_masks(mask_human, mask_robot, self.mask_mode)
        
        m0 = np.zeros_like(mask_human)
        final_mask = cv2.merge((m0, mask_robot, mask_human))
        mask_tensor = np.transpose(final_mask, (2,0,1))
        #print(f"Mask human: {mask_human.shape}")
        #print(f"Mask robot: {mask_robot.shape}")

        
        image = Image.open(img_path)
        image = image.resize((256,256))
        img_nd = np.array(image)

        if len(img_nd.shape) == 2:
            img_nd = np.expand_dims(img_nd, axis=2)

        img_trans = img_nd.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255
        #if self.transform:
        #    image = self.transform(image)

        return {
                'image': torch.from_numpy(img_trans).type(torch.FloatTensor),
                'mask': torch.from_numpy(mask_tensor).type(torch.FloatTensor)
            }
    
    def __get_mask_info__(self, strx):
        sub, act, rout, frame, camera = strx.split("_")
        return f"{sub}_{act}_{rout}", frame, camera

### Dataloaders, K-fold, etc.

In [4]:
def extrair_numero_regex(texto):
    padrao = r'\d+'
    numeros = re.findall(padrao, texto)
    if numeros:
        return numeros[0]
    else:
        return None


def process_json_namefile(strx):

    splited = strx.split("_")
    formated = f"{int(extrair_numero_regex(splited[0]))}_{int(extrair_numero_regex(splited[1]))}_{int(extrair_numero_regex(splited[2]))}"

    if "robot" in strx:
        formated += "_robot"
        
    return formated

    
def load_dataset_masks(pasta):
    # Inicializa o dicionário para armazenar os dados
    dados_json = {}

    # Lista todos os arquivos na pasta
    arquivos = os.listdir(pasta)

    # Filtra apenas os arquivos JSON
    arquivos_json = [arquivo for arquivo in arquivos if arquivo.endswith('.json')]

    # Processa cada arquivo JSON encontrado
    for arquivo_json in tqdm(arquivos_json):
        caminho_completo = os.path.join(pasta, arquivo_json)
        nome_arquivo = os.path.basename(arquivo_json)

        # Carrega o conteúdo do arquivo JSON como um dicionário
        with open(caminho_completo, 'r', encoding='utf-8') as f:
            conteudo = json.load(f)
        
        # Adiciona ao dicionário final usando o nome do arquivo como chave
        dados_json[process_json_namefile(nome_arquivo)] = conteudo
    
    return dados_json


def load_image_paths(directory):
    """Loads all image paths from the specified directory."""
    path = Path(directory)
    image_paths = list(path.glob('*.jpg'))
    return [str(img) for img in image_paths]


def group_images_by_prefix(image_paths):
    """Groups images by their prefix NUMSUBJECT_NUMACTIVITY_NUM_ROUTINE."""
    pattern = re.compile(r'(\d+)_(\d+)_(\d+)_\d+_\d+.jpg')
    grouped = {}
    for img_path in image_paths:
        match = pattern.search(os.path.basename(img_path))
        if match:
            prefix = f"{match.group(1)}_{match.group(2)}_{match.group(3)}"
            if prefix not in grouped:
                grouped[prefix] = []
            grouped[prefix].append(img_path)
    return list(grouped.values())


def create_dataloaders(image_groups, dataset_masks, n_splits=5, batch_size=32, transform=None, mask_mode=None):
    """Creates DataLoaders for cross-validation."""
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    dataloaders = []
    
    for train_index, val_index in kf.split(image_groups):
        train_images = [img for i in train_index for img in image_groups[i]]
        val_images = [img for i in val_index for img in image_groups[i]]
        
        train_dataset = CustomImageDataset(train_images, dataset_masks, transform, mask_mode)
        val_dataset = CustomImageDataset(val_images, dataset_masks, transform, mask_mode)
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        dataloaders.append((train_loader, val_loader))
    
    return dataloaders

def get_sampled_loader(data_loader, sample_percentage):
    # Calcula o tamanho da amostra
    sample_size = int(len(data_loader.dataset) * sample_percentage)
    remaining_size = len(data_loader.dataset) - sample_size
    
    # Divide o dataset em duas partes: amostra e restante
    sample_dataset, _ = random_split(data_loader.dataset, [sample_size, remaining_size])
    
    # Cria um novo DataLoader para a amostra
    sampled_loader = DataLoader(sample_dataset, batch_size=data_loader.batch_size, shuffle=True)
    
    return sampled_loader

### Métricas e avaliação

In [5]:
def transform(tensor):
    return tensor.cpu().numpy().flatten()

def calculate_dice_coefficient(outputs, targets, smooth=1e-6):
    outputs = outputs.argmax(dim=1)  # Convert output logits to class predictions
    num_classes = 3  # Defina o número de classes explicitamente

    dice_per_class = []
    for i in range(num_classes):
        intersection = ((outputs == i) & (targets == i)).float().sum()
        union = (outputs == i).float().sum() + (targets == i).float().sum()
        dice = (2. * intersection + smooth) / (union + smooth)
        dice_per_class.append(dice.item())

    mean_dice = sum(dice_per_class) / num_classes
    return mean_dice, dice_per_class

def calculate_iou(outputs, targets, smooth=1e-6):
    outputs = outputs.argmax(dim=1)  # Convert output logits to class predictions
    num_classes = 3  # Defina o número de classes explicitamente

    iou_per_class = []
    for i in range(num_classes):
        intersection = ((outputs == i) & (targets == i)).float().sum()
        union = ((outputs == i) | (targets == i)).float().sum()
        iou = (intersection + smooth) / (union + smooth)
        iou_per_class.append(iou.item())

    mean_iou = sum(iou_per_class) / num_classes
    return mean_iou, iou_per_class

def calculate_metrics(outputs, targets):
    mean_dice, dice_per_class = calculate_dice_coefficient(outputs, targets)
    mean_iou, iou_per_class = calculate_iou(outputs, targets)
    return mean_dice, mean_iou, dice_per_class, iou_per_class

def evaluation_v2(net, loader, device, pbar):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    mask_type = torch.long
    n_val = len(loader)  # the number of batch
    tot_iou = 0
    tot_dice = 0
    tot = 0
    metrics = []

    with torch.no_grad():
        val_pbar = tqdm(loader, desc=f"Validation...", leave=False)
        
        for batch_idx, batch in enumerate(val_pbar):
            imgs, true_masks = batch['image'], batch['mask']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=mask_type)

            with torch.no_grad():
                mask_pred = net(imgs)

            if True == True:
                true_masks = torch.argmax(true_masks, 1)
                tot += F.cross_entropy(mask_pred, true_masks).item()

                mean_dice, mean_iou, dice_per_class, iou_per_class = calculate_metrics(mask_pred, true_masks)
                metrics.append((mean_dice, mean_iou, dice_per_class, iou_per_class))
                
                out = f'batch {batch_idx}/{len(loader)}'
                
                pbar.set_postfix(**{'Evaluating...': out})
            else:
                pred = torch.sigmoid(mask_pred)
                pred = (pred > 0.5).float()
                
                tot_dice += dice_coeff(pred, true_masks).item()
                tot_iou += iou(transform(pred), transform(true_masks))
    
    val_pbar.set_postfix({"loss": tot / n_val})

    net.train()
    return tot / n_val, metrics#, tot_dice / n_val, tot_iou / n_val


class DiceCoeff(Function):
    """Dice coeff for individual examples"""

    def forward(self, input, target):
        self.save_for_backward(input, target)
        eps = 0.0001
        self.inter = torch.dot(input.view(-1), target.view(-1))
        self.union = torch.sum(input) + torch.sum(target) + eps

        t = (2 * self.inter.float() + eps) / self.union.float()
        return t

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):

        input, target = self.saved_variables
        grad_input = grad_target = None

        if self.needs_input_grad[0]:
            grad_input = grad_output * 2 * (target * self.union - self.inter) \
                         / (self.union * self.union)
        if self.needs_input_grad[1]:
            grad_target = None

        return grad_input, grad_target


def dice_coeff(input, target):
    """Dice coeff for batches"""
    if input.is_cuda:
        s = torch.FloatTensor(1).cuda().zero_()
    else:
        s = torch.FloatTensor(1).zero_()

    for i, c in enumerate(zip(input, target)):
        s = s + DiceCoeff().forward(c[0], c[1])

    return s / (i + 1)


### Novo código de treinamento

In [6]:
def get_last_checkpoint_info(checkpoint_dir):
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    folds = [int(fold_dir.replace('fold', '')) for fold_dir in os.listdir(checkpoint_dir)
             if fold_dir.startswith('fold') and os.path.isdir(os.path.join(checkpoint_dir, fold_dir))]
    if not folds:
        return 0, 0, float('inf')  # No existing folds, start from scratch
    last_fold = max(folds)
    last_checkpoint_path = os.path.join(checkpoint_dir, f"fold{last_fold}", "last_training_loss.pth")
    if os.path.isfile(last_checkpoint_path):
        epoch, fold, best_val_loss = load_checkpoint(None, None, last_checkpoint_path, return_model=False)
        return epoch+1, last_fold-1, best_val_loss
    else:
        return 0, last_fold - 1, float('inf')  # No valid checkpoint, restart the last fold

def load_checkpoint(model, optimizer, checkpoint_path, return_model=True):
    if os.path.isfile(checkpoint_path):
        print(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        if return_model and model is not None and optimizer is not None:
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        fold = checkpoint['fold']
        best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        return epoch, fold, best_val_loss
    else:
        print(f"No checkpoint found at: {checkpoint_path}")
        return 0, 0, float('inf')

def save_checkpoint(model, optimizer, epoch, fold, is_best, checkpoint_dir):
    # Cria o diretório de checkpoints e subdiretório para cada fold, se não existirem
    fold_dir = os.path.join(checkpoint_dir, f"fold{fold+1}")
    os.makedirs(fold_dir, exist_ok=True)
    
    # Caminho para o último checkpoint de treinamento
    last_checkpoint_path = os.path.join(fold_dir, "last_training_loss.pth")
    torch.save({
        'epoch': epoch,
        'fold': fold,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, last_checkpoint_path)

    if is_best:
        best_checkpoint_path = os.path.join(fold_dir, "best_val_loss.pth")
        torch.save({
            'epoch': epoch,
            'fold': fold,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, best_checkpoint_path)
        

In [7]:
def save_metrics_to_csv(loss_train, loss_val, metrics_val, checkpoint_dir, fold, epoch):
    # Cria o diretório de checkpoints, se não existir
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Separar as métricas
    #mean_dice_train, mean_iou_train, dice_per_class_train, iou_per_class_train = zip(*metrics_train)
    mean_dice_val, mean_iou_val, dice_per_class_val, iou_per_class_val = zip(*metrics_val)
    
    # Calcular as médias
    #avg_dice_train = np.mean(mean_dice_train)
    #avg_iou_train = np.mean(mean_iou_train)
    avg_dice_val = np.mean(mean_dice_val)
    avg_iou_val = np.mean(mean_iou_val)
    
    # Preparar os dados para o DataFrame
    data = {
        'epoch': epoch,
        'loss_train': loss_train,
        'loss_val': loss_val,
        #'avg_dice_train': avg_dice_train,
        #'avg_iou_train': avg_iou_train,
        'avg_dice_val': avg_dice_val,
        'avg_iou_val': avg_iou_val,
    }

    # Adicionar métricas por classe
    for class_index in range(3):  # Assumindo 3 classes: fundo, humano, robô
        #data[f'dice_class_{class_index}_train'] = np.mean([d[class_index] for d in dice_per_class_train])
        #data[f'iou_class_{class_index}_train'] = np.mean([i[class_index] for i in iou_per_class_train])
        data[f'dice_class_{class_index}_val'] = np.mean([d[class_index] for d in dice_per_class_val])
        data[f'iou_class_{class_index}_val'] = np.mean([i[class_index] for i in iou_per_class_val])
    
    # Converter para DataFrame
    metrics_df = pd.DataFrame([data])

    csv_path = os.path.join(checkpoint_dir, f"fold{fold+1}_report.csv")
    if not os.path.isfile(csv_path):
        metrics_df.to_csv(csv_path, index=False)
    else:
        metrics_df.to_csv(csv_path, mode='a', header=False, index=False)

In [8]:
def train(model, device, dataloaders, epochs, criterion, learning_rate, dir_checkpoint, sampling):
    
    start_epoch, start_fold, best_val_loss = get_last_checkpoint_info(dir_checkpoint)
    
    for fold in range(start_fold, len(dataloaders)):
        
        net = copy.deepcopy(model).to(device)
        optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
        #net = model

        train_loader, val_loader = dataloaders[fold]
        
        if sampling is not None:
            train_loader = get_sampled_loader(train_loader, sampling)
            val_loader = get_sampled_loader(val_loader, sampling)
        
        if fold == start_fold and start_epoch > 0:
            start_epoch, _, best_val_loss = load_checkpoint(model, optimizer, os.path.join(dir_checkpoint, f"fold{fold+1}", "last_training_loss.pth"))
            best_loss = best_val_loss
        else:
            start_epoch = 0
            best_loss = 32000000

        global_step = 0
        
        
        for epoch in range(start_epoch, epochs):
            net.train()

            epoch_loss = 0
            with tqdm(total=len(train_loader.dataset), desc=f'Fold {fold+1}/{len(dataloaders)} | Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
                for batch in train_loader:
                    imgs = batch['image']
                    true_masks = batch['mask']
                    #print(imgs.shape,true_masks.shape)

                    imgs = imgs.to(device=device, dtype=torch.float32)
                    mask_type = torch.long
                    true_masks = true_masks.to(device=device, dtype=mask_type)

                    masks_pred = net(imgs)
                    true_masks = torch.argmax(true_masks, 1)

                    loss = criterion(masks_pred, true_masks)
                    epoch_loss += loss.item()

                    pbar.set_postfix(**{'loss (batch)': loss.item()})

                    optimizer.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_value_(net.parameters(), 0.1)
                    optimizer.step()

                    pbar.update(imgs.shape[0])
                    global_step += 1
                    

                val_lss, val_metrics = evaluation_v2(net, val_loader, device, pbar)

            if val_lss < best_loss:
                best_loss = val_lss
                save_checkpoint(model, optimizer, epoch, fold, True, dir_checkpoint)

            save_checkpoint(model, optimizer, epoch, fold, True, dir_checkpoint)
            save_metrics_to_csv(epoch_loss/len(train_loader.dataset), val_lss, val_metrics, dir_checkpoint, fold, epoch)
        
        

### Main

In [9]:
############## QUAL PC ESTOU USANDO???
PC = 'home'

if PC == 'home':
    batch_size = 4
    image_directory = "C:/Users/iagor/Documents/git/data-definer/out/"
    mask_directory = "C:/Users/iagor/Documents/git/human-segmentation-sam/out/"
else: 
    batch_size = 256
    image_directory = "/home/iago/PhD/data-definer/out/"
    mask_directory = "/home/iago/PhD/segment-humans-sam/out/"

In [10]:
############## OUTRAS VARIAVEIS
n_splits = 5
num_epochs = 50
learning_rate = 0.0001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sampling = None

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [None]:
############## DATASET
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()  # Converte automaticamente para o intervalo [0, 1]
])

dataset_masks = load_dataset_masks(mask_directory)
mask_mode = "entropy"
image_paths = load_image_paths(image_directory)
image_groups = group_images_by_prefix(image_paths)
dataloaders = create_dataloaders(image_groups, dataset_masks, n_splits, batch_size, transform, mask_mode) 

In [None]:
model = UNet(n_channels=3, n_classes=3, bilinear=True)
model.to(device)
model_name = "unet"

checkpoint_dir = f"checkpoints/{model_name}/"
criterion = nn.CrossEntropyLoss()

train(model, device, dataloaders, num_epochs, criterion, learning_rate, checkpoint_dir, 0.1)

### Testes

In [None]:
def plot_images_and_masks(dataloader):
    # Obtenha o primeiro batch do dataloader
    dataiter = iter(dataloader)
    images, masks = next(dataiter)
    
    batch_size = images.size(0)
    
    fig, axs = plt.subplots(batch_size, 2, figsize=(10, batch_size * 5))
    
    for i in range(batch_size):
        # Converta o tensor da imagem para numpy para exibição
        img = images[i].permute(1, 2, 0).cpu().numpy()
        mask = masks[i].squeeze(0).cpu().numpy()  # Remove a dimensão extra da máscara

        print(img.shape)
        print(f"{mask.shape} - {np.unique(mask)}")
        
        # Normalize the image for display purposes (optional)
        img = img - img.min()
        img = img / img.max()
        
        if batch_size == 1:
            axs[0].imshow(img)
            axs[0].set_title(f"Image {i + 1}")
            axs[0].axis('off')
            
            axs[1].imshow(mask, cmap='gray')
            axs[1].set_title(f"Mask {i + 1}")
            axs[1].axis('off')
        else:
            axs[i, 0].imshow(img)
            axs[i, 0].set_title(f"Image {i + 1}")
            axs[i, 0].axis('off')
            
            axs[i, 1].imshow(mask, cmap='gray')
            axs[i, 1].set_title(f"Mask {i + 1}")
            axs[i, 1].axis('off')
    
    plt.show()


In [None]:
plot_images_and_masks(dataloaders[0][0])

In [None]:
dataiter = iter(dataloaders[0][0])
images, labels = next(dataiter)

In [None]:
torch.unique(labels[0])

In [None]:
train_loader = get_sampled_loader(dataloaders[0][0], sampling)

In [None]:
train_loader, val_loader = dataloaders[0]

In [None]:
dataiter = iter(train_loader)  # substitua `train_loader` pelo seu DataLoader
images, labels = next(dataiter)

In [None]:
for im in images:
    print(im.shape)
    print(torch.unique(im))

In [None]:
labels.shape

In [None]:
def imshou(img):
    img = img / 2 + 0.5  # Dessormalizar (se as imagens foram normalizadas)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
import torchvision
imshou(torchvision.utils.make_grid(images))

In [None]:
imshou(torchvision.utils.make_grid(labels))

In [None]:
labels[0].shape

In [None]:
torch.unique(labels[0])