In [1]:
import torch
import torch.nn as nn

In [2]:
def splice_future_indices(target_tokens, padding_token, mult_factor, max_iter):
    batch_size, max_seq_len = target_tokens.size()
    matrices = []

    length = mult_factor
    tot_length = 1 + length 
    j = 0

    while (tot_length <= max_seq_len) and (j < max_iter):
        matrix = []
        for i in range(max_seq_len):
            subseq = target_tokens[:, i+1:i+1+length]  # slice the target tokens
            
            # If the subsequence is shorter than the required length, pad it with padding_token
            if subseq.size(1) < length:
                padding = torch.full((batch_size, length - subseq.size(1)), padding_token, dtype=torch.long)
                subseq = torch.cat([subseq, padding], dim=1)
            
            matrix.append(subseq)
        
        matrices.append(torch.stack(matrix, dim=1))
        
        length *= mult_factor
        tot_length += length
        j += 1

    return matrices

In [3]:
# Example usage
batch_size = 2
max_seq_len = 15
mult_factor = 2
max_iter = 8
padding_token = 0  # or 'v' if you have a specific padding token
target_tokens = torch.tensor([
    [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    [17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
])

matrices = splice_future_indices(target_tokens, padding_token, mult_factor, max_iter)

# Print the matrices to verify
for idx, matrix in enumerate(matrices):
    print(f"Matrix {idx+1}:\n{matrix}\n")

Matrix 1:
tensor([[[ 3,  4],
         [ 4,  5],
         [ 5,  6],
         [ 6,  7],
         [ 7,  8],
         [ 8,  9],
         [ 9, 10],
         [10, 11],
         [11, 12],
         [12, 13],
         [13, 14],
         [14, 15],
         [15, 16],
         [16,  0],
         [ 0,  0]],

        [[18, 19],
         [19, 20],
         [20, 21],
         [21, 22],
         [22, 23],
         [23, 24],
         [24, 25],
         [25, 26],
         [26, 27],
         [27, 28],
         [28, 29],
         [29, 30],
         [30, 31],
         [31,  0],
         [ 0,  0]]])

Matrix 2:
tensor([[[ 3,  4,  5,  6],
         [ 4,  5,  6,  7],
         [ 5,  6,  7,  8],
         [ 6,  7,  8,  9],
         [ 7,  8,  9, 10],
         [ 8,  9, 10, 11],
         [ 9, 10, 11, 12],
         [10, 11, 12, 13],
         [11, 12, 13, 14],
         [12, 13, 14, 15],
         [13, 14, 15, 16],
         [14, 15, 16,  0],
         [15, 16,  0,  0],
         [16,  0,  0,  0],
         [ 0,  0,  0,  0]],

In [4]:
# Example usage
batch_size = 32
max_seq_len = 512
mult_factor = 2
max_iter = 6
vocab_size = 8192  
embedding_dim = 128 
target_tokens = torch.randint(vocab_size, (batch_size, max_seq_len))

matrices = splice_future_indices(target_tokens, padding_token, mult_factor, max_iter)

# Print the matrices to verify
for idx, matrix in enumerate(matrices):
    print(f"Matrix {idx+1}: {matrix.shape}")

Matrix 1: torch.Size([32, 512, 2])
Matrix 2: torch.Size([32, 512, 4])
Matrix 3: torch.Size([32, 512, 8])
Matrix 4: torch.Size([32, 512, 16])
Matrix 5: torch.Size([32, 512, 32])
Matrix 6: torch.Size([32, 512, 64])


In [5]:
# Initialize the embedding layer
embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_token)

# Convert matrices to embeddings
embedded_matrices = []
for matrix in matrices:
    embedded_matrix = embedding(matrix)  # Shape: (batch_size, max_seq_len, subseq_length, embedding_dim)
    embedded_matrices.append(embedded_matrix)

# Print the shapes of the embedded matrices to verify
for idx, embedded_matrix in enumerate(embedded_matrices):
    print(f"Embedded Matrix {idx+1} shape: {embedded_matrix.shape}")

Embedded Matrix 1 shape: torch.Size([32, 512, 2, 128])
Embedded Matrix 2 shape: torch.Size([32, 512, 4, 128])
Embedded Matrix 3 shape: torch.Size([32, 512, 8, 128])
Embedded Matrix 4 shape: torch.Size([32, 512, 16, 128])
Embedded Matrix 5 shape: torch.Size([32, 512, 32, 128])
Embedded Matrix 6 shape: torch.Size([32, 512, 64, 128])


In [6]:
import math

class CompressionSchedule:
    def __init__(self, compress_freq, compress_freq_n):
        self.compress_freq = compress_freq
        self.compress_freq_n = compress_freq_n
        
        if compress_freq_n < 1: 
            raise ValueError(f"compress_freq_n must be >= 1")
        if compress_freq == 'log' and compress_freq_n < 2: 
            raise ValueError(f"if using compress_freq=='log' then compress_freq_n must be >= 2")

    def __call__(self, i: int) -> int:
        if self.compress_freq == 'constant':
            return self.constant(i)
        elif self.compress_freq == 'linear':
            return self.linear(i)
        elif self.compress_freq == 'root':
            return self.root(i)
        elif self.compress_freq == 'log':
            return self.log(i)
        elif self.compress_freq == 'poly':
            return self.poly(i)
        else:
            raise ValueError(f"Invalid compression frequency type. {self.compress_freq} is unknkown")

    def constant(self, i: int) -> int:
        return math.floor(self.compress_freq_n)

    def linear(self, i: int) -> int:
        return math.floor(self.compress_freq_n * i + 1)

    def root(self, i: int) -> int:
        return math.floor((i+1) ** (1 / self.compress_freq_n))

    def log(self, i: int) -> int:
        return math.floor(math.log((i+1), self.compress_freq_n))+1

    def poly(self, i: int) -> int:
        return math.floor((i+1) ** self.compress_freq_n)

# Example usage:
scheddy = CompressionSchedule(compress_freq='linear', compress_freq_n=1)

# Forward function
for i in range(16):
    output = scheddy(i)
    print(f"Output for i={i}: {output}")


Output for i=0: 1
Output for i=1: 2
Output for i=2: 3
Output for i=3: 4
Output for i=4: 5
Output for i=5: 6
Output for i=6: 7
Output for i=7: 8
Output for i=8: 9
Output for i=9: 10
Output for i=10: 11
Output for i=11: 12
Output for i=12: 13
Output for i=13: 14
Output for i=14: 15
Output for i=15: 16


In [7]:
from modules.pool_operations import *
from modules.norm import Norm

In [8]:
input_norm = Norm(dim=embedding_dim)
module = MaxPooling()
output_norm = Norm(dim=embedding_dim)

# Initially, logging is disabled by default
module.enable_logging()
### Optionally disabling printing for sub-functions
#module.disable_function_logging('')

c = torch.concat(
    [output_norm(
        module(
            input_norm(
                embedded_matrix
            )
        )
    ).unsqueeze(2) for embedded_matrix in embedded_matrices], 
    dim=2
)

print(c.shape)
del input_norm, module, output_norm, c


Inputs:
Tensor 'x' shape: torch.Size([32, 512, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 8, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 64, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 128])
torch.Size([32, 512, 6, 128])


In [9]:
input_norm = Norm(dim=embedding_dim)
module = SumPooling()
output_norm = Norm(dim=embedding_dim)

c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))).unsqueeze(2) for embedded_matrix in embedded_matrices], 
    dim=2
)

