In [None]:
# test segment-level recurrence
import torch
import torch.nn as nn
from transformers import TransfoXLTokenizer, TransfoXLPreTrainedModel

# mem_len decide the memory size
class Transformer_XL(nn.Module):
    def __init__(self, channels: int=256, n_layers: int=6, mem_len: int=200):
        super().__init__()
        self.channels = channels
        self.n_layers = n_layers
        self.mem_len = mem_len
        self.memory = None # save hidden_layer_param

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=channels, nhead=8),
            num_layers=n_layers
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.memory is not None:
            print(f"memory shape: {self.memory.shape}\n")
            x = torch.cat([self.memory, x], dim=0)

        output = self.encoder(x)
        print(f"output shape: {output.shape}\n")
        self.memory = output.detach()[-self.mem_len:].clone() # detach prevent tensor from contributing to the gradient calculations
        return output

model = Transformer_XL()
seq_1 = torch.rand(20, 10, 256)
seq_2 = torch.rand(20, 10, 256)
seq_3 = torch.rand(20, 10, 256)
input_seq = [seq_1, seq_2, seq_3]
for seq in input_seq:
    output = model(seq)

print(output.shape)

output shape: torch.Size([20, 10, 256])

memory shape: torch.Size([20, 10, 256])

output shape: torch.Size([40, 10, 256])

memory shape: torch.Size([40, 10, 256])

output shape: torch.Size([60, 10, 256])

torch.Size([60, 10, 256])


