In [5]:
import torch
from torch.nn import functional as F

In [6]:
L = 5
d = 8

z = torch.randn((L, d))
z.shape

torch.Size([5, 8])

In [12]:
def LT(x, target_length, sigma=1.0):
    # Compute pairwise distance in the feature space.
    x_norm = x.norm(dim=1)
    lx = x_norm.size(0)
    a = torch.empty(target_length, lx)

    # Calculate the attention weights based on the negative squared distance.
    for j in range(lx):
        for k in range(lx):
            a[j, k] = -1/(2*sigma**2) * (k - j)**2

    # Compute the softmax of the attention weights along the rows.
    w = F.softmax(a, dim=1)

    # Compute the weighted sum.
    z = torch.matmul(w, x)

    return z

L = 5
d = 8

z = torch.randn((L, d))
z.shape


# Call the function
z_bar = LT(z, target_length=7)
z_bar.shape

torch.Size([7, 8])

In [3]:
import torch
import torch.nn.functional as F

def LT(x, target_length, sigma=1.0):
    # x is expected to have shape (length, batch, dim)
    length, batch, dim = x.shape
    
    # Creating the squared distance matrix for a range of indices
    j = torch.arange(target_length).unsqueeze(1).repeat(1, length)   # Shape: (target_length, length)
    k = torch.arange(length).unsqueeze(0).repeat(target_length, 1)  # Shape: (target_length, length)
    squared_distance = -((k - j)**2) / (2 * sigma**2)               # Shape: (target_length, length)
    
    # Broadcasting squared_distance across batch dimension
    a = squared_distance.unsqueeze(2).repeat(1, 1, batch)           # Shape: (target_length, length, batch)
    
    # Softmax across the 'length' dimension
    w = F.softmax(a, dim=1)
    
    # Re-arranging x to perform batched matrix multiplication: (batch, dim, length)
    x_perm = x.permute(1, 2, 0)
    
    # Matrix multiplication along the specified dimensions: (batch, dim, target_length)
    z = torch.bmm(x_perm, w.permute(2, 1, 0))
    
    # Re-arrange back to the desired output shape: (target_length, batch, dim)
    z = z.permute(2, 0, 1)

    return z

# Example usage
L = 35
d = 512
b = 128  # batch size

# Random tensor with (length, batch, dim)
x = torch.randn((L, b, d))

# Call the function
z_bar = LT(x, target_length=10)
print(z_bar.shape)  # Expected shape: (7, b, d)

torch.Size([10, 128, 512])
