In [1]:
import torch
import torch.nn as nn

In [2]:
from modules.loss import splice_future_indices, create_multi_hot_vector

In [3]:
# Example usage
batch_size = 2
max_seq_len = 15
mult_factor = 2
max_iter = 8
padding_token = 0  # or 'v' if you have a specific padding token
target_tokens = torch.tensor([
    [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    [17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
])

matrices = splice_future_indices(target_tokens, padding_token, mult_factor, max_iter)

# Print the matrices to verify
for idx, matrix in enumerate(matrices):
    print(f"Matrix {idx+1}:\n{matrix}\n")

Matrix 1:
tensor([[[ 3,  4],
         [ 4,  5],
         [ 5,  6],
         [ 6,  7],
         [ 7,  8],
         [ 8,  9],
         [ 9, 10],
         [10, 11],
         [11, 12],
         [12, 13],
         [13, 14],
         [14, 15],
         [15, 16],
         [16,  0],
         [ 0,  0]],

        [[18, 19],
         [19, 20],
         [20, 21],
         [21, 22],
         [22, 23],
         [23, 24],
         [24, 25],
         [25, 26],
         [26, 27],
         [27, 28],
         [28, 29],
         [29, 30],
         [30, 31],
         [31,  0],
         [ 0,  0]]])

Matrix 2:
tensor([[[ 3,  4,  5,  6],
         [ 4,  5,  6,  7],
         [ 5,  6,  7,  8],
         [ 6,  7,  8,  9],
         [ 7,  8,  9, 10],
         [ 8,  9, 10, 11],
         [ 9, 10, 11, 12],
         [10, 11, 12, 13],
         [11, 12, 13, 14],
         [12, 13, 14, 15],
         [13, 14, 15, 16],
         [14, 15, 16,  0],
         [15, 16,  0,  0],
         [16,  0,  0,  0],
         [ 0,  0,  0,  0]],

In [4]:
# Example usage
batch_size = 32
max_seq_len = 128
mult_factor = 4
max_iter = 3
vocab_size = 2048  
embedding_dim = 64
target_tokens = torch.randint(vocab_size, (batch_size, max_seq_len))

matrices = splice_future_indices(target_tokens, padding_token, mult_factor, max_iter)

# Print the matrices to verify
for idx, matrix in enumerate(matrices):
    print(f"Matrix {idx+1}: {matrix.shape}")

Matrix 1: torch.Size([32, 128, 4])
Matrix 2: torch.Size([32, 128, 16])
Matrix 3: torch.Size([32, 128, 64])


In [5]:
# Initialize the embedding layer
embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_token)

# Convert matrices to embeddings
embedded_matrices = []
for matrix in matrices:
    embedded_matrix = embedding(matrix)  # Shape: (batch_size, max_seq_len, subseq_length, embedding_dim)
    embedded_matrices.append(embedded_matrix)

# Print the shapes of the embedded matrices to verify
for idx, embedded_matrix in enumerate(embedded_matrices):
    print(f"Embedded Matrix {idx+1} shape: {embedded_matrix.shape}")

Embedded Matrix 1 shape: torch.Size([32, 128, 4, 64])
Embedded Matrix 2 shape: torch.Size([32, 128, 16, 64])
Embedded Matrix 3 shape: torch.Size([32, 128, 64, 64])


In [7]:
import math

from modules.pool_mech import CompressionSchedule

# Example usage:
scheddy = CompressionSchedule(compress_freq='linear', compress_freq_n=1)

# Forward function
for i in range(max_iter):
    output = scheddy(i)
    print(f"Output for i={i}: {output}")


Output for i=0: 1
Output for i=1: 2
Output for i=2: 3


In [9]:
from modules.pool_ops import *
from modules.norm import Norm

In [10]:
input_norm = Norm(dim=embedding_dim)
module = MaxPooling()
output_norm = Norm(dim=embedding_dim)

# Initially, logging is disabled by default
module.enable_logging()
### Optionally disabling printing for sub-functions
#module.disable_function_logging('')

c = torch.concat(
    [output_norm(
        module(
            input_norm(
                embedded_matrix
            )
        )
    ).unsqueeze(2) for embedded_matrix in embedded_matrices], 
    dim=2
)

print(c.shape)
del input_norm, module, output_norm, c


Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])
torch.Size([32, 128, 3, 1, 64])


In [11]:
input_norm = Norm(dim=embedding_dim)
module = SumPooling()
output_norm = Norm(dim=embedding_dim)

c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))).unsqueeze(2) for embedded_matrix in embedded_matrices], 
    dim=2
)

print(c.shape)
del input_norm, module, output_norm, c

torch.Size([32, 128, 3, 1, 64])


In [12]:
input_norm = Norm(dim=embedding_dim)
module = ParametricSumPooling(embedding_dim, output_seq_len=1, use_output_linear=True)
output_norm = Norm(dim=embedding_dim)

# let's take a look
print(sum(p.numel() for p in module.parameters())/1e3, 'K parameters')
print(module)

