In [1]:
import torch
import math
import os


if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

verbose = False

GPU: NVIDIA RTX A6000 is available.


## MeMo Tokenizer and input

In [2]:
from MeMoPyTorch.modelling_memo_tokenizer import MeMoTokenizer

In [3]:
max_length = 12 
tokenizer = MeMoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", 
                                          truncation_side = 'left',
                                          padding_side='left', max_length=max_length, head_number=4)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPTNeoXTokenizer'. 
The class this function is called from is 'MeMoTokenizer'.


Setting pad token and pad token id = <|endoftext|>, 0


In [4]:
with open("testo_di_prova.txt") as my_first_text_f:
    my_first_text = my_first_text_f.read()

token_ids = tokenizer.encode(my_first_text)#, return_tensors='pt')
print(token_ids) # return max len + 1 

(tensor([[18886,   256, 36144,  4164,  1809,    80,  1448,   295,   532,  1584,
            13, 50190]]), tensor([[  256, 36144,  4164,  1809,    80,  1448,   295,   532,  1584,    13,
         50190,    15]]))


In [5]:
memo_input = tokenizer.get_text_batch_encoding([my_first_text, my_first_text[0:10]])
memo_input.keys(), memo_input['input_ids'].shape

(dict_keys(['input_ids', 'labels']), torch.Size([52, 12]))

In [6]:
for i in range(3):
    print(tokenizer.decode(memo_input['input_ids'][i]))
    print(tokenizer.decode(memo_input['labels'][i]))
    print()

<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Cosimo di
<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Cosimo di Giovanni

 de' Medici detto il Vecchio o Pater
' Medici detto il Vecchio o Pater patri

æ (Firenze, 27 settembre 1389
 (Firenze, 27 settembre 1389 –



## MeMo Embedding layer

In [7]:
from MeMoPyTorch.modelling_memo_embedding import MeMoEmbedding

In [8]:
d,h,l = 1024, 4, 3

In [9]:
embedding = MeMoEmbedding(
    num_embeddings=tokenizer.vocab_size,
    embedding_dim=d,
    padding_idx=tokenizer.pad_token_id, #0
    _freeze=True
)

MeMo embedding initilialization


In [10]:
input_tokens_ids = tokenizer(['Test', 'Un altro Test'])['input_ids']
print(input_tokens_ids)

input_embeddings = embedding.forward(input_tokens_ids)
input_embeddings

tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
         5089],
        [   0,    0,    0,    0,    0,    0,    0,    0,    0, 2447, 6945,  287,
         6004]])


tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0582,  0.0155, -0.0178,  ..., -0.0274,  0.0375, -0.0045]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0050, -0.0133,  0.0261,  ...,  0.0378, -0.0316,  0.0230],
         [ 0.0376,  0.0051, -0.0028,  ..., -0.0063,  0.0614, -0.0016],
         [ 0.0601,  0.0036,  0.0164,  ...,  0.0109, -0.0369, -0.0159]]])

In [11]:
memo_input = tokenizer.get_text_batch_encoding([my_first_text, my_first_text[10:30]])

memo_input['input_ids'].shape

torch.Size([52, 12])

In [12]:
memo_input['input_ids'][0:10]

tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0, 38577,
         17622,  1073],
        [  372,     8,  9718,    74,   843,   936,  4164, 43876, 41380,   258,
           367,   727],
        [ 5507,   313, 15723,   445,  2721,    13,  3435,  3414,   358,  3381,
         15410,    26],
        [ 9776,  1266,    74,    13,   337, 11703,   639, 39337,  1638,  1540,
            10, 12187],
        [  440,  2314,  4173,   299,  8913,  2942,   250,   352,  6770,    80,
            13,  2248],
        [  861,   410,   372, 32924,  1073, 33813,   445,  2721,   299,  2248,
            80,  1484],
        [ 1073,   659,  4611,  1073,   391,   300,   466,  5711, 14804,  1431,
           304, 19702],
        [   74,    15, 14929,  1327,  1323, 10081, 24843, 15438,   412, 16406,
         38055,  9821],
        [ 3737,  1073,   391,   300,   466,  5711, 39814,   260,   770,  5991,
           313,  1962],
        [18006, 22217, 42722, 10863,   262,  7958,  1593, 12704,  5940,  

In [13]:
input_embeddings = embedding.encode(memo_input['input_ids'])
output_symbols = embedding.encode(memo_input['labels'])

input_embeddings.shape, output_symbols.shape

(torch.Size([52, 12, 1024]), torch.Size([52, 12, 1024]))

In [14]:
decoded, _ = embedding.decode(input_embeddings)
print(decoded.shape)

decoded[0:10]

torch.Size([52, 12])


tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0, 38577,
         17622,  1073],
        [  372,     8,  9718,    74,   843,   936,  4164, 43876, 41380,   258,
           367,   727],
        [ 5507,   313, 15723,   445,  2721,    13,  3435,  3414,   358,  3381,
         15410,    26],
        [ 9776,  1266,    74,    13,   337, 11703,   639, 39337,  1638,  1540,
            10, 12187],
        [  440,  2314,  4173,   299,  8913,  2942,   250,   352,  6770,    80,
            13,  2248],
        [  861,   410,   372, 32924,  1073, 33813,   445,  2721,   299,  2248,
            80,  1484],
        [ 1073,   659,  4611,  1073,   391,   300,   466,  5711, 14804,  1431,
           304, 19702],
        [   74,    15, 14929,  1327,  1323, 10081, 24843, 15438,   412, 16406,
         38055,  9821],
        [ 3737,  1073,   391,   300,   466,  5711, 39814,   260,   770,  5991,
           313,  1962],
        [18006, 22217, 42722, 10863,   262,  7958,  1593, 12704,  5940,  

In [15]:
decoded[0:10] == memo_input['input_ids'][:10]

tensor([[True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True]])

In [16]:
sims = embedding.weight @ embedding.weight.T
display(sims)
diag_sum = torch.sum(sims[1:, 1: ].diag()) # almost 1 in each entry
print(diag_sum) # obs vs expected
print(torch.sum(sims[1:, 1:]) - diag_sum) #almost 0... more or less

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  1.0116, -0.0136,  ..., -0.0119,  0.0133,  0.0027],
        [ 0.0000, -0.0136,  0.9932,  ..., -0.0567, -0.0084, -0.0524],
        ...,
        [ 0.0000, -0.0119, -0.0567,  ...,  1.0161, -0.0105,  0.0283],
        [ 0.0000,  0.0133, -0.0084,  ..., -0.0105,  1.0409,  0.0496],
        [ 0.0000,  0.0027, -0.0524,  ...,  0.0283,  0.0496,  0.9741]])

tensor(50247.5547)
tensor(-437.1758)


## Test layer and MeMo CMM

In [17]:
from MeMoPyTorch.modelling_memo_layer import MeMoLayer, ProjectionSequence, ProjectionTokens, CorrelationMatrixMemory

### Check initialization of each matrix

In [18]:
d,h,l = 1024, 4, 3
proj = ProjectionSequence(d, d*h)
print(proj.weight.shape, proj.extra_repr())
(proj.weight.T @ proj.weight).diag(), (proj.weight @ proj.weight.T).diag()

torch.Size([4096, 1024]) (trasposed wrt saved one) in_features=4096, out_features=1024


(tensor([1.0034, 1.0135, 0.9981,  ..., 0.9602, 1.0391, 1.0285],
        grad_fn=<DiagonalBackward0_copy>),
 tensor([0.2335, 0.2499, 0.2359,  ..., 0.2417, 0.2590, 0.2396],
        grad_fn=<DiagonalBackward0_copy>))

In [19]:
Prj = torch.normal(0, 1/math.sqrt(d*h), size=(d,d*h))
Prj = torch.transpose(Prj, 0, 1)
(Prj.T @ Prj).diag(), (Prj @ Prj.T).diag()

