In [47]:
import torch
from torch_geometric.utils import to_scipy_sparse_matrix, unbatch_edge_index
from scipy.sparse.csgraph import connected_components
import torch_geometric as pyg 
import matplotlib.pyplot as plt 
from scipy.optimize import linear_sum_assignment
from src.models.hiclnet import HICLNet
from src.models.lightmodel import LightGlueMOT
import yaml


In [48]:
savedModel = torch.load('outputs/experiments/mot20_private_train_02-29_11:56:27.332701/models/hiclnet_epoch_7_iteration749.pth', map_location='cpu')

In [49]:
with open(r'configs/mymodel_cfg.yaml') as file:
    mymodel_params = yaml.load(file, Loader=yaml.FullLoader)

model = HICLNet(submodel_type=LightGlueMOT, submodel_params=mymodel_params,
                hicl_depth=7, use_motion=[False, True, True, True, True, True, True],
                use_reid_edge=[True]*7, use_pos_edge=[True]*7,
                share_weights='all', edge_level_embed=False,
                node_level_embed=True
                )

# model = LightGlueMOT(mymodel_params)

In [None]:
new_state_dict = {}

# Copy the pretrained weights for encoder and joint_enc modules
for key, value in savedModel.items():
    if 'enc_layer' in key or 'encoder' in key or 'joint_enc' in key:
        new_state_dict[key] = value

model.load_state_dict(new_state_dict, strict=False)

for ix in range(7):  # Assuming there are 7 layers
    i = 0
    for param in model.layers[ix].encoder.parameters():
        i+=1
        param.requires_grad = False

    for param in model.layers[ix].joint_enc.parameters():
        param.requires_grad = False


In [62]:
for name, param in model.named_parameters():
    print(name,param.requires_grad)

layers.0.edge_enc.fc_layers.0.weight True
layers.0.edge_enc.fc_layers.0.bias True
layers.0.edge_enc.fc_layers.3.weight True
layers.0.edge_enc.fc_layers.3.bias True
layers.0.init_enc.weight True
layers.0.init_enc.bias True
layers.0.enc_layer.self_attn.in_proj_weight True
layers.0.enc_layer.self_attn.in_proj_bias True
layers.0.enc_layer.self_attn.out_proj.weight True
layers.0.enc_layer.self_attn.out_proj.bias True
layers.0.enc_layer.linear1.weight True
layers.0.enc_layer.linear1.bias True
layers.0.enc_layer.linear2.weight True
layers.0.enc_layer.linear2.bias True
layers.0.enc_layer.norm1.weight True
layers.0.enc_layer.norm1.bias True
layers.0.enc_layer.norm2.weight True
layers.0.enc_layer.norm2.bias True
layers.0.encoder.layers.0.self_attn.in_proj_weight False
layers.0.encoder.layers.0.self_attn.in_proj_bias False
layers.0.encoder.layers.0.self_attn.out_proj.weight False
layers.0.encoder.layers.0.self_attn.out_proj.bias False
layers.0.encoder.layers.0.linear1.weight False
layers.0.encode

In [57]:
model.layers[ix].encoder

