In [1]:
import torch
from meu_dataset import MeuDataset,avaliar_descritor,calcular_matching
from teste_util import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
path_dataset = "./data/datasets/features_path_flowers_dataset.pt"
# Carregar o dataset do arquivo "meu_dataset.pt"
meu_dataset = MeuDataset.load_from_file(path_dataset)
#verificar se o objeto meu dataset está retornando o tensor correto
assert isinstance(meu_dataset,MeuDataset), 'o tipo de retorno não é MeuDataset'
assert isinstance(meu_dataset[0][0],torch.Tensor), 'o tipo de retorno não é torch.Tensor'


In [2]:
import gc
from torch.optim.lr_scheduler import ExponentialLR
gc.collect()
torch.cuda.empty_cache()
batch_size_siam = 50
from torch.utils.data import random_split, DataLoader


train_dataset, val_dataset, test_dataset = random_split(meu_dataset, [0.4,0.4,0.2])

# Crie uma instância do DataLoader usando seu conjunto de dados personalizado
dataloader_train = DataLoader(train_dataset, batch_size=batch_size_siam, shuffle=False)
dataloader_val = DataLoader(val_dataset, batch_size=batch_size_siam, shuffle=False)
dataloader_test = DataLoader(test_dataset, batch_size=batch_size_siam, shuffle=False)

In [3]:
from tqdm import tqdm
import random

# Cosine similarity function
def cosine_similarity(a, b):
    a_norm = torch.nn.functional.normalize(a, dim=-1)
    b_norm = torch.nn.functional.normalize(b, dim=-1)
    return torch.mm(a_norm, b_norm.T)

def my_similarity(a, b):
    a_norm = torch.nn.functional.normalize(a, dim=-1)
    b_norm = torch.nn.functional.normalize(b, dim=-1)
    return torch.cdist(a_norm, b_norm, p=2)

# Triplet loss function
def triplet_loss(anchor, positive, negative, margin=0.2):
    similarities = cosine_similarity(anchor, positive)
    # Calcular a média da diagonal principal (âncoras vs. seus respectivos positivos)
    mean_diagonal = torch.mean(torch.diagonal(similarities))
    # Calcular a média dos outros elementos (âncoras vs. positivos não correspondentes)
    mean_other = torch.mean(similarities[~torch.eye(similarities.shape[0], dtype=torch.bool)])
    
    # losses = torch.relu(mean_other - mean_diagonal + margin)# losses considerando similaridade de cosseno
    losses = torch.relu((1 - mean_diagonal) + mean_other)# losses considerando similaridade de cosseno
    # losses = torch.relu(mean_diagonal - mean_other + margin)# my_similarity
    # print(losses,mean_diagonal,mean_other)
    return losses,mean_other,mean_diagonal


def train_one_epoch(model, data_loader, optimizer, loss_fn, device='cpu', is_training=True):
    model.train(is_training)
    total_loss = 0.

    progress_bar = tqdm(data_loader)
    for idx, data in enumerate(progress_bar):
        # Extract the anchor and positive batches
        anchor_batch, positive_batch = (
            data[0].to(device),
            data[1].to(device),
        )

        # Calculate descriptors for the anchor and positive images
        descs_anchor = model(anchor_batch)
        descs_pos = model(positive_batch)

        # Calculate the triplet loss
        loss,negativo,positivo = loss_fn(descs_anchor, descs_pos, None, margin=0.7)

        if is_training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        progress_bar.set_description(f'Loss: {loss.item()}-{negativo.item()}-{positivo.item()} - Total Loss: {total_loss/len(data_loader)}')

    return total_loss/len(data_loader)


In [45]:
from typing import List, Optional
import torch
from e2cnn import gspaces
from e2cnn import nn as enn    #the equivariant layer we need to build the model
import torch.nn.functional as F
from torch import nn
from typing_extensions import TypedDict

from kornia.core import Module, Tensor, concatenate
from kornia.filters import SpatialGradient
from kornia.geometry.transform import pyrdown
from kornia.utils.helpers import map_location_to_cpu

from kornia.feature.scale_space_detector import get_default_detector_config, MultiResolutionDetector,Detector_config


class KeyNet_conf(TypedDict):
    num_filters: int
    num_levels: int
    kernel_size: int
    Detector_conf: Detector_config


class _FeatureExtractor(Module):
    def __init__(self) -> None:
        super().__init__()

        self.hc_block = _HandcraftedBlock()
        self.lb_block = _LearnableBlock()

    def forward(self, x: Tensor) -> Tensor:
        x_hc = self.hc_block(x)
        x_lb = self.lb_block(x_hc)
        return x_lb


class _HandcraftedBlock(Module):
    def __init__(self) -> None:
        super().__init__()
        self.spatial_gradient = SpatialGradient('sobel', 1)

    def forward(self, x: Tensor) -> Tensor:
        sobel = self.spatial_gradient(x)
        dx, dy = sobel[:, :, 0, :, :], sobel[:, :, 1, :, :]

        sobel_dx = self.spatial_gradient(dx)
        dxx, dxy = sobel_dx[:, :, 0, :, :], sobel_dx[:, :, 1, :, :]

        sobel_dy = self.spatial_gradient(dy)
        dyy = sobel_dy[:, :, 1, :, :]

        hc_feats = concatenate([dx, dy, dx**2.0, dy**2.0, dx * dy, dxy, dxy**2.0, dxx, dyy, dxx * dyy], 1)

        return hc_feats


def _KeyNetConvBlock(
    feat_type_in,
    feat_type_out,
    r2_act,
    kernel_size: int = 5,
    stride: int = 1,
    padding: int = 2,
    dilation: int = 1,
) -> nn.Sequential:
    return enn.SequentialModule(
            enn.R2Conv(feat_type_in, feat_type_out, kernel_size=kernel_size, padding=padding, bias=False),
            enn.InnerBatchNorm(feat_type_out),
            enn.ReLU(feat_type_out, inplace=True),
        )