(tensor([0.9540, 0.9915, 0.9663,  ..., 0.9823, 1.0174, 1.0255]),
 tensor([0.2524, 0.2305, 0.2505,  ..., 0.2412, 0.2558, 0.2710]))

In [20]:
print(d)
print(h)

d_k = d // h

W_v = ProjectionTokens(d, d_k)

print(W_v.weight.shape, W_v.extra_repr())


### always used transposed! so check with .T
(W_v.weight.T @ W_v.weight).diag(), (W_v.weight @ W_v.weight.T).diag()

1024
4
torch.Size([256, 1024]) in_features=1024, out_features=256


(tensor([0.9716, 0.8887, 0.9213,  ..., 0.9931, 1.0560, 1.0470],
        grad_fn=<DiagonalBackward0_copy>),
 tensor([3.8957, 3.9282, 3.8173, 3.9242, 3.7240, 4.0538, 4.2113, 3.8914, 4.1052,
         4.1155, 4.0357, 3.8291, 4.0179, 4.0468, 4.1695, 3.6141, 3.9338, 3.9435,
         4.2316, 3.8034, 4.2810, 3.8362, 3.9516, 4.0251, 3.9569, 4.0409, 4.1122,
         3.8379, 3.9411, 4.0859, 4.1653, 4.1817, 4.2582, 3.8627, 3.9495, 3.9011,
         3.9367, 4.0767, 3.7587, 3.8176, 3.9241, 3.9176, 3.8721, 4.1166, 3.8814,
         4.1041, 4.4408, 3.9817, 4.1150, 4.1030, 4.2362, 4.0798, 3.8293, 3.8891,
         3.9857, 4.4485, 4.1655, 3.6626, 4.0928, 3.8111, 3.8579, 3.7830, 3.9901,
         3.7101, 4.1409, 4.0890, 3.8550, 4.0866, 4.1738, 3.9588, 4.0144, 4.3011,
         4.0719, 3.9803, 3.8962, 4.0984, 4.1461, 4.2218, 3.6977, 4.1242, 3.9939,
         4.0745, 4.1220, 4.1001, 3.8201, 3.8617, 3.8810, 3.8470, 4.1705, 4.1657,
         3.8939, 3.8102, 4.4850, 3.8803, 4.0308, 3.8671, 3.8971, 3.7658, 3.8068,
  

In [21]:
W_v_single_head = torch.normal(0, 1/math.sqrt(d_k), size=(d,d_k))

(W_v_single_head.T @ W_v_single_head).diag(), (W_v_single_head @ W_v_single_head.T).diag()

(tensor([4.0607, 3.9188, 4.1343, 4.3398, 3.9781, 4.0719, 3.9630, 3.9899, 4.1298,
         4.0704, 4.1006, 3.5645, 4.1545, 4.2411, 4.3649, 3.8846, 4.0114, 4.1324,
         3.9862, 3.8759, 3.9989, 3.9211, 3.7758, 3.6554, 4.1563, 3.9463, 3.8554,
         3.7379, 4.1779, 4.1594, 4.2406, 4.0629, 3.8923, 3.6850, 4.0267, 4.1834,
         4.0705, 3.9980, 3.9785, 3.6814, 3.8253, 3.9768, 4.1703, 3.5919, 3.7570,
         4.1289, 4.1667, 3.6757, 3.9860, 4.0630, 4.1107, 3.8269, 4.0566, 3.9629,
         3.8810, 3.9809, 3.8808, 3.9366, 3.6887, 3.7962, 3.9128, 3.6060, 4.1888,
         4.0053, 3.8992, 4.1934, 3.9406, 3.9481, 4.0117, 3.9613, 4.0182, 3.7418,
         3.9632, 3.9029, 4.2775, 3.7637, 3.9496, 3.9138, 3.8087, 4.0188, 4.0456,
         4.1549, 3.9044, 4.1022, 3.7980, 3.9736, 3.9323, 4.2428, 4.1303, 3.8692,
         4.0947, 4.3600, 4.2825, 3.8362, 3.9793, 3.9693, 4.0807, 3.7860, 3.7670,
         4.0065, 3.9023, 3.7851, 4.0609, 3.9677, 4.1964, 3.6961, 4.0435, 3.7618,
         4.1452, 3.9568, 3.9

### Check memorization on single layer

In [22]:
d,h,l = 1024, 4, 3

layer = MeMoLayer(d, h)
layer

MeMoLayer(
  (W_v_single_head): ProjectionTokens(in_features=1024, out_features=256)
  (Prj): ProjectionSequence((trasposed wrt saved one) in_features=4096, out_features=1024)
  (CMM): CorrelationMatrixMemory(in_features=1024, out_features=1024)
)

In [23]:
batch_size, current_length, d = input_embeddings.shape
batch_size, current_length, d 

(52, 12, 1024)

In [24]:
output_symbols.shape

torch.Size([52, 12, 1024])

In [25]:
current_length = int(input_embeddings.shape[1]/ h)

input_sequence = input_embeddings.reshape((batch_size, current_length, h, d))

current_output_symbols = output_symbols[:, [(x+1)*h-1 for x in range(0,current_length)]]
j = 2 
print(sum(sum(input_sequence[0][j] == input_embeddings[0][4*j:4*(j+1)])), input_sequence[0][j].shape)

(batch_size, blocks,h,d) = input_sequence.shape

tensor(4096) torch.Size([4, 1024])


In [26]:
input_sequence.shape, current_output_symbols.shape

(torch.Size([52, 3, 4, 1024]), torch.Size([52, 3, 1024]))

In [27]:
layer = MeMoLayer(d, h)
display(layer)
## update the input sequence for the next layer
_, seq_encoding_for_the_last_layer = layer.memorize(input_sequence, current_output_symbols, is_last=False)
layer.directly_memorize(seq_encoding_for_the_last_layer)

MeMoLayer(
  (W_v_single_head): ProjectionTokens(in_features=1024, out_features=256)
  (Prj): ProjectionSequence((trasposed wrt saved one) in_features=4096, out_features=1024)
  (CMM): CorrelationMatrixMemory(in_features=1024, out_features=1024)
)

In [28]:
input_sequence.shape, input_sequence[3].shape # batch (52 elements of chunks 4*4*1024)

(torch.Size([52, 3, 4, 1024]), torch.Size([3, 4, 1024]))

In [29]:
_, seq_encoding_for_the_last_layer = layer.retrieve(input_sequence)

print(seq_encoding_for_the_last_layer.shape)

torch.Size([52, 1024])


In [30]:
logits = layer.directly_retrieve(seq_encoding_for_the_last_layer)

In [31]:
retreived_output_symbol_vector, m = embedding.decode(logits)
print(retreived_output_symbol_vector, m)

tensor([48505, 20110,  1108, 48019,    80, 19216,  9718,  1113,  4927,    66,
        19216,  1448,  2122, 41530,   187,  4172,   246,   659, 10986, 30975,
           80, 12931,   352, 14134,  2721,   258,  8830,    87,   826,    15,
        17532,   729, 26798, 41070,  6575,   299,   266,  3737, 20889,   287,
          512,   354,   250,   247,    70,   275, 16128,  2680,    74, 13679,
           15,   843]) tensor([0.8819, 1.0156, 0.9961, 0.9514, 1.0413, 1.1477, 1.1566, 0.9515, 1.0465,
        1.9006, 1.2682, 0.9832, 1.1472, 1.0905, 1.0196, 1.1548, 1.0398, 0.9599,
        0.9555, 1.0387, 1.0676, 1.1081, 1.0265, 1.1388, 1.1526, 1.1123, 0.9653,
        1.0188, 1.1298, 0.8249, 1.0822, 1.0452, 0.9774, 0.9335, 0.9042, 0.8819,
        1.0207, 0.9899, 1.0153, 1.1080, 0.8006, 0.8810, 0.8329, 1.1565, 1.2493,
        0.9081, 1.0114, 0.9514, 1.0136, 1.0880, 0.8588, 1.0223],
       grad_fn=<MaxBackward0>)


In [32]:
o = embedding.decode(current_output_symbols[:, -1])[0]
display(o)

print(sum(o == retreived_output_symbol_vector), 'over', retreived_output_symbol_vector.shape)

tensor([48505, 20110,  1108, 48019,    80, 19216,  9718,  1113,  4927,    66,
        19216,  1448,  2122, 41530,   187,  4172,   246,   659, 10986, 30975,
           80, 12931,   352, 14134,  2721,   258,  8830,    87,   826,    15,
        17532,   729, 26798, 41070,  6575,   299,   266,  3737, 20889,   287,
          512,   354,   250,   247,    70,   275, 16128,  2680,    74, 13679,
           15,   209])

tensor(51) over torch.Size([52])


In [33]:
#### test single block

In [34]:
print(input_sequence[3].shape, current_output_symbols[3].shape)
print(input_sequence[3][0].shape, current_output_symbols[3][0].shape)

torch.Size([3, 4, 1024]) torch.Size([3, 1024])
torch.Size([4, 1024]) torch.Size([1024])


In [35]:
total = 0
correct = 0 

for batch_index in range(len(memo_input['input_ids'])):
    #print("input ids", memo_input[0]['input_ids'][batch_index])
    #print()
    
    for i in range(len(current_output_symbols[batch_index])):
        #display(embedding.decode(input_sequence[batch_index][i]), embedding.decode(current_output_symbols[batch_index][i]))
        true = embedding.decode(current_output_symbols[batch_index][i])[0].item()
        
        _, seq_encoding_for_the_last_layer  = layer.retrieve(input_sequence[batch_index][i].unsqueeze(0).unsqueeze(0))
        
        retreived_output_symbol_vector, m = embedding.decode(layer.directly_retrieve(seq_encoding_for_the_last_layer))
        pred = retreived_output_symbol_vector.item()

        total += 1
        correct += pred == true

print(f"{correct}/{total}")

153/156


### Check with batch size of 1 and output probs

In [36]:
memo_input = tokenizer.get_text_batch_encoding(['this is a test for a very short short sequence of 12 tokens'])
input_ids, labels = memo_input['input_ids'], memo_input['labels']
print(input_ids, labels)

input_embeddings = embedding.encode(input_ids)
#print(input_embeddings.shape)

output_embeddings = embedding.encode(labels)
#print(output_embeddings.shape)


current_length = max_length

current_length = int(current_length/h)
input_sequence = input_embeddings.reshape((1, current_length, h, d))

output_symbols = output_embeddings[:, [(x+1)*h-1 for x in range(0,current_length)]] ## the output symbol is always the same tokem?
print(embedding.decode(output_symbols))
input_sequence.shape, output_symbols.shape

tensor([[2520,  310,  247, 1071,  323,  247, 1077, 2159, 2159, 3425,  273, 1249]]) tensor([[  310,   247,  1071,   323,   247,  1077,  2159,  2159,  3425,   273,
          1249, 21761]])
(tensor([[  323,  2159, 21761]]), tensor([[1.0203, 0.9653, 0.9761]]))


(torch.Size([1, 3, 4, 1024]), torch.Size([1, 3, 1024]))

In [37]:
embedding.decode(input_sequence)[0], embedding.decode(output_symbols)[0], input_ids

(tensor([[[2520,  310,  247, 1071],
          [ 323,  247, 1077, 2159],
          [2159, 3425,  273, 1249]]]),
 tensor([[  323,  2159, 21761]]),
 tensor([[2520,  310,  247, 1071,  323,  247, 1077, 2159, 2159, 3425,  273, 1249]]))

In [38]:
input_sequence.shape

torch.Size([1, 3, 4, 1024])

In [39]:
output_symbols.shape

torch.Size([1, 3, 1024])

In [40]:
layer = MeMoLayer(d, h)
display(layer)

## update the input sequence for the next layer
_, seq_encoding_for_the_last_layer = layer.memorize(input_sequence, output_symbols, is_last=False)
layer.directly_memorize(seq_encoding_for_the_last_layer)

for i in range(0,3):
    _, seq_encoding_for_the_last_layer = layer.retrieve(input_sequence[0][i].unsqueeze(0).unsqueeze(0))
    print(seq_encoding_for_the_last_layer.shape)
                                                        
    retreived_output_symbol_vector, m = embedding.decode(layer.directly_retrieve(seq_encoding_for_the_last_layer))
    print(retreived_output_symbol_vector, m)

MeMoLayer(
  (W_v_single_head): ProjectionTokens(in_features=1024, out_features=256)
  (Prj): ProjectionSequence((trasposed wrt saved one) in_features=4096, out_features=1024)
  (CMM): CorrelationMatrixMemory(in_features=1024, out_features=1024)
)

torch.Size([1, 1024])
tensor([323]) tensor([1.0873], grad_fn=<MaxBackward0>)
torch.Size([1, 1024])
tensor([2159]) tensor([0.9083], grad_fn=<MaxBackward0>)
torch.Size([1, 1024])
tensor([21761]) tensor([0.9877], grad_fn=<MaxBackward0>)


## Test the entire MeMo model

In [41]:
from MeMoPyTorch.modelling_memo import MeMo

In [42]:
from MeMoPyTorch.modelling_memo_tokenizer import MeMoTokenizer

In [43]:
with open("testo_di_prova.txt") as my_first_text_f:
    my_first_text = my_first_text_f.read()

In [44]:
max_length = 384
print(max_length, h)
tokenizer = MeMoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", 
                                          padding_side='left', truncation_side='left', 
                                          max_length=max_length, head_number=h)

384 4


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPTNeoXTokenizer'. 
The class this function is called from is 'MeMoTokenizer'.


Setting pad token and pad token id = <|endoftext|>, 0


In [45]:
memo_input = tokenizer.get_text_batch_encoding(my_first_text)
memo_input['labels'].shape

torch.Size([2, 384])

In [46]:
device='cuda:0' #'cpu'

In [47]:
d,h,l

(1024, 4, 3)

In [48]:
model = MeMo(inner_dim=d, 
             num_of_heads=h, 
             num_of_layers=3, 
             chunk_length=max_length, 
             num_embeddings=tokenizer.vocab_size, 
             padding_idx=tokenizer.pad_token_id, 
             device=device) #MeMoModel
model

MeMo embedding initilialization


MeMo(
  (encoder): MeMoEmbedding(50254, 1024, padding_idx=0)
  (layers): ModuleList(
    (0-2): 3 x MeMoLayer(
      (W_v_single_head): ProjectionTokens(in_features=1024, out_features=256)
      (Prj): ProjectionSequence((trasposed wrt saved one) in_features=4096, out_features=1024)
      (CMM): CorrelationMatrixMemory(in_features=1024, out_features=1024)
    )
  )
)

In [49]:
input_sequence =  model.encoder.encode(memo_input['input_ids'])
output_symbols = model.encoder.encode(memo_input['labels'])

(batch_size, current_length, d) = input_sequence.shape
last_layer = model.layers[model.l-1]

current_length = model.chunk_length

input_sequence.shape, output_symbols.shape, current_length

(torch.Size([2, 384, 1024]), torch.Size([2, 384, 1024]), 384)

In [50]:
layer_level = 0
input_index = [[j for j in range(i - model.h ** (layer_level + 1), i, model.h ** ((layer_level + 1) - 1))] 
               for i in range(model.h ** (layer_level + 1), current_length + 1)]
input_index[0:5]

[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]]

In [51]:
original_input_seq = torch.clone(input_sequence)
original_out = torch.clone(output_symbols)

original_input_seq.shape, original_out.shape

(torch.Size([2, 384, 1024]), torch.Size([2, 384, 1024]))

In [52]:
#for layer_level in range(model.l):
layer_level = 0 #1,2
          
print(model.h ** (layer_level + 1) < current_length + 1)
## update the input sequence for the next layer
layer_output_idxs = [i - model.h ** ((layer_level + 1) - 1) for i in range(model.h ** (layer_level + 1), current_length + 1)]
output_symbols = output_symbols[:, layer_output_idxs]
print(output_symbols.shape)

input_index = [[j for j in range(i - model.h ** (layer_level + 1), i, model.h ** ((layer_level + 1) - 1))] 
               for i in range(model.h ** (layer_level + 1), current_length + 1)]


input_sequence = input_sequence[:, input_index]
print(input_sequence.shape)

True
torch.Size([2, 381, 1024])
torch.Size([2, 381, 4, 1024])


In [53]:
model.encoder.decode(output_symbols[0])[0] == model.encoder.decode(torch.stack([original_out[0][i] for i in layer_output_idxs]))[0]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, Tr

In [54]:
import tqdm


class Evaluation:
    def check_memorization(self, model, tokenizer, text, # device='cpu',
                           starting_point=None):
        if starting_point == None:
            basic_block = model.h ** model.l
        else:
            basic_block = starting_point
        
        
        input_ = tokenizer(my_first_text, padding='longest', truncation='do_not_truncate', max_length=None)
        input_ = tokenizer.pad(input_, pad_to_multiple_of=basic_block)
        input_ids = input_['input_ids']
                
        count = 0
        correct = 0
        max_length = tokenizer.max_length
        (batch_size, number_of_tokens) = input_ids.shape

        #print(f"(batch_size, number_of_tokens) = {(batch_size, number_of_tokens)}")
        
        for i in tqdm.tqdm(range(basic_block,  number_of_tokens - 1)):
            text_tokens = input_ids[:, i - basic_block:i]
            
            (batch_size, number_of_tokens) = text_tokens.shape
            
            text_tokens = torch.concat((torch.zeros((batch_size, max_length-1-number_of_tokens), 
                                                    dtype=torch.int), 
                                        text_tokens), axis=1
                                      )
            
            #print(i - basic_block, i)
            out, max_value = model.retrieve(text_tokens)
            #print(out, input_ids[:, i])
            #print(out[0].item())
            
            count += batch_size
            correct += torch.sum(out.to('cpu') == input_ids[:, i])
        
                           
        return correct / count

    def check_pretokenized(self, model, tokenizer, input_ids,# device='cpu',
                           starting_point=None):
        if starting_point == None:
            basic_block = model.h ** model.l
        else:
            basic_block = starting_point
                
        count = 0
        correct = 0
        max_length = tokenizer.max_length
        (batch_size, number_of_tokens) = input_ids.shape

        #print(f"(batch_size, number_of_tokens) = {(batch_size, number_of_tokens)}")
        
        for i in tqdm.tqdm(range(basic_block,  number_of_tokens - 1)):
            text_tokens = input_ids[:, i - basic_block:i]
            
            (batch_size, number_of_tokens) = text_tokens.shape
            
            text_tokens = torch.concat((torch.zeros((batch_size, max_length-1-number_of_tokens), 
                                                    dtype=torch.int), 
                                        text_tokens), axis=1
                                      )
            
            #print(i - basic_block, i)
            out, max_value = model.retrieve(text_tokens)
            #print(out, input_ids[:, i])
            #print(out[0].item())
            
            count += batch_size
            correct += torch.sum(out.to('cpu') == input_ids[:, i])
        
                           
        return correct / count
        

In [55]:
model = MeMo(inner_dim=d, 
             num_of_heads=h, 
             num_of_layers=l, 
             chunk_length=max_length, 
             num_embeddings=tokenizer.vocab_size, 
             padding_idx=tokenizer.pad_token_id, 
             device=device)

memo_input = tokenizer.get_text_batch_encoding([my_first_text]*8)


memo_input['input_ids'].shape

MeMo embedding initilialization


torch.Size([16, 384])

In [56]:
model.memorize_text(memo_input)

In [57]:
e = Evaluation()
out = e.check_pretokenized(model, tokenizer, memo_input['input_ids'])
print("Degree of memorization: %f ", out)

100%|██████████| 319/319 [00:00<00:00, 746.92it/s]

Degree of memorization: %f  tensor(0.8448)





In [58]:
model.forget_text(memo_input)

In [59]:
out = e.check_pretokenized(model, tokenizer, memo_input['input_ids'])
print("Degree of memorization: %f ", out)

100%|██████████| 319/319 [00:00<00:00, 763.57it/s]

Degree of memorization: %f  tensor(0.0815)





In [60]:
model = MeMo(inner_dim=d, 
             num_of_heads=h, 
             num_of_layers=l, 
             chunk_length=max_length, 
             num_embeddings=tokenizer.vocab_size, 
             padding_idx=tokenizer.pad_token_id, 
             device=device)
print("CMM pre learning")
display(model.layers[0].CMM.weight)


bs = 8
for b in range(bs):
    memo_input = tokenizer.get_text_batch_encoding(my_first_text)
    print(memo_input['input_ids'].shape)

    model.memorize_text(memo_input)

e = Evaluation()
out = e.check_pretokenized(model, tokenizer, memo_input['input_ids'])
print("Degree of memorization: %f ", out)

MeMo embedding initilialization
CMM pre learning


Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', requires_grad=True)

torch.Size([2, 384])
torch.Size([2, 384])
torch.Size([2, 384])
torch.Size([2, 384])
torch.Size([2, 384])
torch.Size([2, 384])
torch.Size([2, 384])
torch.Size([2, 384])


100%|██████████| 319/319 [00:00<00:00, 792.50it/s]

Degree of memorization: %f  tensor(0.9624)





In [61]:
Prj = model.layers[0].Prj.weight.detach().cpu()
CMM = model.layers[0].CMM.weight.detach().cpu()

Prj.T @ Prj

tensor([[ 9.4733e-01,  6.2101e-04, -8.6885e-03,  ..., -8.4878e-03,
          1.7844e-02,  1.2130e-02],
        [ 6.2101e-04,  1.0110e+00, -2.2606e-02,  ..., -1.0224e-02,
         -1.9795e-02,  8.5904e-03],
        [-8.6885e-03, -2.2606e-02,  9.6828e-01,  ...,  1.8914e-02,
         -3.0991e-03,  2.6482e-02],
        ...,
        [-8.4878e-03, -1.0224e-02,  1.8914e-02,  ...,  1.0126e+00,
         -2.0454e-02,  2.7226e-03],
        [ 1.7844e-02, -1.9795e-02, -3.0991e-03,  ..., -2.0454e-02,
          1.0351e+00, -6.7074e-03],
        [ 1.2130e-02,  8.5904e-03,  2.6482e-02,  ...,  2.7226e-03,
         -6.7074e-03,  1.0019e+00]])

In [62]:
CMM

tensor([[ 0.0060,  0.0329, -0.0196,  ..., -0.0565,  0.0308, -0.0099],
        [-0.0205, -0.0693, -0.0293,  ...,  0.0095, -0.0060,  0.0002],
        [-0.0252, -0.0445, -0.0202,  ...,  0.0301,  0.0368,  0.0085],
        ...,
        [ 0.0099,  0.0281, -0.0018,  ..., -0.0022,  0.0010,  0.0280],
        [-0.0315, -0.0133, -0.0283,  ...,  0.0207,  0.0575, -0.0010],
        [-0.0069,  0.0160,  0.0095,  ...,  0.0081,  0.0150, -0.0087]])

In [63]:
out = e.check_pretokenized(model, tokenizer, memo_input['input_ids'])
print("Degree of memorization: %f ", out)

100%|██████████| 319/319 [00:00<00:00, 802.99it/s]

Degree of memorization: %f  tensor(0.9624)





In [64]:
model.forget_text(memo_input)

In [65]:
out = e.check_pretokenized(model, tokenizer, memo_input['input_ids'])
print("Degree of memorization: %f ", out)

100%|██████████| 319/319 [00:00<00:00, 799.53it/s]

Degree of memorization: %f  tensor(0.9624)