TransformerEncoder(
  (layers): ModuleList(
    (0): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
      )
      (linear1): Linear(in_features=32, out_features=128, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=128, out_features=32, bias=True)
      (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
)

In [46]:
len(model.layers)

7

In [36]:
# Load the new state_dict into the model
model.load_state_dict(new_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['layers.0.edge_enc.fc_layers.0.weight', 'layers.0.edge_enc.fc_layers.0.bias', 'layers.0.edge_enc.fc_layers.3.weight', 'layers.0.edge_enc.fc_layers.3.bias', 'layers.0.init_enc.weight', 'layers.0.init_enc.bias', 'layers.0.pos_enc.pe', 'layers.0.cross_gnn.0.att', 'layers.0.cross_gnn.0.bias', 'layers.0.cross_gnn.0.lin_l.weight', 'layers.0.cross_gnn.0.lin_l.bias', 'layers.0.cross_gnn.0.lin_r.weight', 'layers.0.cross_gnn.0.lin_r.bias', 'layers.0.cross_gnn.0.lin_edge.weight', 'layers.0.cross_gnn.0.lin_edge.bias', 'layers.0.cross_gnn.0.last_projector.weight', 'layers.0.cross_gnn.0.last_projector.bias', 'layers.0.cross_gnn.1.att', 'layers.0.cross_gnn.1.bias', 'layers.0.cross_gnn.1.lin_l.weight', 'layers.0.cross_gnn.1.lin_l.bias', 'layers.0.cross_gnn.1.lin_r.weight', 'layers.0.cross_gnn.1.lin_r.bias', 'layers.0.cross_gnn.1.lin_edge.weight', 'layers.0.cross_gnn.1.lin_edge.bias', 'layers.0.cross_gnn.1.last_projector.weight', 'layers.0.cross_gnn.1.last_projector.bias

In [4]:
savedModel.keys()

odict_keys(['layers.0.edge_enc.fc_layers.0.weight', 'layers.0.edge_enc.fc_layers.0.bias', 'layers.0.edge_enc.fc_layers.3.weight', 'layers.0.edge_enc.fc_layers.3.bias', 'layers.0.init_enc.weight', 'layers.0.init_enc.bias', 'layers.0.pos_enc.pe', 'layers.0.enc_layer.self_attn.in_proj_weight', 'layers.0.enc_layer.self_attn.in_proj_bias', 'layers.0.enc_layer.self_attn.out_proj.weight', 'layers.0.enc_layer.self_attn.out_proj.bias', 'layers.0.enc_layer.linear1.weight', 'layers.0.enc_layer.linear1.bias', 'layers.0.enc_layer.linear2.weight', 'layers.0.enc_layer.linear2.bias', 'layers.0.enc_layer.norm1.weight', 'layers.0.enc_layer.norm1.bias', 'layers.0.enc_layer.norm2.weight', 'layers.0.enc_layer.norm2.bias', 'layers.0.encoder.layers.0.self_attn.in_proj_weight', 'layers.0.encoder.layers.0.self_attn.in_proj_bias', 'layers.0.encoder.layers.0.self_attn.out_proj.weight', 'layers.0.encoder.layers.0.self_attn.out_proj.bias', 'layers.0.encoder.layers.0.linear1.weight', 'layers.0.encoder.layers.0.line

In [2]:
edge_mask = torch.ones((100,), dtype=torch.bool)
edge_mask[50:] = False

In [3]:
edge_mask

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [5]:
a = torch.rand((100,))

In [7]:
a[edge_mask]

tensor([0.3861, 0.7096, 0.0070, 0.7422, 0.7205, 0.5581, 0.2839, 0.2298, 0.2274,
        0.7890, 0.5262, 0.6527, 0.2394, 0.0992, 0.7059, 0.1699, 0.8500, 0.1301,
        0.5336, 0.9371, 0.0759, 0.2226, 0.6008, 0.0200, 0.3690, 0.7807, 0.8870,
        0.2434, 0.7951, 0.6591, 0.9667, 0.3636, 0.8168, 0.9855, 0.3253, 0.2942,
        0.8632, 0.2510, 0.5946, 0.2271, 0.7624, 0.5058, 0.0045, 0.5612, 0.8216,
        0.0340, 0.2017, 0.9015, 0.9913, 0.1861])

In [8]:
a

tensor([0.3861, 0.7096, 0.0070, 0.7422, 0.7205, 0.5581, 0.2839, 0.2298, 0.2274,
        0.7890, 0.5262, 0.6527, 0.2394, 0.0992, 0.7059, 0.1699, 0.8500, 0.1301,
        0.5336, 0.9371, 0.0759, 0.2226, 0.6008, 0.0200, 0.3690, 0.7807, 0.8870,
        0.2434, 0.7951, 0.6591, 0.9667, 0.3636, 0.8168, 0.9855, 0.3253, 0.2942,
        0.8632, 0.2510, 0.5946, 0.2271, 0.7624, 0.5058, 0.0045, 0.5612, 0.8216,
        0.0340, 0.2017, 0.9015, 0.9913, 0.1861, 0.6608, 0.8515, 0.6733, 0.4071,
        0.5306, 0.9550, 0.6557, 0.0694, 0.3091, 0.9469, 0.4367, 0.1064, 0.3367,
        0.1164, 0.1254, 0.0726, 0.4923, 0.4731, 0.8930, 0.8700, 0.6599, 0.4969,
        0.6651, 0.0969, 0.4284, 0.3757, 0.9282, 0.7937, 0.0970, 0.6914, 0.5215,
        0.1864, 0.2324, 0.2827, 0.5069, 0.7201, 0.4124, 0.2847, 0.9479, 0.7692,
        0.4856, 0.8629, 0.1590, 0.1691, 0.1428, 0.6082, 0.1445, 0.5168, 0.8153,
        0.3004])

In [3]:
token.shape

torch.Size([32])

In [11]:
expended = token.unsqueeze(0).expand(100, -1).unsqueeze(1)

In [12]:
a = torch.rand((100, 4, 32))

In [14]:
torch.cat([a, expended], dim=1).shape

torch.Size([100, 5, 32])

In [3]:
torch.zeros((6,), dtype=torch.float, device = torch.device('cpu'))

tensor([0., 0., 0., 0., 0., 0.])

In [2]:
import torch.nn as nn

class LearnableFourierFeatures(nn.Module):
    def __init__(self, 
                 M,  # M: Dimension of the Positions   / Input dim
                 D,  # D: Depth of Positional Encoding /Output dim
                 G,  # G: Number of Groups
                 F,  # F: Fourier Feature Dimension
                 H,  # H: Hidden Layer Dimension
                 gamma
                ):
        super().__init__()
        self.F = F
        self.D = D
        self.gamma=gamma

        self.Wr = nn.Linear(M, F//2, bias=False)

        self.mlp = nn.Sequential(
            nn.Linear(F, H),
            nn.GELU(),
            nn.Linear(H, D//G)
        )
        
        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
    
    def forward(self, x):
        '''
        Produce positional encodings from x
        :param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M],
                  where G is the positional group and each group has M-dimensional positional values.
                  Positions in different positional groups are independent
        :return: positional encoding for X
        '''
        N, G, M = x.shape

        F = self.Wr(x)
        F = torch.cat([torch.cos(F), torch.sin(F)], dim=2) # /torch.sqrt(torch.tensor(self.F)) 

        Y = self.mlp(F)
        PEx =  Y.reshape((N, self.D))
        return PEx 
    
class PositionalEncoder(nn.Module):
    def __init__(self):
        super(PositionalEncoder, self).__init__()

        self.point_enc = LearnableFourierFeatures(M=2, D=32, G=1, F=8, H=8, gamma=1)
        self.size_enc = LearnableFourierFeatures(M=2, D=32, G=1, F=8, H=8, gamma=1)
        # self.point_enc = MLP(2, [8,32], dropout_p=0.1, use_batchnorm=False, bias=False)
        # self.size_enc  = MLP(2, [8,32], dropout_p=0.1, use_batchnorm=False, bias=False)

    def forward(self, bbox):
        '''
            bbox: Tensor: [N, (L, T, H, W)]   
        '''
        points = bbox[:, :2]
        size = bbox[:, 2:]
        # points = self.point_enc(points)
        # size = self.size_enc(size)
        points = self.point_enc(points.unsqueeze(1))
        size = self.size_enc(size.unsqueeze(1))
        return torch.cat([points, size], dim=-1)     

In [None]:
a = PositionalEncoder()

In [2]:
aa = 1

In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
from torch_geometric.utils import to_scipy_sparse_matrix, unbatch_edge_index
from scipy.sparse.csgraph import connected_components

import math
from typing import Union, Optional
from torch_geometric.typing import (Adj, Size, OptTensor, PairTensor)

from copy import deepcopy
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
from torch_geometric.data import Batch
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.nn import GATv2Conv

class GATv2ConvMOT(MessagePassing):
    r"""The GATv2 operator from the `"How Attentive are Graph Attention Networks?"
    <https://arxiv.org/abs/2105.14491>`_ paper, which fixes the static
    attention problem of the standard :class:`~torch_geometric.conv.GATConv`
    layer: since the linear layers in the standard GAT are applied right after
    each other, the ranking of attended nodes is unconditioned on the query
    node. In contrast, in GATv2, every node can attend to any other node.

    https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GATv2Conv.html#torch_geometric.nn.conv.GATv2Conv

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        heads (int, optional): Number of multi-head-attentions.
            (default: :obj:`1`)
        concat (bool, optional): If set to :obj:`False`, the multi-head
            attentions are averaged instead of concatenated.
            (default: :obj:`True`)
        negative_slope (float, optional): LeakyReLU angle of the negative
            slope. (default: :obj:`0.2`)
        dropout (float, optional): Dropout probability of the normalized
            attention coefficients which exposes each node to a stochastically
            sampled neighborhood during training. (default: :obj:`0`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        share_weights (bool, optional): If set to :obj:`True`, the same matrix
            will be applied to the source and the target node of every edge.
            (default: :obj:`False`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    _alpha: OptTensor

    def __init__(self, in_channels: int,
                 out_channels: int, heads: int = 1, concat: bool = True,
                 negative_slope: float = 0.2, dropout: float = 0.,
                 bias: bool = True, share_weights: bool = False,
                 **kwargs):
        super(GATv2ConvMOT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.share_weights = share_weights

        self.update_mlp = Linear(2*out_channels, out_channels)


        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias)
        if share_weights:
            self.lin_r = self.lin_l
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)

        self.att = Parameter(torch.Tensor(1, heads, 2*out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha = None

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.lin_l.weight)
        glorot(self.lin_r.weight)
        glorot(self.update_mlp.weight)
        glorot(self.att)
        zeros(self.bias)

    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
                size: Size = None, return_attention_weights: bool = None):
        # type: (Union[Tensor, PairTensor], Tensor, Size, Tensor) -> Tensor  # noqa
        # type: (Union[Tensor, PairTensor], SparseTensor, Size, NoneType) -> Tensor  # noqa
        # type: (Union[Tensor, PairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
        # type: (Union[Tensor, PairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor]  # noqa
        r"""
        Args:
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        H, C = self.heads, self.out_channels

        x_l: OptTensor = None
        x_r: OptTensor = None
    
        assert x.dim() == 2
        x_l = self.lin_l(x).view(-1, H, C)
        if self.share_weights:
            x_r = x_l
        else:
            x_r = self.lin_r(x).view(-1, H, C)

        assert x_l is not None
        assert x_r is not None

        # propagate_type: (x: PairTensor)
        out = self.propagate(edge_index, x=(x_l, x_r), size=size)

        alpha = self._alpha
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out += self.bias

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, x_j: Tensor, x_i: Tensor,
                index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:
        # x = x_i + x_j
        x_i = F.leaky_relu(x_i, self.negative_slope)
        x_j = F.leaky_relu(x_j, self.negative_slope)
        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return x_j * alpha.unsqueeze(-1)

    def update(self,aggregate_out, x):
        x_l, x_r = x 
        return x_l + self.update_mlp(torch.cat([x_l, aggregate_out], dim=-1))

    def __repr__(self):
        return '{}({}, {}, heads={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.heads)


class MlpBlock(nn.Module):
    """Multilayer perceptron.

    Parameters
    ----------
    dim : int
        Input and output dimension of the entire block. Inside of the mixer
        it will either be equal to `n_patches` or `hidden_dim`.

    mlp_dim : int
        Dimension of the hidden layer.

    Attributes
    ----------
    linear_1, linear_2 : nn.Linear
        Linear layers.

    activation : nn.GELU
        Activation.
    """

    def __init__(self, dim, mlp_dim=None):
        super().__init__()

        mlp_dim = dim if mlp_dim is None else mlp_dim
        self.linear_1 = nn.Linear(dim, mlp_dim)
        self.activation = nn.GELU()
        self.linear_2 = nn.Linear(mlp_dim, dim)

    def forward(self, x):
        """Run the forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape `(n_samples, n_channels, n_patches)` or
            `(n_samples, n_patches, n_channels)`.

        Returns
        -------
        torch.Tensor
            Output tensor that has exactly the same shape as the input `x`.
        """
        x = self.linear_1(x)  # (n_samples, *, mlp_dim)
        x = self.activation(x)  # (n_samples, *, mlp_dim)
        x = self.linear_2(x)  # (n_samples, *, dim)
        return x


class MixerBlock(nn.Module):
    """Mixer block that contains two `MlpBlock`s and two `LayerNorm`s.

    Parameters
    ----------
    n_patches : int
        Number of patches the image is split up into.

    hidden_dim : int
        Dimensionality of patch embeddings.

    tokens_mlp_dim : int
        Hidden dimension for the `MlpBlock` when doing token mixing.

    channels_mlp_dim : int
        Hidden dimension for the `MlpBlock` when doing channel mixing.

    Attributes
    ----------
    norm_1, norm_2 : nn.LayerNorm
        Layer normalization.

    token_mlp_block : MlpBlock
        Token mixing MLP.

    channel_mlp_block : MlpBlock
        Channel mixing MLP.
    """

    def __init__(
        self, *, n_patches, hidden_dim, tokens_mlp_dim, channels_mlp_dim
    ):
        super().__init__()

        self.norm_1 = nn.LayerNorm(hidden_dim)
        self.norm_2 = nn.LayerNorm(hidden_dim)

        self.token_mlp_block = MlpBlock(n_patches, tokens_mlp_dim)
        self.channel_mlp_block = MlpBlock(hidden_dim, channels_mlp_dim)

    def forward(self, x):
        """Run the forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Tensor of shape `(n_samples, n_patches, hidden_dim)`.

        Returns
        -------
        torch.Tensor
            Tensor of the same shape as `x`, i.e.
            `(n_samples, n_patches, hidden_dim)`.
        """
        y = self.norm_1(x)  # (n_samples, n_patches, hidden_dim)
        y = y.permute(0, 2, 1)  # (n_samples, hidden_dim, n_patches)
        y = self.token_mlp_block(y)  # (n_samples, hidden_dim, n_patches)
        y = y.permute(0, 2, 1)  # (n_samples, n_patches, hidden_dim)
        x = x + y  # (n_samples, n_patches, hidden_dim)
        y = self.norm_2(x)  # (n_samples, n_patches, hidden_dim)
        res = x + self.channel_mlp_block(
            y
        )  # (n_samples, n_patches, hidden_dim)
        return res
    


def normalize_positions(bbox_start, bbox_end, image_shape = (1920, 1080)):
    """ Normalize positional features based on image_shape """
    # (left, top, W, H)

    img_shape_tensor = torch.tensor(image_shape).to(bbox_start.device)
    # center_start = center_start/img_shape_tensor
    # center_end   = center_end/img_shape_tensor
    bbox_start   = bbox_start / torch.cat([img_shape_tensor,img_shape_tensor])
    bbox_end     = bbox_end / torch.cat([img_shape_tensor,img_shape_tensor])

    # return torch.cat([center_start, center_end, bbox_start, bbox_end], axis = 1)
    return torch.cat([bbox_start, bbox_end], axis = 1)


def sigmoid_log_double_softmax(
        sim: torch.Tensor, logaritmic: bool) -> torch.Tensor:
    """ create the log assignment matrix from logits and similarity"""
    b, m, n = sim.shape
    scores = sim.new_full((b, m, n), 0)

    if logaritmic:
        scores0 = F.log_softmax(sim, 2)
        scores1 = F.log_softmax(
            sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
        scores[:, :m, :n] = (scores0 + scores1)/2
    else:
        scores0 = F.softmax(sim, 2)
        scores1 = F.softmax(
            sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
        scores[:, :m, :n] = (scores0 + scores1)
    
    return scores

class MatchAssignment(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, logaritmic=True):
        """ build assignment matrix from descriptors """
        n, d = x.shape
        x = x / d**.25
        
        sim = torch.full((n, n), torch.finfo(torch.float).min, device=x.device)
        sim[edge_index[0], edge_index[1]] = torch.einsum('md,nd->mn', x, x)[edge_index[0], edge_index[1]]

        scores = sigmoid_log_double_softmax(
                        sim.unsqueeze(0), 
                        logaritmic=logaritmic
                        )

        if logaritmic:
            return scores.exp().squeeze(), sim
        return scores.squeeze(), sim

class MLP(nn.Module):
    def __init__(self, input_dim, fc_dims, dropout_p=0., use_batchnorm=False, **kwargs):
        super(MLP, self).__init__()

        assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either a list or a tuple, but got {}'.format(
            type(fc_dims))

        layers = []
        for i, dim in enumerate(fc_dims):
            layers.append(nn.Linear(input_dim, dim, **kwargs))
            if use_batchnorm and dim != 1:
                layers.append(nn.BatchNorm1d(dim))

            if i != len(fc_dims) - 1:
                layers.append(nn.ReLU(inplace=True))

            if dropout_p is not None and dim != 1:
                layers.append(nn.Dropout(p=dropout_p))

            input_dim = dim

        self.fc_layers = nn.Sequential(*layers)

    def forward(self, input):
        return self.fc_layers(input)

class PositionalEncoder(nn.Module):
    def __init__(self):
        super(PositionalEncoder, self).__init__()

        self.point_enc = MLP(2, [16,64], dropout_p=0.1, use_batchnorm=False, bias=False)
        self.size_enc  = MLP(2, [16,64], dropout_p=0.1, use_batchnorm=False, bias=False)

    def forward(self, bbox):
        '''
            bbox: Tensor: [N, (L, T, H, W)]   
        '''
        points = bbox[:, :2]
        size = bbox[:, 2:]
        points = self.point_enc(points)
        size = self.size_enc(size)
        return torch.cat([points, size], dim=-1)        

class LightGlueMOT(nn.Module):

    def __init__(self, model_params, node_level_embed=None):
        super(LightGlueMOT, self).__init__()
    
        self.model_params = model_params
        gnn_params = model_params['gnn_params']

        self.node_enc = MLP(**model_params['vis_enc_params'])
        self.pos_enc = PositionalEncoder()
        self.time_enc = pyg_nn.TemporalEncoding(out_channels=128)

        self.mixer = nn.Sequential(
            MixerBlock(**model_params['mixer_block_params']),
            MixerBlock(**model_params['mixer_block_params']),
            MixerBlock(**model_params['mixer_block_params'])
        )

        self.pos_post = MLP(128, [32, 128], dropout_p=0.2)

        self.reattach_initial_nodes = model_params['reattach_initial_nodes']

        node_factor = 2 if self.reattach_initial_nodes else 1
        gnn_in_channels = gnn_params['in_channels'] # + 2*model_params['pos_enc_params']['D'] if self.model_params['use_pos_enc'] else gnn_params['in_channels']
        gnn_in_dim = node_factor * gnn_in_channels

        self.gnn = nn.ModuleList([
            GATv2ConvMOT(
                in_channels=gnn_in_dim,
                out_channels= gnn_params['out_channels'],
                heads=gnn_params['heads'],
                concat=gnn_params['concat'],
                negative_slope = gnn_params['negative_slope'],
                dropout = gnn_params['dropout'],
                bias = gnn_params['bias'], 
                share_weights=gnn_params['share_weights']
            )
            for _ in range(model_params['num_message_passing'])
        ])
        
        self.matcher = MatchAssignment(dim=model_params['out_node_dim'])

        if gnn_params['concat']:
            node_dim_after_mpn = gnn_params['out_channels'] * gnn_params['heads']
        else:
            node_dim_after_mpn = gnn_params['out_channels'] 

        final_proj_factor = 2 if node_level_embed else 1
        self.final_proj = nn.Linear(final_proj_factor*node_dim_after_mpn, model_params['out_node_dim'])


    def forward(self, graph, node_embdeddings = None, alpha=0.75):
        n_nodes = graph.x.shape[0]

        # 0) Obtain Data and Normalize
        x, edge_index, self_edge_index, edge_attr = graph.x, graph.edge_index, graph.self_edge_index, graph.edge_attr
        bbox_start, bbox_end, frames = graph.x_box_start, graph.x_box_end, graph.x_frame
        frame_start, frame_end = frames
        
        pos_feats = normalize_positions(bbox_start, bbox_end)
        bbox_start = pos_feats[:, :4].clone()
        bbox_end = pos_feats[:, 4:].clone()
        in_x = x
        # 1) First obtain visual and positional Encodings
        x = self.node_enc(x)

        p_s = self.pos_enc(bbox_start)
        p_e = self.pos_enc(bbox_end)
        t_start = self.time_enc(torch.tensor(frame_start, dtype=torch.float)) 
        t_end   = self.time_enc(torch.tensor(frame_end, dtype=torch.float))

        # 2) Before GNN use MLP mixed to mix the appearance with the positional & temporal encodings.
        mixed_latent = torch.stack([x, t_start, p_s, t_end, p_e], dim=1) # [num_nodes, 5, 128]
        mixed_latent = self.mixer(mixed_latent)
        mixed_latent = F.avg_pool1d(mixed_latent.permute(0, 2, 1), kernel_size=5).squeeze()
      
        # 3) Apply message passing
        # In each even iteration perform message passing on the same frame, in each odd iteration perform cross message passing
        for idx, gnn in enumerate(self.gnn):

            # Reattach the initially encoded node embeddings before the GNN
            if self.reattach_initial_nodes:
                mixed_latent = torch.cat([mixed_latent, x], dim=1)
            
            # Message passing in the same frame/block
            if idx%2 == 0 and self.model_params['use_self_edges']:
                mixed_latent = gnn(mixed_latent, self_edge_index)
                
            # Message passing in the cross frames/blocks
            else:
                mixed_latent = gnn(mixed_latent, edge_index)
        
        # 4) For the latent variable (mixed_latent) assume it has both positional and visual knowledge.

        # There are two branches one is the mathcing which is mostly appearance guided. Sinkhorn
        # The other branch try to predict the IOU, the mathcability score.
        
        positional_latent = self.pos_post(mixed_latent)
        p_latent_left = positional_latent[edge_index[0]]
        p_latent_rigt = positional_latent[edge_index[1]]

        time_diff_enc = t_start[edge_index[1]] - t_end[edge_index[1]]

        p_latent_left = p_e[edge_index[0]] + (time_diff_enc * p_latent_left)
        p_latent_rigt = p_s[edge_index[0]] + (-time_diff_enc * p_latent_rigt)

        matchability = F.cosine_similarity(p_latent_left, p_latent_rigt, dim=-1)
        # matchability = (p_latent_left * p_latent_rigt).sum(dim=-1) 
        
        # Apply final projection to nodes
        if node_embdeddings is not None:
            # If there is node embedding to be used in the level of hierarchy add them to the nodes
            hicl_node_embed = node_embdeddings.unsqueeze(0).expand(n_nodes, -1)
            mixed_latent = self.final_proj(torch.cat([mixed_latent, hicl_node_embed], dim=1))
        else:
            mixed_latent = self.final_proj(mixed_latent)
        
        # Before performing sinkhorn produce last scores.
        x = x + mixed_latent

        # batch_graphs = graph.to_data_list()
        # for i, b_graph in enumerate(batch_graphs):
        #     feats_mask = graph.batch == i
        #     feats = x[feats_mask]
        #     scores, _ = self.matcher(feats, b_graph.edge_index)
        #     b_graph.edge_preds = scores[b_graph.edge_index[0], b_graph.edge_index[1]]

        scores, _ = self.matcher(x, edge_index)
        graph.edge_preds = scores[edge_index[0], edge_index[1]]
        # graph = Batch.from_data_list(batch_graphs)
        

        graph.edge_preds = graph.edge_preds * matchability
        # Only used to conform with the hicl_tracker
        outputs_dict = {'classified_edges': [graph.edge_preds]}

        return outputs_dict



In [4]:
import yaml 

with open(r'configs/mymodel_cfg.yaml') as file:
    mymodel_params = yaml.load(file, Loader=yaml.FullLoader)

In [5]:
model = LightGlueMOT(mymodel_params['graph_model_params'])

In [10]:
model.time_enc(torch.tensor(1.3))

tensor([[0.2675, 0.4999, 0.6647, 0.7777, 0.8537, 0.9042, 0.9375, 0.9593, 0.9735,
         0.9828, 0.9888, 0.9927, 0.9953, 0.9969, 0.9980, 0.9987, 0.9992, 0.9995,
         0.9996, 0.9998, 0.9999, 0.9999, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0

In [128]:
res= model(graph[0])

  t_start = self.time_enc(torch.tensor(frame_start, dtype=torch.float))
  t_end   = self.time_enc(torch.tensor(frame_end, dtype=torch.float))


In [129]:
(res['classified_edges'][0]>0.5).sum()

tensor(7)

In [130]:
res['classified_edges'][0][:100]

tensor([0.0978, 0.0755, 0.1129, 0.0823, 0.0844, 0.1179, 0.0791, 0.0933, 0.0762,
        0.0882, 0.1185, 0.0895, 0.0864, 0.0966, 0.0813, 0.0833, 0.0844, 0.0830,
        0.1014, 0.0685, 0.0810, 0.0833, 0.0755, 0.0814, 0.0689, 0.0797, 0.0687,
        0.0949, 0.0948, 0.1131, 0.0906, 0.0992, 0.0832, 0.0759, 0.0737, 0.0867,
        0.0975, 0.0765, 0.0863, 0.0903, 0.0857, 0.1015, 0.0805, 0.0845, 0.0793,
        0.1050, 0.0880, 0.0992, 0.1301, 0.0955, 0.0814, 0.0853, 0.0829, 0.0825,
        0.0834, 0.0793, 0.0773, 0.0824, 0.0730, 0.0818, 0.0668, 0.1037, 0.0766,
        0.0706, 0.0750, 0.0851, 0.0815, 0.0826, 0.0644, 0.0634, 0.0584, 0.0846,
        0.0816, 0.0733, 0.0830, 0.0843, 0.0801, 0.0769, 0.0859, 0.0721, 0.0690,
        0.0869, 0.0819, 0.0893, 0.0761, 0.1040, 0.0852, 0.0718, 0.0683, 0.0860,
        0.0552, 0.0594, 0.0568, 0.0662, 0.0555, 0.0560, 0.0521, 0.0680, 0.0594,
        0.0570], grad_fn=<SliceBackward0>)

In [17]:
mixed

tensor([[ 9.6232e-01,  9.3537e-01,  9.2088e-01,  ...,  8.8755e-02,
         -3.4408e-02, -7.7283e-04],
        [ 9.8010e-01,  9.6580e-01,  9.5808e-01,  ...,  6.3104e-02,
         -2.4450e-02, -5.4912e-04],
        [ 9.8961e-01,  9.8212e-01,  9.7807e-01,  ...,  1.1822e-01,
         -4.5871e-02, -1.0304e-03],
        ...,
        [ 9.7074e-01,  9.4976e-01,  9.3847e-01,  ...,  3.0532e-01,
         -1.1986e-01, -2.6981e-03],
        [ 9.8086e-01,  9.6709e-01,  9.5966e-01,  ...,  6.0769e-02,
         -2.3545e-02, -5.2878e-04],
        [ 9.7823e-01,  9.6258e-01,  9.5415e-01,  ...,  3.4293e-01,
         -1.3513e-01, -3.0439e-03]], grad_fn=<ReshapeAliasBackward0>)

In [173]:
torch.std(mixed)

tensor(0.6202)

In [None]:
(graph[0].bipartite_labels%2)[:25]

In [None]:
from lapsolver import solve_dense

In [None]:
np.array([[1.,0.,0.5], [0.2,0.4,0.3], [0.,0.25,0.58]])

In [None]:
import torch 
import torch.nn as nn
from torch_geometric.utils import to_scipy_sparse_matrix, unbatch_edge_index
from scipy.sparse.csgraph import connected_components

from typing import Union, Optional
from torch_geometric.typing import (Adj, Size, OptTensor, PairTensor)

from copy import deepcopy
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
from torch_geometric.data import Batch
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.nn import GATv2Conv

class GATv2ConvMOT(MessagePassing):
    r"""The GATv2 operator from the `"How Attentive are Graph Attention Networks?"
    <https://arxiv.org/abs/2105.14491>`_ paper, which fixes the static
    attention problem of the standard :class:`~torch_geometric.conv.GATConv`
    layer: since the linear layers in the standard GAT are applied right after
    each other, the ranking of attended nodes is unconditioned on the query
    node. In contrast, in GATv2, every node can attend to any other node.

    https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GATv2Conv.html#torch_geometric.nn.conv.GATv2Conv

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        heads (int, optional): Number of multi-head-attentions.
            (default: :obj:`1`)
        concat (bool, optional): If set to :obj:`False`, the multi-head
            attentions are averaged instead of concatenated.
            (default: :obj:`True`)
        negative_slope (float, optional): LeakyReLU angle of the negative
            slope. (default: :obj:`0.2`)
        dropout (float, optional): Dropout probability of the normalized
            attention coefficients which exposes each node to a stochastically
            sampled neighborhood during training. (default: :obj:`0`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        share_weights (bool, optional): If set to :obj:`True`, the same matrix
            will be applied to the source and the target node of every edge.
            (default: :obj:`False`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    _alpha: OptTensor

    def __init__(self, in_channels: int,
                 out_channels: int, heads: int = 1, concat: bool = True,
                 negative_slope: float = 0.2, dropout: float = 0.,
                 bias: bool = True, share_weights: bool = False,
                 **kwargs):
        super(GATv2ConvMOT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.share_weights = share_weights

        self.update_mlp = Linear(2*out_channels, out_channels)


        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias)
        if share_weights:
            self.lin_r = self.lin_l
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias)

        self.att = Parameter(torch.Tensor(1, heads, 2*out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha = None

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.lin_l.weight)
        glorot(self.lin_r.weight)
        glorot(self.update_mlp.weight)
        glorot(self.att)
        zeros(self.bias)

    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
                size: Size = None, return_attention_weights: bool = None):
        # type: (Union[Tensor, PairTensor], Tensor, Size, Tensor) -> Tensor  # noqa
        # type: (Union[Tensor, PairTensor], SparseTensor, Size, NoneType) -> Tensor  # noqa
        # type: (Union[Tensor, PairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
        # type: (Union[Tensor, PairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor]  # noqa
        r"""
        Args:
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        H, C = self.heads, self.out_channels

        x_l: OptTensor = None
        x_r: OptTensor = None
    
        assert x.dim() == 2
        x_l = self.lin_l(x).view(-1, H, C)
        if self.share_weights:
            x_r = x_l
        else:
            x_r = self.lin_r(x).view(-1, H, C)

        assert x_l is not None
        assert x_r is not None

        # propagate_type: (x: PairTensor)
        out = self.propagate(edge_index, x=(x_l, x_r), size=size)

        alpha = self._alpha
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out += self.bias

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, x_j: Tensor, x_i: Tensor,
                index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:
        # x = x_i + x_j
        x_i = F.leaky_relu(x_i, self.negative_slope)
        x_j = F.leaky_relu(x_j, self.negative_slope)
        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return x_j * alpha.unsqueeze(-1)

    def update(self,aggregate_out, x):
        x_l, x_r = x 
        return x_l + self.update_mlp(torch.cat([x_l, aggregate_out], dim=-1))

    def __repr__(self):
        return '{}({}, {}, heads={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.heads)

class LearnableFourierFeatures(nn.Module):
    def __init__(self, 
                 M,  # M: Dimension of the Positions   / Input dim
                 D,  # D: Depth of Positional Encoding /Output dim
                 G,  # G: Number of Groups
                 F,  # F: Fourier Feature Dimension
                 H,  # H: Hidden Layer Dimension
                 gamma
                ):
        super().__init__()
        self.F = F
        self.D = D
        self.gamma=gamma

        self.Wr = nn.Linear(M, F//2, bias=False)
        self.mlp = nn.Sequential(
            nn.Linear(F, H),
            nn.GELU(),
            nn.Linear(H, D//G)
        )
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
 
    def forward(self, x):
        '''
        Produce positional encodings from x
        :param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M],
                  where G is the positional group and each group has M-dimensional positional values.
                  Positions in different positional groups are independent
        :return: positional encoding for X
        '''
        N, G, M = x.shape

        F = self.Wr(x)
        F = torch.cat([torch.cos(F), torch.sin(F)], dim=2) #/torch.sqrt(torch.tensor(self.F))      
        Y = self.mlp(F)
        PEx =  Y.reshape((N, self.D))
        return PEx 

def normalize_positions(bbox_start, bbox_end, image_shape = (1920, 1080)):
    """ Normalize positional features based on image_shape """
    # (left, top, W, H)

    img_shape_tensor = torch.tensor(image_shape).to(bbox_start.device)
    bbox_start   = bbox_start / torch.cat([img_shape_tensor,img_shape_tensor])
    bbox_end     = bbox_end / torch.cat([img_shape_tensor,img_shape_tensor])

    # return torch.cat([center_start, center_end, bbox_start, bbox_end], axis = 1)
    return torch.cat([bbox_start, bbox_end], axis = 1)

def predict_edges_with_ot(matcher, x, edge_index, node_batch_labels, alpha=1.0, iters=30):
    node_batch_labels = node_batch_labels.squeeze()
    n_edges = edge_index.shape[1]
  
    unbatched_edge_idx = unbatch_edge_index(edge_index[:,:n_edges//2], node_batch_labels.to(torch.int64))

    edge_preds = torch.zeros((n_edges,), device = x.device)
    prev_length = 0

    for i_ in range(len(unbatched_edge_idx)):
        node_mask = node_batch_labels == i_
        sub_nodes = x[node_mask]
        bipartite_edges = unbatched_edge_idx[i_]
        if bipartite_edges.nelement() == 0:
            continue
        
        left_node_ixs  = torch.unique(bipartite_edges[0, :])
        left_node_ixs  = torch.arange(0, left_node_ixs.max()+1)
        n_left_nodes = left_node_ixs.size(0)

        right_node_ixs = torch.unique(bipartite_edges[1, :])
        right_node_ixs = torch.arange(n_left_nodes, node_mask.sum())

        left_nodes  = sub_nodes[left_node_ixs] # nxd
        right_nodes = sub_nodes[right_node_ixs] # mxd

        edge_index_tuple = [tuple(bipartite_edges[:,i].tolist()) for i in range(bipartite_edges.shape[1])]
        all_edges_tuple  = [(i,n_left_nodes+j) for i in range(left_nodes.shape[0]) for j in range(right_nodes.shape[0])]
        missing_edges = list(set(all_edges_tuple) - set(edge_index_tuple))

        rows = [missing_edges[k_][0] for k_ in range(len(missing_edges))]
        cols = [missing_edges[k_][1] - n_left_nodes  for k_ in range(len(missing_edges))]
        

        soft_assign = matcher(left_nodes, right_nodes, rows, cols)
        preds_ = soft_assign[bipartite_edges[0], bipartite_edges[1] - n_left_nodes]
        edge_preds[prev_length: prev_length + bipartite_edges.shape[1]] = preds_

        prev_length += bipartite_edges.shape[1]

    edge_preds[n_edges//2:] = edge_preds[:n_edges//2]
    return edge_preds


def sigmoid_log_double_softmax(
        sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor:
    """ create the log assignment matrix from logits and similarity"""
    b, m, n = sim.shape

    certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)

    scores0 = F.log_softmax(sim, 2)
    print("scores0: ", scores0.shape)
    scores1 = F.log_softmax(
        sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
    print("scores1: ", scores0.shape)
    scores = sim.new_full((b, m+1, n+1), 0)
    scores[:, :m, :n] = (scores0 + scores1)
    scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
    scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
    return scores, certainties


class MatchAssignment(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim
        self.matchability = nn.Linear(dim, 1, bias=True)
        self.final_proj = nn.Linear(dim, dim, bias=True)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
        """ build assignment matrix from descriptors """
        n, d = x.shape
        x = x / d**.25
        sim = torch.full((n,n), -torch.inf, device=x.device)
        sim[edge_index[0], edge_index[1]] = torch.einsum('md,nd->mn', x, x)[edge_index[0], edge_index[1]]

        z0 = self.matchability(x.unsqueeze(0))
        scores, certainties = sigmoid_log_double_softmax(sim.unsqueeze(0), z0, z0)
        return scores.exp().squeeze(), sim

    def get_matchability(self, desc: torch.Tensor):
        return torch.sigmoid(self.matchability(desc)).squeeze(-1)

class MLP(nn.Module):
    def __init__(self, input_dim, fc_dims, dropout_p=0.4, use_batchnorm=True):
        super(MLP, self).__init__()

        assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either a list or a tuple, but got {}'.format(
            type(fc_dims))

        layers = []
        for dim in fc_dims:
            layers.append(nn.Linear(input_dim, dim))
            if use_batchnorm and dim != 1:
                layers.append(nn.BatchNorm1d(dim))

            if dim != 1:
                layers.append(nn.ReLU(inplace=True))

            if dropout_p is not None and dim != 1:
                layers.append(nn.Dropout(p=dropout_p))

            input_dim = dim

        self.fc_layers = nn.Sequential(*layers)

    def forward(self, input):
        return self.fc_layers(input)

class LightGlueMOT(nn.Module):

    def __init__(self, model_params, node_level_embed=None):
        super(LightGlueMOT, self).__init__()
    
        self.model_params = model_params
        gnn_params = model_params['gnn_params']

        self.node_enc = MLP(**model_params['vis_enc_params'])
        self.pos_enc  = LearnableFourierFeatures(**model_params['pos_enc_params'])


        self.reattach_initial_nodes = model_params['reattach_initial_nodes']

        node_factor = 2 if self.reattach_initial_nodes else 1
        gnn_in_dim = node_factor * gnn_params['in_channels']

        self.gnn = nn.ModuleList([
            GATv2ConvMOT(
                in_channels=gnn_in_dim,
                out_channels= gnn_params['out_channels'],
                heads=gnn_params['heads'],
                concat=gnn_params['concat'],
                negative_slope = gnn_params['negative_slope'],
                dropout = gnn_params['dropout'],
                bias = gnn_params['bias'], 
                share_weights=gnn_params['share_weights']
            )
            for _ in range(model_params['num_message_passing'])
        ])
        
        self.matcher = MatchAssignment(dim=model_params['out_node_dim'])

        if gnn_params['concat']:
            node_dim_after_mpn = gnn_params['out_channels'] * gnn_params['heads']
        else:
            node_dim_after_mpn = gnn_params['out_channels'] 

        final_proj_factor = 2 if node_level_embed else 1
        self.final_proj = nn.Linear(final_proj_factor*node_dim_after_mpn, model_params['out_node_dim'])



    def forward(self, graph, node_embdeddings = None):
        n_nodes = graph.x.shape[0]
        n_edges = graph.edge_index.shape[1]

        x, edge_index, self_edge_index, edge_attr = graph.x, graph.edge_index, graph.self_edge_index, graph.edge_attr
        bbox_start, bbox_end, frames = graph.x_box_start, graph.x_box_end, graph.x_frame
        frame_start, frame_end = frames
        
        # Normalize positional features of detections
        pos_feats = normalize_positions(bbox_start, bbox_end)
        bbox_start = pos_feats[:, :4].clone()
        bbox_end = pos_feats[:, 4:].clone()

        
        # Encode positional and visual features and sum them.
        x = self.node_enc(x)

        if self.model_params['use_pos_enc']:
            M_ = self.model_params['pos_enc_params']['M']
            p_s = self.pos_enc(bbox_start.view(bbox_start.size(0),-1,M_ ))
            p_e = self.pos_enc(bbox_end.view(bbox_end.size(0),-1,M_ ))

            positional_enc = torch.cat([p_s, p_e], dim=-1)
            # scaling_alpha = torch.mean(x)/torch.mean(positional_enc)
            x = x + 2 * positional_enc

        initial_node_feats = x
    
        # In each even iteration perform message passing on the same frame, in each odd iteration perform cross message passing
        for idx, gnn in enumerate(self.gnn):

            # Reattach the initially encoded node embeddings before the GNN
            if self.reattach_initial_nodes:
                x = torch.cat([x, initial_node_feats], dim=1)
            
            # Message passing in the same frame/block
            if idx%2 == 0 and self.model_params['use_self_edges']:
                x = gnn(x, self_edge_index)
                
            # Message passing in the cross frames/blocks
            else:
                x = gnn(x, edge_index)
        
        # Apply final projection to nodes
        if node_embdeddings is not None:
            # If there is node embedding to be used in the level of hierarchy add them to the nodes
            hicl_node_embed = node_embdeddings.unsqueeze(0).expand(n_nodes, -1)
            x = self.final_proj(torch.cat([x, hicl_node_embed], dim=1))
        else:
            x = self.final_proj(x)

        try:
            batch_graphs = graph.to_data_list()
            for i, b_graph in enumerate(batch_graphs):
                feats_mask = graph.batch == i
                feats = x[feats_mask]
                scores, sim = self.matcher(feats, b_graph.edge_index)
                b_graph.edge_preds = scores[b_graph.edge_index[0], b_graph.edge_index[1]]
            
            graph = Batch.from_data_list(batch_graphs)
        except:
            scores, sim = self.matcher(x, edge_index)
            graph.edge_preds = scores[edge_index[0], edge_index[1]]
        # Only used to conform with the hicl_tracker
        outputs_dict = {'classified_edges': [graph.edge_preds]}

        return outputs_dict



In [None]:
config = {

    'num_message_passing': 6,
    'out_node_dim' : 128,

    'reattach_initial_nodes': True,
    'use_pos_enc': True,
    'use_self_edges': True,

    'vis_enc_params': {
        'input_dim': 2048,
        'fc_dims': [256, 128]
        },

    # pos_enc_params: 
    #     M: 1      # Input Dim [bbox_start; bbox_end] Each of the features will be considered in differeng group
    #     D: 128    # Output Dim must equal to dim of visual features !! D%G = 0 !! 
    #     G: 8      # Number of Groups
    #     F: 128    # Fourier Feature Dim
    #     H: 32     # Hidden Dim 
    #     gamma: 4  # Initializer's std

    # pos_enc_params: 
    #     M: 4      # Input Dim [bbox_start; bbox_end] Each of the features will be considered in differeng group
    #     D: 128    # Output Dim must equal to dim of visual features !! D%G = 0 !! 
    #     G: 2      # Number of Groups
    #     F: 128    # Fourier Feature Dim
    #     H: 32     # Hidden Dim 
    #     gamma: 0.01  # Initializer's std

    'pos_enc_params': 
        {'M': 1 ,     # Input Dim [bbox_start; bbox_end] Each of the features will be considered in differeng group
        'D': 64 ,   # Output Dim must equal to dim of visual features !! D%G = 0 !! 
        'G': 4  ,    # Number of Groups
        'F': 128 ,   # Fourier Feature Dim
        'H': 32 ,    # Hidden Dim 
        'gamma': 1},  # Initializer's std

    'gnn_params':{
        'in_channels': 128,
        'out_channels': 16,
        'heads': 8,
        'concat': True,
        'negative_slope': 0.2,
        'dropout': 0,
        'bias': True,
        'share_weights': False}

}


In [None]:
torch.finfo(torch.float).min

In [None]:
a = torch.tensor([0.2,0.8,torch.finfo(torch.float).min])
res = F.log_softmax(a, dim=0)
res

In [None]:
res.exp()

In [None]:
model= LightGlueMOT(config)

In [None]:
scores  = model(graph)

In [None]:
preds = scores['classified_edges'][0]

In [None]:
preds.max()

In [None]:
torch.isnan(preds).sum()

In [None]:
preds.shape

In [None]:
scores = scores.exp()

In [None]:
scores

In [None]:
import matplotlib.pyplot as plt 

plt.imshow(scores.squeeze().detach().numpy()[825:850, 820:850], cmap='viridis')
plt.colorbar()

In [None]:
import matplotlib.pyplot as plt 
import numpy as np
gt = np.zeros((21,21))
ee = graph[0].edge_index[:, :graph[0].edge_index.shape[1]//2]
left = ee[0, (ee<20)[0,:]]
r = ee[1, (ee<20)[1,:]]
gt[left, r]
plt.imshow(gt, cmap='viridis')
plt.colorbar()

In [None]:
ee[1,:]

In [None]:
gt.shape

In [None]:
ee[:, (ee<20)[0,:]]

In [None]:
ee

In [None]:
plt.plot(sim[:50, :50].flatten().detach().numpy())

In [None]:
res['classified_edges'][0].exp()

In [None]:
sim = torch.tensor([[1,0,0],[0,1,0],[0,0,1]], dtype=torch.float)
sim

In [None]:
F.log_softmax(torch.tensor([-torch.inf,1.]),dim=0)

In [None]:
sim[sim==0] = -torch.inf
sim

In [None]:
sigmoid_log_double_softmax(sim).exp()

In [None]:
res[0]/res[2]

In [None]:
model.pos_enc.Wr.weight

In [None]:
x, edge_index, x_box_start, x_box_end = batch.x, batch.edge_index, batch.x_box_start, batch.x_box_end

In [None]:
time_enc = PositionalEncoding(128, 0., 128)


In [None]:
t1_enc = time_enc(batch.x_frame[0])
t2_enc = time_enc(batch.x_frame[1])

In [None]:
time_diff = t2_enc[edge_index[1]] - t1_enc[edge_index[0]]

bbox_starts = x_box_start[edge_index[1]]
bbox_ends = x_box_end[edge_index[0]]

In [None]:
res = model(time_diff, bbox_starts, bbox_ends)

In [None]:
plt.plot(sorted(res.detach().squeeze().numpy()))

In [None]:
model = simple()

In [None]:
batch = graph[0]

In [None]:
batch.x_frame[0].shape

In [None]:
start = batch.x_frame[0]
end = batch.x_frame[0]

In [None]:
start = start - start.min()

In [None]:
end = end - end.min()

In [None]:
import torch.nn as nn 

class LearnableFourierFeatures(nn.Module):
    def __init__(self, 
                 M,  # M: Dimension of the Positions   / Input dim
                 D,  # D: Depth of Positional Encoding /Output dim
                 G,  # G: Number of Groups
                 F,  # F: Fourier Feature Dimension
                 H,  # H: Hidden Layer Dimension
                 gamma
                ):
        super().__init__()
        self.F = F
        self.D = D
        self.gamma=gamma

        self.Wr = nn.Linear(M, F//2, bias=False)

        self.mlp = nn.Sequential(
            nn.Linear(F, H),
            nn.GELU(),
            nn.Linear(H, D//G)
        )
        
        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
    
    def forward(self, x):
        '''
        Produce positional encodings from x
        :param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M],
                  where G is the positional group and each group has M-dimensional positional values.
                  Positions in different positional groups are independent
        :return: positional encoding for X
        '''
        N, G, M = x.shape

        F = self.Wr(x)
        F = torch.cat([torch.cos(F), torch.sin(F)], dim=2)/torch.sqrt(torch.tensor(self.F))      
        Y = self.mlp(F)
        PEx =  Y.reshape((N, self.D))
        return PEx 

In [None]:
batch.x_box_start

In [None]:
batch.x_box_end

In [None]:
start

In [None]:
end

In [None]:
pos_enc_params={ 
    'M': 1    ,  
    'D': 128   ,
    'G': 4      ,
    'F': 128    ,
    'H': 32     ,
    'gamma': 0.01,  
}

In [None]:
bbox_enc1 = LearnableFourierFeatures(**pos_enc_params)
bbox_enc2 = LearnableFourierFeatures(**pos_enc_params)
time_enc1 = LearnableFourierFeatures(**pos_enc_params)
time_enc2 = LearnableFourierFeatures(**pos_enc_params)

In [None]:
bb1 = bbox_enc1(batch.x_box_start.unsqueeze(-1))
bb2 = bbox_enc1(batch.x_box_start.unsqueeze(-1))
t1  = time_enc1(batch.x_box_start.unsqueeze(-1))
t2  = time_enc1(batch.x_box_start.unsqueeze(-1))

In [None]:
import torch.nn as nn 

class simple_m(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(384, 128),
            nn.GELU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.mlp(x)

In [None]:
x = torch.cat([t1-t2, bb1, bb2], dim=1)
x.shape

In [None]:
model = simple_m()
model(x)