In [1]:
import torch
import torch.nn as nn

In [2]:
def splice_future_indices(target_tokens, padding_token, mult_factor, max_iter):
    batch_size, max_seq_len = target_tokens.size()
    matrices = []

    length = mult_factor
    tot_length = 1 + length 
    j = 0

    while (tot_length <= max_seq_len) and (j < max_iter):
        matrix = []
        for i in range(max_seq_len):
            subseq = target_tokens[:, i+1:i+1+length]  # slice the target tokens
            
            # If the subsequence is shorter than the required length, pad it with padding_token
            if subseq.size(1) < length:
                padding = torch.full((batch_size, length - subseq.size(1)), padding_token, dtype=torch.long)
                subseq = torch.cat([subseq, padding], dim=1)
            
            matrix.append(subseq)
        
        matrices.append(torch.stack(matrix, dim=1))
        
        length *= mult_factor
        tot_length += length
        j += 1

    return matrices

In [3]:
# Example usage
batch_size = 2
max_seq_len = 15
mult_factor = 2
max_iter = 8
padding_token = 0  # or 'v' if you have a specific padding token
target_tokens = torch.tensor([
    [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    [17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
])

matrices = splice_future_indices(target_tokens, padding_token, mult_factor, max_iter)

# Print the matrices to verify
for idx, matrix in enumerate(matrices):
    print(f"Matrix {idx+1}:\n{matrix}\n")

Matrix 1:
tensor([[[ 3,  4],
         [ 4,  5],
         [ 5,  6],
         [ 6,  7],
         [ 7,  8],
         [ 8,  9],
         [ 9, 10],
         [10, 11],
         [11, 12],
         [12, 13],
         [13, 14],
         [14, 15],
         [15, 16],
         [16,  0],
         [ 0,  0]],

        [[18, 19],
         [19, 20],
         [20, 21],
         [21, 22],
         [22, 23],
         [23, 24],
         [24, 25],
         [25, 26],
         [26, 27],
         [27, 28],
         [28, 29],
         [29, 30],
         [30, 31],
         [31,  0],
         [ 0,  0]]])

Matrix 2:
tensor([[[ 3,  4,  5,  6],
         [ 4,  5,  6,  7],
         [ 5,  6,  7,  8],
         [ 6,  7,  8,  9],
         [ 7,  8,  9, 10],
         [ 8,  9, 10, 11],
         [ 9, 10, 11, 12],
         [10, 11, 12, 13],
         [11, 12, 13, 14],
         [12, 13, 14, 15],
         [13, 14, 15, 16],
         [14, 15, 16,  0],
         [15, 16,  0,  0],
         [16,  0,  0,  0],
         [ 0,  0,  0,  0]],

In [4]:
# Example usage
batch_size = 32
max_seq_len = 128
mult_factor = 2
max_iter = 8
vocab_size = 8192  
embedding_dim = 192 
target_tokens = torch.randint(vocab_size, (batch_size, max_seq_len))

matrices = splice_future_indices(target_tokens, padding_token, mult_factor, max_iter)

# Print the matrices to verify
for idx, matrix in enumerate(matrices):
    print(f"Matrix {idx+1}: {matrix.shape}")

Matrix 1: torch.Size([32, 128, 2])
Matrix 2: torch.Size([32, 128, 4])
Matrix 3: torch.Size([32, 128, 8])
Matrix 4: torch.Size([32, 128, 16])
Matrix 5: torch.Size([32, 128, 32])
Matrix 6: torch.Size([32, 128, 64])


In [5]:
# Initialize the embedding layer
embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_token)

# Convert matrices to embeddings
embedded_matrices = []
for matrix in matrices:
    embedded_matrix = embedding(matrix)  # Shape: (batch_size, max_seq_len, subseq_length, embedding_dim)
    embedded_matrices.append(embedded_matrix)

# Print the shapes of the embedded matrices to verify
for idx, embedded_matrix in enumerate(embedded_matrices):
    print(f"Embedded Matrix {idx+1} shape: {embedded_matrix.shape}")

Embedded Matrix 1 shape: torch.Size([32, 128, 2, 192])
Embedded Matrix 2 shape: torch.Size([32, 128, 4, 192])
Embedded Matrix 3 shape: torch.Size([32, 128, 8, 192])
Embedded Matrix 4 shape: torch.Size([32, 128, 16, 192])
Embedded Matrix 5 shape: torch.Size([32, 128, 32, 192])
Embedded Matrix 6 shape: torch.Size([32, 128, 64, 192])


In [6]:
from modules.pooling import *
from modules.norm import Norm

In [7]:
module = MaxPooling()
norm = Norm(dim=embedding_dim)

# Initially, logging is disabled by default
module.enable_logging()
### Optionally disabling printing for sub-functions
#module.disable_function_logging('')

c = torch.concat(
    [norm(module(embedded_matrix)).unsqueeze(2) for embedded_matrix in embedded_matrices], 
    dim=2
)

print(c.shape)
del module, norm, c


Inputs:
Tensor 'x' shape: torch.Size([32, 128, 2, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 8, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 192])
torch.Size([32, 128, 6, 192])


In [8]:
module = SumPooling()
norm = Norm(dim=embedding_dim, affine=False)

c = torch.concat(
    [norm(module(embedded_matrix)).unsqueeze(2) for embedded_matrix in embedded_matrices], 
    dim=2
)

print(c.shape)
del module, norm, c

cannot have both affine==False and bias==True. Skipping bias
torch.Size([32, 128, 6, 192])


In [9]:
module = ParametricSumPooling(embedding_dim, output_seq_len=1, use_output_layer=True)

# let's take a look
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()
### Optionally disabling printing for sub-functions
#module.disable_function_logging('')

c = torch.concat(
    [module(embedded_matrix) for embedded_matrix in embedded_matrices], 
    dim=2
)

print(c.shape)
del module, c

74.112 K parameters
ParametricSumPooling(
  (linears): ModuleList(
    (0): Linear(in_features=192, out_features=192, bias=True)
  )
  (out): Linear(in_features=192, out_features=192, bias=True)
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 2, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 8, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])
torch.Size([32, 128, 6, 192])


In [10]:
# Create a list of ParametricMaxPooling modules with different output_seq_len values
modules = [ParametricMaxPooling(embedding_dim, output_seq_len=i+1, use_output_layer=True) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [module(embedded_matrix) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del module, c

Module 0: 74.112 K parameters
Module 1: 111.168 K parameters
Module 2: 148.224 K parameters
Module 3: 185.28 K parameters
Module 4: 222.336 K parameters
Module 5: 259.392 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 2, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 2, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 8, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 3, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 4, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 5, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 6, 192])
torch.Size([32, 128, 21, 192])


In [11]:
# Create a list of FlattenProjectionPooling modules with different output_seq_len values
modules = [FlattenProjectionPooling(
    to_be_pooled_seq_len = embedded_matrices[i].shape[2],
    embed_dim = embedding_dim, 
    output_seq_len=2*i+1, 
    use_output_layer=False
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [module(embedded_matrix) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del module, c

Module 0: 73.92 K parameters
Module 1: 442.944 K parameters
Module 2: 1475.52 K parameters
Module 3: 4130.112 K parameters
Module 4: 10618.56 K parameters
Module 5: 25954.368 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 2, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 3, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 8, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 5, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 7, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 9, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 11, 192])
torch.Size([32, 128, 36, 192])


In [12]:
import math

# Create a list of ConvPooling modules with different output_seq_len values
modules = [ConvPooling(
    to_be_pooled_seq_len = embedded_matrices[i].shape[2],
    embed_dim = embedding_dim, 
    output_seq_len = math.floor(math.sqrt(i + 1)), 
    use_output_layer=False
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [module(embedded_matrix) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del module, c

Module 0: 73.92 K parameters
Module 1: 147.648 K parameters
Module 2: 295.104 K parameters
Module 3: 1180.032 K parameters
Module 4: 2359.68 K parameters
Module 5: 4718.976 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 2, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 8, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 2, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 2, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 2, 192])
torch.Size([32, 128, 9, 192])


In [13]:
# Create a list of AttentionPooling modules with different output_seq_len values
modules = [AttentionPooling(
    embed_dim = embedding_dim, 
    output_seq_len = i**2+1, 
    use_output_layer=False
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [module(embedded_matrix) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del module, c

Module 0: 0.193 K parameters
Module 1: 0.386 K parameters
Module 2: 0.965 K parameters
Module 3: 1.93 K parameters
Module 4: 3.281 K parameters
Module 5: 5.018 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 2, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 2, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 8, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 5, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 10, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 17, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 26, 192])
torch.Size([32, 128, 61, 192])


In [14]:
# Create a list of SelfAttentionPooling modules with different output_seq_len values
modules = [SelfAttentionPooling(
    embed_dim = embedding_dim, 
    output_seq_len = 1
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [module(embedded_matrix) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del module, c

Module 0: 111.36 K parameters
Module 1: 111.36 K parameters
Module 2: 111.36 K parameters
Module 3: 111.36 K parameters
Module 4: 111.36 K parameters
Module 5: 111.36 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 2, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 8, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 32, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 192])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 192])
torch.Size([32, 128, 6, 192])