class _LearnableBlock(nn.Sequential):
    def __init__(self, in_channels: int = 10, out_channels: int = 8, group_size=8) -> None:
        super().__init__()
        r2_act = gspaces.Rot2dOnR2(N=group_size)

        feat_type_in = enn.FieldType(r2_act, in_channels * [r2_act.trivial_repr])
        self.in_type = feat_type_in
        feat_type_out = enn.FieldType(r2_act, out_channels * [r2_act.regular_repr])
        self.block0 = _KeyNetConvBlock(feat_type_in, feat_type_out, r2_act)

        feat_type_out = enn.FieldType(r2_act, out_channels * [r2_act.regular_repr])
        self.block1 = _KeyNetConvBlock(self.block0.out_type, feat_type_out, r2_act)

        feat_type_out = enn.FieldType(r2_act, out_channels * [r2_act.regular_repr])
        self.block2 = _KeyNetConvBlock(self.block1.out_type, feat_type_out, r2_act)
        self.gpool = enn.GroupPooling(self.block2.out_type)

    def forward(self, x: Tensor) -> Tensor:
        x = enn.GeometricTensor(x, self.in_type)
        x = self.block0(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.gpool(x)
        # print("after gpool", x.shape)
        return x.tensor

class SoftHistogramLayer(nn.Module):
    def __init__(self, n_bin=5):
        super(SoftHistogramLayer, self).__init__()
        self.n_bin = n_bin

    def forward(self, x):
        batch_size, num_channels, height, width = x.shape
        out = x.view(batch_size, num_channels, -1)  # Redimensiona 'x' para um tensor 3D
        
        bin_edges = torch.linspace(0, 10, self.n_bin + 1, device=x.device)
        hist = []
        
        for i in range(batch_size):  # Itera sobre as instâncias no lote
            channel_hists = []
            
            for c in range(num_channels):  # Itera sobre os canais
                bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
                
                # Calcula o histograma suave para o canal 'c' da instância atual 'i'
                soft_hist = torch.exp(-(out[i, c, :, None] - bin_centers[None, :])**2 / 0.1).sum(dim=0)
                channel_hists.append(soft_hist)
            
            channel_hist_tensor = torch.stack(channel_hists)  # Empilha os histogramas suaves dos canais
            hist.append(channel_hist_tensor)
        
        hist = torch.stack(hist)  # Empilha os histogramas suaves das instâncias
        return hist.reshape(batch_size, -1)  # Redimensiona para um tensor 2D
    

class KeyNet(Module):
    def __init__(self, pretrained: bool = False, keynet_conf: KeyNet_conf = keynet_default_config) -> None:
        super().__init__()

        num_filters = keynet_conf['num_filters']
        self.num_levels = keynet_conf['num_levels']
        kernel_size = keynet_conf['kernel_size']
        padding = kernel_size // 2
        print("KeyNet config: ", keynet_conf)
        self.feature_extractor = _FeatureExtractor()
        
        self.last_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=num_filters * self.num_levels, out_channels=self.num_levels, kernel_size=kernel_size, padding=padding
            ),
            nn.ReLU(inplace=True),
        )
        self.softHistogramLayer = SoftHistogramLayer()
        if pretrained:
            print("KeyNet loaded")


    def forward(self, x: Tensor) -> Tensor:
        B, C, H, W = x.shape
        x= x.reshape(B*C, 1, H, W)
        shape_im = x.shape
        feats: List[Tensor] = [self.feature_extractor(x)]
        
        for i in range(1, self.num_levels):
            x = pyrdown(x, factor=1.2)
            feats_i = self.feature_extractor(x)
            feats_i = F.interpolate(feats_i, size=(shape_im[2], shape_im[3]), mode='bilinear')
            feats.append(feats_i)
        scores = self.last_conv(concatenate(feats, 1))
        # print("after pyramid", scores.shape)
        scores= scores.reshape(B,C*scores.shape[1],scores.shape[2],scores.shape[3])
        hist = self.softHistogramLayer(scores)
        # print("after hist", hist.shape)
        return hist
    
keynet_default_config: KeyNet_conf = {
    # Key.Net Model
    'num_filters': 8,
    'num_levels': 3,
    'kernel_size': 5,
    # Extraction Parameters
    'Detector_conf': {'nms_size': 5, 'pyramid_levels': 2, 'up_levels': 1, 'scale_factor_levels': 1.3, 's_mult': 20.0},
}

model =KeyNet().to(device)

KeyNet config:  {'num_filters': 8, 'num_levels': 3, 'kernel_size': 5, 'Detector_conf': {'nms_size': 5, 'pyramid_levels': 2, 'up_levels': 1, 'scale_factor_levels': 1.3, 's_mult': 20.0}}


In [23]:
gc.collect()
torch.cuda.empty_cache()

def save_checkpoint(model, optimizer, epoch, loss, path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss,
        # Adicione outras informações que você deseja salvar, como hiperparâmetros, configurações, etc.
    }
    torch.save(checkpoint, path)
    
def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    # Outras informações que você salvou no dicionário de checkpoint podem ser acessadas aqui
    return model, optimizer, epoch, loss


In [50]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epoch_i=-1
model_keynet = KeyNet().to(device)
optimizer = optim.Adam(model_keynet.parameters(), lr=0.001, weight_decay=0.0001)
PATH_MODEL = './data/models/feature_flowers_sp.pt'
model_keynet, optimizer, epoch_i, loss =load_checkpoint(model_keynet, optimizer,PATH_MODEL)
print("epoch_i ",epoch_i,"loss ",loss)

gc.collect()
torch.cuda.empty_cache()
epochs=300
i_epoch = 0
loss = 0

scheduler = ExponentialLR(optimizer, gamma=0.75)