# Initially, logging is disabled by default
module.enable_logging()
### Optionally disabling printing for sub-functions
#module.disable_function_logging('')

c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))) for embedded_matrix in embedded_matrices], 
    dim=2
)

print(c.shape)
del input_norm, module, output_norm, c

8.192 K parameters
ParametricSumPooling(
  (linears): ModuleList(
    (0): Linear(in_features=64, out_features=64, bias=False)
  )
  (out): Linear(in_features=64, out_features=64, bias=False)
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])
torch.Size([32, 128, 3, 64])


In [13]:
# define our compression rate scheduler
scheddy = CompressionSchedule(compress_freq='linear', compress_freq_n=1)

input_norm = Norm(dim=embedding_dim)
output_norm = Norm(dim=embedding_dim)

# Create a list of ParametricMaxPooling modules with different output_seq_len values
modules = [ParametricMaxPooling(embedding_dim, output_seq_len=scheddy(i), use_output_linear=False) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")
print(modules[-1])

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [module(embedded_matrix) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del scheddy, input_norm, modules, output_norm, c

Module 0: 4.096 K parameters
Module 1: 8.192 K parameters
Module 2: 12.288 K parameters
ParametricMaxPooling(
  (linears): ModuleList(
    (0-2): 3 x Linear(in_features=64, out_features=64, bias=False)
  )
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 2, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 3, 64])
torch.Size([32, 128, 6, 64])


In [14]:
# define our compression rate scheduler
scheddy = CompressionSchedule(compress_freq='constant', compress_freq_n=1)

input_norm = Norm(dim=embedding_dim)
output_norm = Norm(dim=embedding_dim)

# Create a list of FlattenProjectionPooling modules with different output_seq_len values
modules = [FlattenProjectionPooling(
    to_be_pooled_seq_len = embedded_matrices[i].shape[2],
    dim = embedding_dim, 
    output_seq_len=scheddy(i)
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del scheddy, input_norm, output_norm, module, c

Module 0: 16.384 K parameters
Module 1: 65.536 K parameters
Module 2: 262.144 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])
torch.Size([32, 128, 3, 64])


In [15]:
# define our compression rate scheduler
scheddy = CompressionSchedule(compress_freq='root', compress_freq_n=2)

input_norm = Norm(dim=embedding_dim)
output_norm = Norm(dim=embedding_dim)

# Create a list of ConvPooling modules with different output_seq_len values
modules = [ConvPooling(
    to_be_pooled_seq_len = embedded_matrices[i].shape[2],
    dim = embedding_dim, 
    output_seq_len = scheddy(i), 
    use_output_linear=False
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")
print(modules[-1])

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del scheddy, input_norm, output_norm, module, c

Module 0: 16.384 K parameters
Module 1: 65.536 K parameters
Module 2: 262.144 K parameters
ConvPooling(
  (convs): ModuleList(
    (0): Conv1d(64, 64, kernel_size=(64,), stride=(1,), bias=False)
  )
)

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])
torch.Size([32, 128, 3, 64])


In [16]:
# define our compression rate scheduler
scheddy = CompressionSchedule(compress_freq='poly', compress_freq_n=1.25)

input_norm = Norm(dim=embedding_dim)
output_norm = Norm(dim=embedding_dim)

# Create a list of AttentionPooling modules with different output_seq_len values
modules = [AttentionPooling(
    dim = embedding_dim, 
    output_seq_len = scheddy(i), 
    use_output_linear=False
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del scheddy, input_norm, output_norm, module, c

Module 0: 0.064 K parameters
Module 1: 0.128 K parameters
Module 2: 0.192 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 2, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 3, 64])
torch.Size([32, 128, 6, 64])


In [17]:
# define our compression rate scheduler
scheddy = CompressionSchedule(compress_freq='log', compress_freq_n=2)

input_norm = Norm(dim=embedding_dim)
output_norm = Norm(dim=embedding_dim)

# Create a list of SelfAttentionPooling modules with different output_seq_len values
modules = [SelfAttentionPooling(
    dim = embedding_dim, 
    output_seq_len = scheddy(i)
) for i in range(len(embedded_matrices))]

# Print the total number of parameters in each module
for idx, module in enumerate(modules):
    print(f"Module {idx}: {sum(p.numel() for p in module.parameters()) / 1e3} K parameters")

# Enable logging for each module
for module in modules:
    module.enable_logging()

# Concatenate the outputs from each module
c = torch.concat(
    [output_norm(module(input_norm(embedded_matrix))) for module, embedded_matrix in zip(modules, embedded_matrices)], 
    dim=2
)

print(c.shape)
del scheddy, input_norm, output_norm, module, c

Module 0: 8.256 K parameters
Module 1: 8.32 K parameters
Module 2: 8.32 K parameters

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 4, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 1, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 16, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 2, 64])

Inputs:
Tensor 'x' shape: torch.Size([32, 128, 64, 64])

Outputs:
Tensor 'output' shape: torch.Size([32, 128, 2, 64])
torch.Size([32, 128, 5, 64])