print(c.shape)
del input_norm, module, output_norm, c

torch.Size([32, 512, 6, 128])


In [10]:
input_norm = Norm(dim=embedding_dim)
module = ParametricSumPooling(embedding_dim, output_seq_len=1, use_output_linear=True)
output_norm = Norm(dim=embedding_dim)

# let's take a look
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()
### Optionally disabling printing for sub-functions
#module.disable_function_logging('')

c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))) for embedded_matrix in embedded_matrices], 
    dim=2
)

print(c.shape)
del input_norm, module, output_norm, c

32.768 K parameters
ParametricSumPooling(
  (linears): ModuleList(
    (0): Linear(in_features=128, out_features=128, bias=False)
  )
  (out): Linear(in_features=128, out_features=128, bias=False)
)

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 8, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 64, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])
torch.Size([32, 512, 6, 128])


In [11]:
# define our compression rate scheduler
scheddy = CompressionSchedule(compress_freq='linear', compress_freq_n=1)

input_norm = Norm(dim=embedding_dim)
output_norm = Norm(dim=embedding_dim)

# Create a list of ParametricMaxPooling modules with different output_seq_len values
modules = [ParametricMaxPooling(embedding_dim, output_seq_len=scheddy(i), use_output_linear=False) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")
print(modules[-1])

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [module(embedded_matrix) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del scheddy, input_norm, modules, output_norm, c

Module 0: 16.384 K parameters
Module 1: 32.768 K parameters
Module 2: 49.152 K parameters
Module 3: 65.536 K parameters
Module 4: 81.92 K parameters
Module 5: 98.304 K parameters
ParametricMaxPooling(
  (linears): ModuleList(
    (0-5): 6 x Linear(in_features=128, out_features=128, bias=False)
  )
)

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 2, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 8, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 3, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 4, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 5, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 64, 128])