KeyNet config:  {'num_filters': 8, 'num_levels': 3, 'kernel_size': 5, 'Detector_conf': {'nms_size': 5, 'pyramid_levels': 2, 'up_levels': 1, 'scale_factor_levels': 1.3, 's_mult': 20.0}}
epoch_i  14 loss  0.5196015781164169


In [51]:
gc.collect()
torch.cuda.empty_cache()

def train_with_early_stopping(model, trainloader, testloader, criterion_d, optimizer, scheduler, device, epochs=100, patience=20):
    best_loss = float('inf')
    best_model = None
    epochs_without_improvement = 0

    for epoch in range(epoch_i+1,epochs):
        # Atualizar a taxa de aprendizado
        if (epoch % 16 == 0) and (epoch != 0):
            scheduler.step()
            
        running_loss = train_one_epoch(model, data_loader=trainloader, loss_fn=criterion_d,  optimizer=optimizer, device=device,is_training=True)

        with torch.no_grad():
            loss_test = train_one_epoch(model, data_loader=testloader, loss_fn=criterion_d,  optimizer=None, device=device,is_training=False)

        # Verificar se a perda melhorou
        if loss_test < best_loss:
            best_loss = loss_test
            epochs_without_improvement = 0
            best_model = model.state_dict()            
            save_checkpoint(model=model, epoch=epoch, optimizer=optimizer, loss=loss_test, path=PATH_MODEL)
            print("salvou no colab")
        else:
            epochs_without_improvement += 1

        # Verificar a condição de parada
        if epochs_without_improvement == patience:
            print(f"No improvement in loss for {epochs_without_improvement} epochs. Training stopped.")
            break

        print(f"Epoch [{epoch}/{epochs}] - Running Loss: {running_loss:.4f}, Test Loss: {loss_test:.4f}, Initial LR: {optimizer.param_groups[0]['initial_lr']:.6f}, Current LR: {optimizer.param_groups[0]['lr']:.6f}, Epochs without Improvement: {epochs_without_improvement}")

    # Carregar a melhor configuração do modelo
    model.load_state_dict(best_model)
    print(f'Epoch: {epoch}, Best Loss: {best_loss:.4f}')

train_with_early_stopping(model_keynet.to(device), dataloader_train, dataloader_val, triplet_loss, optimizer, scheduler, device, epochs=epochs, patience=100)

Loss: 0.5744235515594482-0.40661218762397766-0.8321886658668518 - Total Loss: 0.5069975882768631: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it]  
Loss: 0.5607937574386597-0.32170647382736206-0.7609127163887024 - Total Loss: 0.5123343616724014: 100%|██████████| 50/50 [00:18<00:00,  2.64it/s]  


salvou no colab
Epoch [15/300] - Running Loss: 0.5070, Test Loss: 0.5123, Initial LR: 0.001000, Current LR: 0.001000, Epochs without Improvement: 0


Loss: 0.43879708647727966-0.3023764193058014-0.8635793328285217 - Total Loss: 0.4882033705711365: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]  
Loss: 0.5288327932357788-0.33927690982818604-0.8104441165924072 - Total Loss: 0.5172278141975403: 100%|██████████| 50/50 [00:18<00:00,  2.63it/s]  


Epoch [16/300] - Running Loss: 0.4882, Test Loss: 0.5172, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 1


Loss: 0.4607101380825043-0.30224713683128357-0.8415369987487793 - Total Loss: 0.47472964465618134: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it] 
Loss: 0.528831958770752-0.3325652778148651-0.8037333488464355 - Total Loss: 0.5134063965082168: 100%|██████████| 50/50 [00:19<00:00,  2.59it/s]    


Epoch [17/300] - Running Loss: 0.4747, Test Loss: 0.5134, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 2


Loss: 0.5172823071479797-0.3374578356742859-0.8201755285263062 - Total Loss: 0.47157160341739657: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it]  
Loss: 0.5286083221435547-0.33535006642341614-0.8067417144775391 - Total Loss: 0.5062222272157669: 100%|██████████| 50/50 [00:19<00:00,  2.59it/s]  


salvou no colab
Epoch [18/300] - Running Loss: 0.4716, Test Loss: 0.5062, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 0


Loss: 0.5390344858169556-0.3560433089733124-0.8170087933540344 - Total Loss: 0.45675212383270264: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]  
Loss: 0.5273086428642273-0.3243125081062317-0.7970038652420044 - Total Loss: 0.5040867656469346: 100%|██████████| 50/50 [00:18<00:00,  2.65it/s]   


salvou no colab
Epoch [19/300] - Running Loss: 0.4568, Test Loss: 0.5041, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 0


Loss: 0.5426614284515381-0.3277384042739868-0.7850769758224487 - Total Loss: 0.46327581226825715: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it]  
Loss: 0.5178435444831848-0.3194006681442261-0.8015571236610413 - Total Loss: 0.5008743500709534: 100%|██████████| 50/50 [00:19<00:00,  2.63it/s]   


salvou no colab
Epoch [20/300] - Running Loss: 0.4633, Test Loss: 0.5009, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 0


Loss: 0.5273089408874512-0.39751091599464417-0.8702019453048706 - Total Loss: 0.45774599254131315: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it] 
Loss: 0.5787725448608398-0.3308822810649872-0.752109706401825 - Total Loss: 0.5151174640655518: 100%|██████████| 50/50 [00:18<00:00,  2.65it/s]   


Epoch [21/300] - Running Loss: 0.4577, Test Loss: 0.5151, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 1


Loss: 0.570215106010437-0.3254562020301819-0.7552410960197449 - Total Loss: 0.47116585314273834: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it]   
Loss: 0.5387915968894958-0.3906957507133484-0.8519041538238525 - Total Loss: 0.5347198641300202: 100%|██████████| 50/50 [00:18<00:00,  2.64it/s]   


Epoch [22/300] - Running Loss: 0.4712, Test Loss: 0.5347, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 2


Loss: 0.4750228524208069-0.31447750329971313-0.8394546508789062 - Total Loss: 0.4579558080434799: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]  
Loss: 0.527934193611145-0.3388371765613556-0.810903012752533 - Total Loss: 0.5095388036966324: 100%|██████████| 50/50 [00:19<00:00,  2.61it/s]     


Epoch [23/300] - Running Loss: 0.4580, Test Loss: 0.5095, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 3


Loss: 0.543773889541626-0.3219277560710907-0.7781538963317871 - Total Loss: 0.4491182243824005: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]    
Loss: 0.5483936071395874-0.3968830406665802-0.8484894633293152 - Total Loss: 0.5370275682210922: 100%|██████████| 50/50 [00:18<00:00,  2.64it/s]   


Epoch [24/300] - Running Loss: 0.4491, Test Loss: 0.5370, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 4


Loss: 0.4651140868663788-0.3039349615573883-0.8388208746910095 - Total Loss: 0.4488437509536743: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it]   
Loss: 0.5274291038513184-0.342195600271225-0.8147664666175842 - Total Loss: 0.513924548625946: 100%|██████████| 50/50 [00:19<00:00,  2.62it/s]     


Epoch [25/300] - Running Loss: 0.4488, Test Loss: 0.5139, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 5


Loss: 0.5348994731903076-0.302739679813385-0.7678402066230774 - Total Loss: 0.4417275655269623: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]    
Loss: 0.5515405535697937-0.36713701486587524-0.8155964612960815 - Total Loss: 0.526501151919365: 100%|██████████| 50/50 [00:18<00:00,  2.66it/s]  


Epoch [26/300] - Running Loss: 0.4417, Test Loss: 0.5265, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 6


Loss: 0.5222067832946777-0.30622193217277527-0.7840151786804199 - Total Loss: 0.4345950037240982: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it]  
Loss: 0.5615212917327881-0.3777438998222351-0.816222608089447 - Total Loss: 0.5361395877599716: 100%|██████████| 50/50 [00:19<00:00,  2.61it/s]    


Epoch [27/300] - Running Loss: 0.4346, Test Loss: 0.5361, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 7


Loss: 0.4732568562030792-0.2956559956073761-0.8223991394042969 - Total Loss: 0.44121892273426055: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]  
Loss: 0.5156468749046326-0.3607562780380249-0.8451094031333923 - Total Loss: 0.5295720797777176: 100%|██████████| 50/50 [00:19<00:00,  2.61it/s]   


Epoch [28/300] - Running Loss: 0.4412, Test Loss: 0.5296, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 8


Loss: 0.4156001806259155-0.28417789936065674-0.8685777187347412 - Total Loss: 0.4523766976594925: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it]  
Loss: 0.5022276043891907-0.324989914894104-0.8227623105049133 - Total Loss: 0.517915860414505: 100%|██████████| 50/50 [00:18<00:00,  2.65it/s]     


Epoch [29/300] - Running Loss: 0.4524, Test Loss: 0.5179, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 9


Loss: 0.4398280382156372-0.3017047643661499-0.8618767261505127 - Total Loss: 0.43865441262722016: 100%|██████████| 50/50 [01:00<00:00,  1.20s/it]  
Loss: 0.5273064970970154-0.3301599621772766-0.8028534650802612 - Total Loss: 0.5162342518568039: 100%|██████████| 50/50 [00:18<00:00,  2.66it/s]   


Epoch [30/300] - Running Loss: 0.4387, Test Loss: 0.5162, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 10


Loss: 0.44526243209838867-0.3045138120651245-0.8592513799667358 - Total Loss: 0.44651614427566527: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it] 
Loss: 0.5374518036842346-0.320218563079834-0.7827667593955994 - Total Loss: 0.5134093445539475: 100%|██████████| 50/50 [00:19<00:00,  2.62it/s]   


Epoch [31/300] - Running Loss: 0.4465, Test Loss: 0.5134, Initial LR: 0.001000, Current LR: 0.000750, Epochs without Improvement: 11


Loss: 0.4850585162639618-0.30933281779289246-0.8242743015289307 - Total Loss: 0.45617332935333255: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it] 
Loss: 0.5732824802398682-0.3691391944885254-0.7958567142486572 - Total Loss: 0.5294667112827302: 100%|██████████| 50/50 [00:18<00:00,  2.64it/s]   


Epoch [32/300] - Running Loss: 0.4562, Test Loss: 0.5295, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 12


Loss: 0.46553024649620056-0.2910479009151459-0.8255176544189453 - Total Loss: 0.4414481008052826: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it]  
Loss: 0.5405417680740356-0.35242971777915955-0.8118879795074463 - Total Loss: 0.5276267087459564: 100%|██████████| 50/50 [00:19<00:00,  2.61it/s] 


Epoch [33/300] - Running Loss: 0.4414, Test Loss: 0.5276, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 13


Loss: 0.46700355410575867-0.2953462302684784-0.8283426761627197 - Total Loss: 0.4262623596191406: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]  
Loss: 0.5754730701446533-0.36426806449890137-0.788794994354248 - Total Loss: 0.5348997306823731: 100%|██████████| 50/50 [00:18<00:00,  2.65it/s]   


Epoch [34/300] - Running Loss: 0.4263, Test Loss: 0.5349, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 14


Loss: 0.46339818835258484-0.3002620041370392-0.8368638157844543 - Total Loss: 0.42753719449043276: 100%|██████████| 50/50 [01:00<00:00,  1.20s/it] 
Loss: 0.5809815526008606-0.36512309312820435-0.7841415405273438 - Total Loss: 0.5373654174804687: 100%|██████████| 50/50 [00:18<00:00,  2.67it/s]  


Epoch [35/300] - Running Loss: 0.4275, Test Loss: 0.5374, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 15


Loss: 0.44899606704711914-0.2944301962852478-0.8454341292381287 - Total Loss: 0.41979512870311736: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it] 
Loss: 0.5801488161087036-0.369377464056015-0.7892286777496338 - Total Loss: 0.540206510424614: 100%|██████████| 50/50 [00:18<00:00,  2.64it/s]     