Outputs:
Tensor 'output' shape: torch.

In [12]:
# define our compression rate scheduler
scheddy = CompressionSchedule(compress_freq='constant', compress_freq_n=1)

input_norm = Norm(dim=embedding_dim)
output_norm = Norm(dim=embedding_dim)

# Create a list of FlattenProjectionPooling modules with different output_seq_len values
modules = [FlattenProjectionPooling(
    to_be_pooled_seq_len = embedded_matrices[i].shape[2],
    dim = embedding_dim, 
    output_seq_len=scheddy(i)
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del scheddy, input_norm, output_norm, module, c

Module 0: 32.768 K parameters
Module 1: 65.536 K parameters
Module 2: 131.072 K parameters
Module 3: 262.144 K parameters
Module 4: 524.288 K parameters
Module 5: 1048.576 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 8, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 64, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])
torch.Size([32, 512, 6, 128])


In [13]:
# define our compression rate scheduler
scheddy = CompressionSchedule(compress_freq='root', compress_freq_n=2)

input_norm = Norm(dim=embedding_dim)
output_norm = Norm(dim=embedding_dim)

# Create a list of ConvPooling modules with different output_seq_len values
modules = [ConvPooling(
    to_be_pooled_seq_len = embedded_matrices[i].shape[2],
    dim = embedding_dim, 
    output_seq_len = scheddy(i), 
    use_output_linear=False
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")
print(modules[-1])

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del scheddy, input_norm, output_norm, module, c

Module 0: 32.896 K parameters
Module 1: 65.664 K parameters
Module 2: 131.2 K parameters
Module 3: 524.544 K parameters
Module 4: 1048.832 K parameters
Module 5: 2097.408 K parameters
ConvPooling(
  (convs): ModuleList(
    (0-1): 2 x Conv1d(128, 128, kernel_size=(64,), stride=(1,))
  )
)

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 8, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 2, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 2, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 64, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 5

In [21]:
# define our compression rate scheduler
scheddy = CompressionSchedule(compress_freq='poly', compress_freq_n=1.25)

input_norm = Norm(dim=embedding_dim)
output_norm = Norm(dim=embedding_dim)

# Create a list of AttentionPooling modules with different output_seq_len values
modules = [AttentionPooling(
    dim = embedding_dim, 
    output_seq_len = scheddy(i), 
    use_output_linear=False
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del scheddy, input_norm, output_norm, module, c

Module 0: 0.128 K parameters
Module 1: 0.256 K parameters
Module 2: 0.384 K parameters
Module 3: 0.64 K parameters
Module 4: 0.896 K parameters
Module 5: 1.152 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 2, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 8, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 3, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 5, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 7, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 64, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 9, 128])
torch.Size([32, 512, 27, 128])


In [19]:
# define our compression rate scheduler
scheddy = CompressionSchedule(compress_freq='log', compress_freq_n=2)

input_norm = Norm(dim=embedding_dim)
output_norm = Norm(dim=embedding_dim)

# Create a list of SelfAttentionPooling modules with different output_seq_len values
modules = [SelfAttentionPooling(
    dim = embedding_dim, 
    output_seq_len = scheddy(i)
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del scheddy, input_norm, output_norm, module, c

Module 0: 32.896 K parameters
Module 1: 33.024 K parameters
Module 2: 33.024 K parameters
Module 3: 33.152 K parameters
Module 4: 33.152 K parameters
Module 5: 33.152 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 2, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 1, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 4, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 2, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 8, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 2, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 16, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 3, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 32, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 3, 128])

Inputs:
Tensor 'x' shape: torch.Size([32, 512, 64, 128])

Outputs:
Tensor 'output' shape: torch.Size([32, 512, 3, 128])
torch.Size([32, 512, 14, 128])