Epoch [36/300] - Running Loss: 0.4198, Test Loss: 0.5402, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 16


Loss: 0.4452815055847168-0.2888615131378174-0.8435800075531006 - Total Loss: 0.43264119029045106: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]  
Loss: 0.5650601387023926-0.36477598547935486-0.7997158169746399 - Total Loss: 0.5349340808391571: 100%|██████████| 50/50 [00:19<00:00,  2.63it/s] 


Epoch [37/300] - Running Loss: 0.4326, Test Loss: 0.5349, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 17


Loss: 0.42725104093551636-0.29101186990737915-0.8637608289718628 - Total Loss: 0.43142612397670743: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]
Loss: 0.5410904884338379-0.3486175835132599-0.8075270652770996 - Total Loss: 0.5265320730209351: 100%|██████████| 50/50 [00:19<00:00,  2.62it/s]   


Epoch [38/300] - Running Loss: 0.4314, Test Loss: 0.5265, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 18


Loss: 0.44110047817230225-0.2948395013809204-0.8537390232086182 - Total Loss: 0.43935995995998384: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it] 
Loss: 0.5315758585929871-0.3505457639694214-0.8189699053764343 - Total Loss: 0.5250807619094848: 100%|██████████| 50/50 [00:19<00:00,  2.62it/s]  


Epoch [39/300] - Running Loss: 0.4394, Test Loss: 0.5251, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 19


Loss: 0.45460522174835205-0.3028622269630432-0.8482570052146912 - Total Loss: 0.42199732184410094: 100%|██████████| 50/50 [01:00<00:00,  1.20s/it] 
Loss: 0.5424999594688416-0.3579944968223572-0.8154945373535156 - Total Loss: 0.5321207666397094: 100%|██████████| 50/50 [00:18<00:00,  2.64it/s]   


Epoch [40/300] - Running Loss: 0.4220, Test Loss: 0.5321, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 20


Loss: 0.43722954392433167-0.2947940528392792-0.8575645089149475 - Total Loss: 0.41580452501773835: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it] 
Loss: 0.537058413028717-0.3550056219100952-0.8179472088813782 - Total Loss: 0.5343366062641144: 100%|██████████| 50/50 [00:18<00:00,  2.67it/s]    


Epoch [41/300] - Running Loss: 0.4158, Test Loss: 0.5343, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 21


Loss: 0.4430925250053406-0.30929088592529297-0.8661983609199524 - Total Loss: 0.41152568757534025: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it] 
Loss: 0.5528616905212402-0.36054661870002747-0.8076849579811096 - Total Loss: 0.5363821685314178: 100%|██████████| 50/50 [00:18<00:00,  2.65it/s] 


Epoch [42/300] - Running Loss: 0.4115, Test Loss: 0.5364, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 22


Loss: 0.4167897701263428-0.29788893461227417-0.8810991644859314 - Total Loss: 0.41217404723167417: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it] 
Loss: 0.5580042600631714-0.36985746026039124-0.8118531703948975 - Total Loss: 0.5383462864160538: 100%|██████████| 50/50 [00:18<00:00,  2.66it/s]  


Epoch [43/300] - Running Loss: 0.4122, Test Loss: 0.5383, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 23


Loss: 0.43316006660461426-0.2878735065460205-0.8547134399414062 - Total Loss: 0.4225130134820938: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it]  
Loss: 0.5147054195404053-0.3376592993736267-0.8229538798332214 - Total Loss: 0.5213553595542908: 100%|██████████| 50/50 [00:18<00:00,  2.68it/s]   


Epoch [44/300] - Running Loss: 0.4225, Test Loss: 0.5214, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 24


Loss: 0.42277106642723083-0.2852703034877777-0.8624992370605469 - Total Loss: 0.43202839910984037: 100%|██████████| 50/50 [01:00<00:00,  1.20s/it] 
Loss: 0.49908027052879333-0.3330977261066437-0.8340174555778503 - Total Loss: 0.5124785780906678: 100%|██████████| 50/50 [00:18<00:00,  2.69it/s]  


Epoch [45/300] - Running Loss: 0.4320, Test Loss: 0.5125, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 25


Loss: 0.42305055260658264-0.29923608899116516-0.8761855363845825 - Total Loss: 0.4138657820224762: 100%|██████████| 50/50 [01:00<00:00,  1.20s/it] 
Loss: 0.4869813621044159-0.3323942720890045-0.8454129099845886 - Total Loss: 0.5129827708005905: 100%|██████████| 50/50 [00:18<00:00,  2.67it/s]   


Epoch [46/300] - Running Loss: 0.4139, Test Loss: 0.5130, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 26


Loss: 0.4133896827697754-0.28977417945861816-0.8763844966888428 - Total Loss: 0.40569970846176145: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it] 
Loss: 0.5156035423278809-0.3298164904117584-0.8142129778862 - Total Loss: 0.5203154748678207: 100%|██████████| 50/50 [00:18<00:00,  2.65it/s]      


Epoch [47/300] - Running Loss: 0.4057, Test Loss: 0.5203, Initial LR: 0.001000, Current LR: 0.000563, Epochs without Improvement: 27


Loss: 0.40853604674339294-0.2864358127117157-0.8778997659683228 - Total Loss: 0.40016846120357513: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it] 
Loss: 0.48376205563545227-0.32618871331214905-0.8424266576766968 - Total Loss: 0.5153662109375: 100%|██████████| 50/50 [00:18<00:00,  2.66it/s]    


Epoch [48/300] - Running Loss: 0.4002, Test Loss: 0.5154, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 28


Loss: 0.406627893447876-0.289706289768219-0.883078396320343 - Total Loss: 0.39282083928585054: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]     
Loss: 0.5007081031799316-0.3237544000148773-0.8230462670326233 - Total Loss: 0.5136004132032395: 100%|██████████| 50/50 [00:19<00:00,  2.62it/s]   


Epoch [49/300] - Running Loss: 0.3928, Test Loss: 0.5136, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 29


Loss: 0.3948902487754822-0.2907898426055908-0.8958995938301086 - Total Loss: 0.3901545476913452: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]    
Loss: 0.49263137578964233-0.3194350004196167-0.8268036246299744 - Total Loss: 0.5126889914274215: 100%|██████████| 50/50 [00:18<00:00,  2.68it/s]  


Epoch [50/300] - Running Loss: 0.3902, Test Loss: 0.5127, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 30


Loss: 0.412434458732605-0.2992972731590271-0.8868628144264221 - Total Loss: 0.3907855886220932: 100%|██████████| 50/50 [01:00<00:00,  1.20s/it]    
Loss: 0.4932484030723572-0.3077065348625183-0.8144581317901611 - Total Loss: 0.5023564672470093: 100%|██████████| 50/50 [00:18<00:00,  2.66it/s]   


Epoch [51/300] - Running Loss: 0.3908, Test Loss: 0.5024, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 31


Loss: 0.3971066474914551-0.2946329712867737-0.8975263237953186 - Total Loss: 0.3865723127126694: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it]   
Loss: 0.5055215358734131-0.30323734879493713-0.7977158427238464 - Total Loss: 0.5036238813400269: 100%|██████████| 50/50 [00:18<00:00,  2.69it/s]  


Epoch [52/300] - Running Loss: 0.3866, Test Loss: 0.5036, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 32


Loss: 0.3780144155025482-0.28855469822883606-0.9105402827262878 - Total Loss: 0.3978740870952606: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]  
Loss: 0.5013300180435181-0.3113080561161041-0.8099780082702637 - Total Loss: 0.508625196814537: 100%|██████████| 50/50 [00:18<00:00,  2.68it/s]   


Epoch [53/300] - Running Loss: 0.3979, Test Loss: 0.5086, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 33


Loss: 0.39469626545906067-0.30212220549583435-0.9074259400367737 - Total Loss: 0.40946728229522705: 100%|██████████| 50/50 [01:00<00:00,  1.20s/it]
Loss: 0.5094557404518127-0.29113513231277466-0.7816793918609619 - Total Loss: 0.5029063200950623: 100%|██████████| 50/50 [00:18<00:00,  2.64it/s]  


Epoch [54/300] - Running Loss: 0.4095, Test Loss: 0.5029, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 34


Loss: 0.39761021733283997-0.30811575055122375-0.9105055332183838 - Total Loss: 0.41098669052124026: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]
Loss: 0.4920472204685211-0.29274752736091614-0.800700306892395 - Total Loss: 0.5002692884206772: 100%|██████████| 50/50 [00:18<00:00,  2.65it/s]  


salvou no colab
Epoch [55/300] - Running Loss: 0.4110, Test Loss: 0.5003, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 0


Loss: 0.3851945996284485-0.2873743176460266-0.9021797180175781 - Total Loss: 0.407495197057724: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]    
Loss: 0.495612233877182-0.2954559028148651-0.7998436689376831 - Total Loss: 0.5029095578193664: 100%|██████████| 50/50 [00:18<00:00,  2.67it/s]    


Epoch [56/300] - Running Loss: 0.4075, Test Loss: 0.5029, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 1


Loss: 0.35945382714271545-0.2832449972629547-0.9237911701202393 - Total Loss: 0.3888034576177597: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]  
Loss: 0.48415130376815796-0.2897549867630005-0.8056036829948425 - Total Loss: 0.5023036462068557: 100%|██████████| 50/50 [00:19<00:00,  2.61it/s]  


Epoch [57/300] - Running Loss: 0.3888, Test Loss: 0.5023, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 2


Loss: 0.36640527844429016-0.29173508286476135-0.9253298044204712 - Total Loss: 0.39133192479610446: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]
Loss: 0.49408766627311707-0.29667654633522034-0.8025888800621033 - Total Loss: 0.49907003462314603: 100%|██████████| 50/50 [00:18<00:00,  2.67it/s]


salvou no colab
Epoch [58/300] - Running Loss: 0.3913, Test Loss: 0.4991, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 0


Loss: 0.3959291875362396-0.2900020182132721-0.8940728306770325 - Total Loss: 0.38566898345947265: 100%|██████████| 50/50 [01:00<00:00,  1.21s/it]  
Loss: 0.4845961630344391-0.31653156876564026-0.8319354057312012 - Total Loss: 0.5016384840011596: 100%|██████████| 50/50 [00:19<00:00,  2.60it/s]  


Epoch [59/300] - Running Loss: 0.3857, Test Loss: 0.5016, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 1


Loss: 0.3984008729457855-0.2856440842151642-0.8872432112693787 - Total Loss: 0.3789657068252563: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it]   
Loss: 0.4958278238773346-0.31601837277412415-0.8201905488967896 - Total Loss: 0.5081949275732041: 100%|██████████| 50/50 [00:18<00:00,  2.66it/s]  


Epoch [60/300] - Running Loss: 0.3790, Test Loss: 0.5082, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 2


Loss: 0.38778558373451233-0.28417566418647766-0.8963900804519653 - Total Loss: 0.3750283795595169: 100%|██████████| 50/50 [01:00<00:00,  1.20s/it] 
Loss: 0.49797871708869934-0.31833717226982117-0.8203584551811218 - Total Loss: 0.5095633524656296: 100%|██████████| 50/50 [00:19<00:00,  2.63it/s] 


Epoch [61/300] - Running Loss: 0.3750, Test Loss: 0.5096, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 3


Loss: 0.41864514350891113-0.2885339856147766-0.8698888421058655 - Total Loss: 0.37177977919578553: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it] 
Loss: 0.4887706935405731-0.32579347491264343-0.8370227813720703 - Total Loss: 0.5174092662334442: 100%|██████████| 50/50 [00:18<00:00,  2.63it/s] 


Epoch [62/300] - Running Loss: 0.3718, Test Loss: 0.5174, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 4


Loss: 0.3929004669189453-0.2859100103378296-0.8930095434188843 - Total Loss: 0.3784379452466965: 100%|██████████| 50/50 [01:00<00:00,  1.22s/it]   
Loss: 0.4951477646827698-0.32580626010894775-0.830658495426178 - Total Loss: 0.5213278865814209: 100%|██████████| 50/50 [00:18<00:00,  2.70it/s]   


Epoch [63/300] - Running Loss: 0.3784, Test Loss: 0.5213, Initial LR: 0.001000, Current LR: 0.000422, Epochs without Improvement: 5


Loss: 0.3872997760772705-0.28444093465805054-0.89714115858078 - Total Loss: 0.37404749989509584: 100%|██████████| 50/50 [00:59<00:00,  1.19s/it]   
Loss: 0.5283774137496948-0.32421448826789856-0.7958371043205261 - Total Loss: 0.08424505293369293:  16%|█▌        | 8/50 [00:02<00:15,  2.75it/s] 

In [None]:
keynet_default_config: KeyNet_conf = {
    # Key.Net Model
    'num_filters': 8,
    'num_levels': 3,
    'kernel_size': 5,
    # Extraction Parameters
    'Detector_conf': {'nms_size': 5, 'pyramid_levels': 2, 'up_levels': 1, 'scale_factor_levels': 1.3, 's_mult': 20.0},
}

model =KeyNet(keynet_conf=keynet_default_config).to(device)
model_keynet, optimizer, epoch_i, loss =  load_checkpoint(model=model, optimizer=None, path=PATH_MODEL)

from meu_dataset import MeuDataset,avaliar_descritor
model =model.eval()
with torch.no_grad():
    total_acertos,total_erros,total_elementos = avaliar_descritor(dataloader_test, model,th=0.51)
print(f'Total de elementos no DataLoader: {total_elementos}')
print(f'Acertei: {total_acertos}  Errei: {total_erros}')

In [None]:
model =model.eval()
with torch.no_grad():
    total_acertos,total_erros,total_elementos = avaliar_descritor(dataloader_test, model,th=0.025)
sub_conjunto = total_elementos//2
print(f'Total de elementos no DataLoader: {total_elementos}')
print(f'Acertei: {total_acertos}/{sub_conjunto} Errei: {total_erros}/{sub_conjunto}')

### Refazer o treinamento para fazer o descritor na imagem original ao inves da feature

In [None]:
path_dataset = "./data/datasets/img_path_flowers_dataset.pt"
meu_dataset2 = MeuDataset.load_from_file(path_dataset)
train_dataset2, val_dataset2, test_dataset2 = random_split(meu_dataset2, [0.5,0.3,0.2])

# Crie uma instância do DataLoader usando seu conjunto de dados personalizado
dataloader_train2 = DataLoader(train_dataset2, batch_size=batch_size_siam, shuffle=True)
dataloader_val2 = DataLoader(val_dataset2, batch_size=batch_size_siam, shuffle=True)
dataloader_test2 = DataLoader(test_dataset2, batch_size=batch_size_siam, shuffle=True)

In [None]:
n_channel =1
model =Feature(n_channel=n_channel).to(device)
PATH_MODEL = './data/models/img_flowers_sp.pt'
model = train(model,dataloader_train2,dataloader_val2)

In [None]:
model =model.eval()
with torch.no_grad():
    total_acertos,total_erros,total_elementos = avaliar_descritor(dataloader_test2, model,th=0.2)
sub_conjunto = total_elementos//2
print(f'Total de elementos no DataLoader: {total_elementos}')
print(f'Acertei: {total_acertos}/{sub_conjunto} Errei: {total_erros}/{sub_conjunto}')
save_model(model, PATH_MODEL)

In [None]:
from typing import List, Optional

import torch
import torch.nn.functional as F
from torch import nn
from typing_extensions import TypedDict

from kornia.core import Module, Tensor, concatenate
from kornia.filters import SpatialGradient
from kornia.geometry.transform import pyrdown
from kornia.utils.helpers import map_location_to_cpu

from kornia.feature.scale_space_detector import get_default_detector_config, MultiResolutionDetector,Detector_config


class KeyNet_conf(TypedDict):
    num_filters: int
    num_levels: int
    kernel_size: int
    Detector_conf: Detector_config


keynet_default_config: KeyNet_conf = {
    # Key.Net Model
    'num_filters': 8,
    'num_levels': 3,
    'kernel_size': 5,
    # Extraction Parameters
    'Detector_conf': {'nms_size': 5, 'pyramid_levels': 2, 'up_levels': 1, 'scale_factor_levels': 1.3, 's_mult': 20.0},
}


class _FeatureExtractor(Module):
    def __init__(self) -> None:
        super().__init__()

        self.hc_block = _HandcraftedBlock()
        self.lb_block = _LearnableBlock()

    def forward(self, x: Tensor) -> Tensor:
        x_hc = self.hc_block(x)
        x_lb = self.lb_block(x_hc)
        return x_lb


class _HandcraftedBlock(Module):
    def __init__(self) -> None:
        super().__init__()
        self.spatial_gradient = SpatialGradient('sobel', 1)

    def forward(self, x: Tensor) -> Tensor:
        sobel = self.spatial_gradient(x)
        dx, dy = sobel[:, :, 0, :, :], sobel[:, :, 1, :, :]

        sobel_dx = self.spatial_gradient(dx)
        dxx, dxy = sobel_dx[:, :, 0, :, :], sobel_dx[:, :, 1, :, :]

        sobel_dy = self.spatial_gradient(dy)
        dyy = sobel_dy[:, :, 1, :, :]

        hc_feats = concatenate([dx, dy, dx**2.0, dy**2.0, dx * dy, dxy, dxy**2.0, dxx, dyy, dxx * dyy], 1)

        return hc_feats


def _KeyNetConvBlock(
    feat_type_in,
    feat_type_out,
    r2_act,
    kernel_size: int = 5,
    stride: int = 1,
    padding: int = 2,
    dilation: int = 1,
) -> nn.Sequential:
    return enn.SequentialModule(
            enn.R2Conv(feat_type_in, feat_type_out, kernel_size=kernel_size, padding=padding, bias=False),
            enn.InnerBatchNorm(feat_type_out),
            enn.ReLU(feat_type_out, inplace=True),
        )


class _LearnableBlock(nn.Sequential):
    def __init__(self, in_channels: int = 10, out_channels: int = 8, group_size=8) -> None:
        super().__init__()
        r2_act = gspaces.Rot2dOnR2(N=group_size)

        feat_type_in = enn.FieldType(r2_act, in_channels * [r2_act.trivial_repr])
        self.in_type = feat_type_in
        feat_type_out = enn.FieldType(r2_act, out_channels * [r2_act.regular_repr])
        self.block0 = _KeyNetConvBlock(feat_type_in, feat_type_out, r2_act)

        feat_type_out = enn.FieldType(r2_act, out_channels * [r2_act.regular_repr])
        self.block1 = _KeyNetConvBlock(self.block0.out_type, feat_type_out, r2_act)

        feat_type_out = enn.FieldType(r2_act, out_channels * [r2_act.regular_repr])
        self.block2 = _KeyNetConvBlock(self.block1.out_type, feat_type_out, r2_act)
        self.gpool = enn.GroupPooling(self.block2.out_type)

    def forward(self, x: Tensor) -> Tensor:
        x = enn.GeometricTensor(x, self.in_type)
        x = self.block0(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.gpool(x)
        return x.tensor

class SoftHistogramLayer(nn.Module):
    def __init__(self, n_bin=10):
        super(SoftHistogramLayer, self).__init__()
        self.n_bin = n_bin

    def forward(self, x):
        batch_size, num_channels, height, width = x.shape
        out = x.view(batch_size, num_channels, -1)  # Redimensiona 'x' para um tensor 3D
        
        bin_edges = torch.linspace(0, 10, self.n_bin + 1, device=x.device)
        hist = []
        
        for i in range(batch_size):  # Itera sobre as instâncias no lote
            channel_hists = []
            
            for c in range(num_channels):  # Itera sobre os canais
                bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
                
                # Calcula o histograma suave para o canal 'c' da instância atual 'i'
                soft_hist = torch.exp(-(out[i, c, :, None] - bin_centers[None, :])**2 / 0.1).sum(dim=0)
                channel_hists.append(soft_hist)
            
            channel_hist_tensor = torch.stack(channel_hists)  # Empilha os histogramas suaves dos canais
            hist.append(channel_hist_tensor)
        
        hist = torch.stack(hist)  # Empilha os histogramas suaves das instâncias
        return hist.reshape(batch_size, -1)  # Redimensiona para um tensor 2D
    

class KeyNet(Module):
    def __init__(self, pretrained: bool = False, keynet_conf: KeyNet_conf = keynet_default_config) -> None:
        super().__init__()

        num_filters = keynet_conf['num_filters']
        self.num_levels = keynet_conf['num_levels']
        kernel_size = keynet_conf['kernel_size']
        padding = kernel_size // 2

        self.feature_extractor = _FeatureExtractor()
        
        self.last_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=num_filters * self.num_levels, out_channels=num_filters, kernel_size=kernel_size, padding=padding
            ),
            nn.ReLU(inplace=True),
        )
        self.softHistogramLayer = SoftHistogramLayer()
        if pretrained:
            print("KeyNet loaded")


    def forward(self, x: Tensor) -> Tensor:
        shape_im = x.shape
        feats: List[Tensor] = [self.feature_extractor(x)]
        for i in range(1, self.num_levels):
            x = pyrdown(x, factor=1.2)
            feats_i = self.feature_extractor(x)
            feats_i = F.interpolate(feats_i, size=(shape_im[2], shape_im[3]), mode='bilinear')
            feats.append(feats_i)
        scores = self.last_conv(concatenate(feats, 1))
        print("scores ",scores.shape)
        return self.softHistogramLayer(scores)
    

n_channel =8
model =KeyNet().to(device)
image_a =torch.rand(40,n_channel,32,32).to(device)
image_p =torch.rand(40,n_channel,32,32).to(device)
B,C,H,W = image_a.shape
print("image_a ",image_a.shape, image_p.shape)
image_a =image_a.reshape(B*C,1,H,W)
image_p =image_p.reshape(B*C,1,H,W)
print("image_a ",image_a.shape, image_p.shape)
# Calculate descriptors for the anchor and positive images
descs_anchor = model(image_a)
descs_pos = model(image_p)

descs_anchor = descs_anchor.reshape(B,C,-1)
descs_pos = descs_pos.reshape(B,C,-1)
print("out ",descs_anchor.shape, descs_pos.shape)
from matplotlib import pyplot as plt

plt.imshow(image_a[0,0,:,:].cpu().detach().numpy())
plt.show()
print('descs_anchor ',descs_anchor[0])
# Calculate distances/similarities between anchor and all examples in the batch
distances = my_similarity(descs_anchor, descs_pos)  # Broadcasting
print("distances ",distances.shape)
# Choose the hardest negative example for each anchor
hard_negatives = torch.argmin(distances, dim=1)  # Get the index of the minimum similarity for each anchor

def similarity(desc1, desc2):
    desc1_norm = torch.nn.functional.normalize(desc1, dim=-1)
    desc2_norm = torch.nn.functional.normalize(desc2, dim=-1)
    return torch.sum(desc1_norm * desc2_norm)

print(descs_anchor.shape, descs_pos.shape)
similarity(descs_anchor, descs_pos)
