In [1]:
import torch
import math
import os


if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

verbose = False

GPU: NVIDIA RTX A6000 is available.


In [2]:
!nvidia-smi

Mon Jan  6 14:53:33 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A6000               Off | 00000000:3B:00.0 Off |                  Off |
| 30%   32C    P8              27W / 300W |     25MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000               Off | 00000000:5E:00.0 Off |  

## MeMo Tokenizer and input

In [3]:
from MeMoPyTorch.modelling_memo_tokenizer import MeMoTokenizer

In [4]:
max_length = 12 
tokenizer = MeMoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", 
                                          truncation_side = 'left',
                                          padding_side='left', max_length=max_length, head_number=4)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPTNeoXTokenizer'. 
The class this function is called from is 'MeMoTokenizer'.


Setting pad token and pad token id = <|endoftext|>, 0


In [5]:
with open("testo_di_prova.txt") as my_first_text_f:
    my_first_text = my_first_text_f.read()

token_ids = tokenizer.encode(my_first_text)#, return_tensors='pt')
print(token_ids) # return max len + 1 

(tensor([[18886,   256, 36144,  4164,  1809,    80,  1448,   295,   532,  1584,
            13, 50190]]), tensor([[  256, 36144,  4164,  1809,    80,  1448,   295,   532,  1584,    13,
         50190,    15]]))


In [6]:
memo_input = tokenizer.get_text_batch_encoding([my_first_text, my_first_text[0:10]])
memo_input.keys()

dict_keys([0, 1, 2, 3])

In [7]:
memo_input[0]['input_ids'].shape

torch.Size([52, 12])

## MeMo Embedding layer

In [8]:
from MeMoPyTorch.modelling_memo_embedding import MeMoEmbedding

In [9]:
d,h,l = 1024, 4, 3

In [10]:
embedding = MeMoEmbedding(
    num_embeddings=tokenizer.vocab_size,
    embedding_dim=d,
    padding_idx=tokenizer.pad_token_id, #0
    _freeze=True
)

MeMo embedding initilialization


In [11]:
input_tokens_ids = tokenizer(['Test', 'Un altro Test'])['input_ids']
print(input_tokens_ids)

input_embeddings = embedding.forward(input_tokens_ids)
input_embeddings

tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
         5089],
        [   0,    0,    0,    0,    0,    0,    0,    0,    0, 2447, 6945,  287,
         6004]])


tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0835, -0.0631, -0.0027,  ..., -0.0271, -0.0230,  0.0581]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [-0.0217,  0.0062, -0.0081,  ..., -0.0301, -0.0055, -0.0123],
         [ 0.0048,  0.0056, -0.0090,  ...,  0.0085, -0.0227, -0.0020],
         [-0.0105,  0.0004, -0.0299,  ...,  0.0073,  0.0339,  0.0490]]])

In [12]:
memo_input = tokenizer.get_text_batch_encoding([my_first_text, my_first_text[10:30]])

memo_input[0]['input_ids'].shape

torch.Size([52, 12])

In [13]:
memo_input[0]['input_ids'][0:10]

tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0, 38577,
         17622,  1073],
        [  372,     8,  9718,    74,   843,   936,  4164, 43876, 41380,   258,
           367,   727],
        [ 5507,   313, 15723,   445,  2721,    13,  3435,  3414,   358,  3381,
         15410,    26],
        [ 9776,  1266,    74,    13,   337, 11703,   639, 39337,  1638,  1540,
            10, 12187],
        [  440,  2314,  4173,   299,  8913,  2942,   250,   352,  6770,    80,
            13,  2248],
        [  861,   410,   372, 32924,  1073, 33813,   445,  2721,   299,  2248,
            80,  1484],
        [ 1073,   659,  4611,  1073,   391,   300,   466,  5711, 14804,  1431,
           304, 19702],
        [   74,    15, 14929,  1327,  1323, 10081, 24843, 15438,   412, 16406,
         38055,  9821],
        [ 3737,  1073,   391,   300,   466,  5711, 39814,   260,   770,  5991,
           313,  1962],
        [18006, 22217, 42722, 10863,   262,  7958,  1593, 12704,  5940,  

In [14]:
input_embeddings = embedding.encode(memo_input[0]['input_ids'])
output_symbols = embedding.encode(memo_input[0]['labels'])

input_embeddings.shape, output_symbols.shape

(torch.Size([52, 12, 1024]), torch.Size([52, 12, 1024]))

In [15]:
decoded, _ = embedding.decode(input_embeddings)
print(decoded.shape)

decoded[0:10]

torch.Size([52, 12])


tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0, 38577,
         17622,  1073],
        [  372,     8,  9718,    74,   843,   936,  4164, 43876, 41380,   258,
           367,   727],
        [ 5507,   313, 15723,   445,  2721,    13,  3435,  3414,   358,  3381,
         15410,    26],
        [ 9776,  1266,    74,    13,   337, 11703,   639, 39337,  1638,  1540,
            10, 12187],
        [  440,  2314,  4173,   299,  8913,  2942,   250,   352,  6770,    80,
            13,  2248],
        [  861,   410,   372, 32924,  1073, 33813,   445,  2721,   299,  2248,
            80,  1484],
        [ 1073,   659,  4611,  1073,   391,   300,   466,  5711, 14804,  1431,
           304, 19702],
        [   74,    15, 14929,  1327,  1323, 10081, 24843, 15438,   412, 16406,
         38055,  9821],
        [ 3737,  1073,   391,   300,   466,  5711, 39814,   260,   770,  5991,
           313,  1962],
        [18006, 22217, 42722, 10863,   262,  7958,  1593, 12704,  5940,  

In [16]:
decoded[0:10] == memo_input[0]['input_ids'][:10]

tensor([[True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True]])

In [17]:
sims = embedding.weight @ embedding.weight.T
display(sims)
diag_sum = torch.sum(sims[1:, 1: ].diag()) # almost 1 in each entry
print(diag_sum) # obs vs expected
print(torch.sum(sims[1:, 1:]) - diag_sum) #almost 0... more or less

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  9.7935e-01,  1.2794e-02,  ...,  6.0318e-02,
          2.1276e-02,  2.9399e-02],
        [ 0.0000e+00,  1.2794e-02,  9.6680e-01,  ...,  3.8886e-02,
         -1.7371e-02,  5.5490e-02],
        ...,
        [ 0.0000e+00,  6.0318e-02,  3.8886e-02,  ...,  9.6188e-01,
          6.3878e-02, -6.1192e-04],
        [ 0.0000e+00,  2.1276e-02, -1.7371e-02,  ...,  6.3878e-02,
          9.8433e-01, -4.7447e-02],
        [ 0.0000e+00,  2.9399e-02,  5.5490e-02,  ..., -6.1192e-04,
         -4.7447e-02,  9.7612e-01]])

tensor(50250.5938)
tensor(-1252.6641)


## Test layer and MeMo CMM

In [18]:
from MeMoPyTorch.modelling_memo_layer import MeMoLayer, ProjectionSequence, ProjectionTokens, CorrelationMatrixMemory

### Check initialization of each matrix

In [19]:
d,h,l = 1024, 4, 3
proj = ProjectionSequence(d, d*h)
print(proj.weight.shape, proj.extra_repr())
(proj.weight.T @ proj.weight).diag(), (proj.weight @ proj.weight.T).diag()

torch.Size([4096, 1024]) (trasposed wrt saved one) in_features=4096, out_features=1024


(tensor([1.0144, 1.0094, 1.0184,  ..., 0.9896, 0.9906, 0.9925],
        grad_fn=<DiagonalBackward0_copy>),
 tensor([0.2668, 0.2575, 0.2495,  ..., 0.2568, 0.2474, 0.2405],
        grad_fn=<DiagonalBackward0_copy>))

In [20]:
Prj = torch.normal(0, 1/math.sqrt(d*h), size=(d,d*h))
Prj = torch.transpose(Prj, 0, 1)
(Prj.T @ Prj).diag(), (Prj @ Prj.T).diag()

(tensor([1.0223, 0.9927, 1.0043,  ..., 1.0023, 1.0144, 1.0078]),
 tensor([0.2474, 0.2642, 0.2357,  ..., 0.2390, 0.2545, 0.2311]))

In [21]:
print(d)
print(h)

d_k = d // h

W_v = ProjectionTokens(d, d_k)

print(W_v.weight.shape, W_v.extra_repr())


### always used transposed! so check with .T
(W_v.weight.T @ W_v.weight).diag(), (W_v.weight @ W_v.weight.T).diag()

1024
4
torch.Size([256, 1024]) in_features=1024, out_features=256


(tensor([1.0034, 1.1309, 1.0402,  ..., 0.9823, 0.8685, 1.1207],
        grad_fn=<DiagonalBackward0_copy>),
 tensor([3.7838, 3.8727, 3.8892, 4.1692, 3.7304, 4.0257, 3.6414, 4.1350, 3.9648,
         3.9609, 3.6953, 3.9883, 4.0704, 3.9424, 4.0839, 3.9512, 4.2426, 3.9462,
         3.8508, 4.1929, 3.8313, 4.0518, 4.0457, 4.1839, 4.0315, 4.2982, 3.7961,
         4.0107, 4.0511, 4.3714, 3.7917, 4.0835, 3.8269, 3.9315, 4.2685, 3.9030,
         4.0416, 3.9057, 3.9971, 3.9861, 4.2417, 3.8572, 4.0085, 3.8302, 3.7265,
         4.2387, 3.8386, 3.8492, 3.9907, 4.0602, 4.0078, 3.7434, 4.0288, 4.1997,
         4.1648, 3.9174, 4.1524, 3.6851, 4.0384, 3.9745, 3.9933, 3.9400, 4.2912,
         3.9150, 3.8510, 4.3452, 4.2322, 3.9411, 4.1291, 4.3399, 4.3190, 3.7528,
         4.0245, 3.9443, 3.6496, 4.0652, 3.7083, 4.1614, 3.7873, 3.9109, 4.1132,
         3.5621, 4.1639, 4.1976, 3.6749, 4.1739, 4.1141, 3.7559, 4.3099, 4.0219,
         4.2101, 4.0856, 3.9498, 4.1303, 3.9624, 4.0415, 3.9125, 4.4438, 3.7975,
  

In [22]:
W_v_single_head = torch.normal(0, 1/math.sqrt(d_k), size=(d,d_k))

(W_v_single_head.T @ W_v_single_head).diag(), (W_v_single_head @ W_v_single_head.T).diag()

(tensor([3.8580, 3.8931, 4.2426, 3.8728, 4.1576, 4.0706, 3.9007, 3.9701, 3.8986,
         3.7712, 3.9201, 4.0676, 3.6894, 3.9858, 3.9228, 4.1466, 3.7117, 3.7501,
         4.0615, 3.8578, 4.2153, 3.8464, 3.6133, 3.8736, 3.8603, 3.9971, 3.8789,
         4.1942, 4.1517, 3.8603, 3.9019, 3.4583, 3.5920, 3.8812, 3.9152, 3.9505,
         4.1654, 4.3137, 4.2492, 3.8551, 4.0042, 4.0932, 4.0694, 3.7652, 4.2155,
         4.0445, 4.1697, 3.9420, 3.9184, 4.0150, 4.0032, 4.0153, 3.7572, 3.9870,
         3.9958, 3.5553, 3.8700, 4.0539, 4.0136, 3.9899, 3.5696, 4.1975, 4.1856,
         3.6233, 3.7882, 3.9867, 3.8463, 3.9743, 3.9500, 4.0963, 4.0645, 4.3165,
         4.0246, 4.0406, 4.0985, 3.7918, 3.7860, 3.9474, 4.4527, 4.1771, 4.3035,
         4.0906, 4.1758, 3.8684, 3.7808, 4.1073, 4.1180, 3.7858, 4.5228, 3.8587,
         4.0921, 3.8875, 4.1034, 4.1227, 3.9766, 3.7816, 3.7624, 4.1224, 4.0637,
         3.8097, 3.7841, 3.7562, 3.9015, 3.9028, 3.8789, 3.9445, 4.1326, 4.0425,
         3.9117, 4.0736, 4.1

### Check memorization on single layer

In [23]:
d,h,l = 1024, 4, 3

layer = MeMoLayer(d, h)
layer

MeMoLayer(
  (W_v_single_head): ProjectionTokens(in_features=1024, out_features=256)
  (Prj): ProjectionSequence((trasposed wrt saved one) in_features=4096, out_features=1024)
  (CMM): CorrelationMatrixMemory(in_features=1024, out_features=1024)
)

In [24]:
batch_size, current_length, d = input_embeddings.shape
batch_size, current_length, d 

(52, 12, 1024)

In [25]:
output_symbols.shape

torch.Size([52, 12, 1024])

In [26]:
current_length = int(input_embeddings.shape[1]/ h)

input_sequence = input_embeddings.reshape((batch_size, current_length, h, d))

current_output_symbols = output_symbols[:, [(x+1)*h-1 for x in range(0,current_length)]]
j = 2 
print(sum(sum(input_sequence[0][j] == input_embeddings[0][4*j:4*(j+1)])), input_sequence[0][j].shape)

(batch_size, blocks,h,d) = input_sequence.shape

tensor(4096) torch.Size([4, 1024])


In [27]:
input_sequence.shape, current_output_symbols.shape

(torch.Size([52, 3, 4, 1024]), torch.Size([52, 3, 1024]))

In [28]:
layer = MeMoLayer(d, h)
display(layer)
## update the input sequence for the next layer
_, seq_encoding_for_the_last_layer = layer.memorize(input_sequence, current_output_symbols, is_last=False)
layer.directly_memorize(seq_encoding_for_the_last_layer)

MeMoLayer(
  (W_v_single_head): ProjectionTokens(in_features=1024, out_features=256)
  (Prj): ProjectionSequence((trasposed wrt saved one) in_features=4096, out_features=1024)
  (CMM): CorrelationMatrixMemory(in_features=1024, out_features=1024)
)

In [29]:
input_sequence.shape, input_sequence[3].shape # batch (52 elements of chunks 4*4*1024)

(torch.Size([52, 3, 4, 1024]), torch.Size([3, 4, 1024]))

In [30]:
_, seq_encoding_for_the_last_layer = layer.retrieve(input_sequence)

print(seq_encoding_for_the_last_layer.shape)

torch.Size([52, 1024])


In [31]:
logits = layer.directly_retrieve(seq_encoding_for_the_last_layer)

In [32]:
retreived_output_symbol_vector, m = embedding.decode(logits)
print(retreived_output_symbol_vector, m)

tensor([48505, 20110,  1108, 48019,    80, 19216,  9718,  1113,  4927,    66,
        19216,  1448,  2122, 41530,   187,  4172,   246,   659, 10986, 30975,
           80, 12931,   352, 14134,  2721,   258,  8830,    87,   826,    15,
        17532,   729, 26798, 41070,  6575,   299,   266,  3737, 20889,   287,
          512,   354,   250,   247,    70,   275, 16128,  2680,    74, 13679,
           15,   209]) tensor([0.8087, 0.9320, 0.9849, 1.0984, 1.2244, 1.3176, 1.0163, 1.0889, 0.9557,
        1.8996, 1.2476, 0.9599, 0.9788, 1.0013, 0.9819, 0.9918, 0.9456, 0.9295,
        0.9705, 1.0572, 1.1345, 0.9096, 0.9802, 0.9546, 1.4618, 0.9903, 1.0166,
        1.0116, 1.1751, 1.0554, 0.8882, 1.1188, 1.0335, 0.9970, 1.0631, 0.9999,
        0.9998, 0.9425, 0.9801, 0.9525, 0.9181, 0.9766, 1.1350, 1.0746, 1.0994,
        0.8784, 1.0130, 1.0656, 1.0741, 1.1679, 0.9479, 1.0618],
       grad_fn=<MaxBackward0>)


In [33]:
o = embedding.decode(current_output_symbols[:, -1])[0]
display(o)

print(sum(o == retreived_output_symbol_vector), 'over', retreived_output_symbol_vector.shape)

tensor([48505, 20110,  1108, 48019,    80, 19216,  9718,  1113,  4927,    66,
        19216,  1448,  2122, 41530,   187,  4172,   246,   659, 10986, 30975,
           80, 12931,   352, 14134,  2721,   258,  8830,    87,   826,    15,
        17532,   729, 26798, 41070,  6575,   299,   266,  3737, 20889,   287,
          512,   354,   250,   247,    70,   275, 16128,  2680,    74, 13679,
           15,   209])

tensor(52) over torch.Size([52])


In [34]:
#### test single block

In [35]:
print(input_sequence[3].shape, current_output_symbols[3].shape)
print(input_sequence[3][0].shape, current_output_symbols[3][0].shape)

torch.Size([3, 4, 1024]) torch.Size([3, 1024])
torch.Size([4, 1024]) torch.Size([1024])


In [36]:
total = 0
correct = 0 

for batch_index in range(len(memo_input[0]['input_ids'])):
    #print("input ids", memo_input[0]['input_ids'][batch_index])
    #print()
    
    for i in range(len(current_output_symbols[batch_index])):
        #display(embedding.decode(input_sequence[batch_index][i]), embedding.decode(current_output_symbols[batch_index][i]))
        true = embedding.decode(current_output_symbols[batch_index][i])[0].item()
        
        _, seq_encoding_for_the_last_layer  = layer.retrieve(input_sequence[batch_index][i].unsqueeze(0).unsqueeze(0))
        
        retreived_output_symbol_vector, m = embedding.decode(layer.directly_retrieve(seq_encoding_for_the_last_layer))
        pred = retreived_output_symbol_vector.item()

        total += 1
        correct += pred == true

print(f"{correct}/{total}")

153/156


### Check with batch size of 1 and output probs

In [37]:
memo_input = tokenizer.get_text_batch_encoding(['this is a test for a very short short sequence of 12 tokens'])[0]
input_ids, labels = memo_input['input_ids'], memo_input['labels']
print(input_ids, labels)

input_embeddings = embedding.encode(input_ids)
#print(input_embeddings.shape)

output_embeddings = embedding.encode(labels)
#print(output_embeddings.shape)


current_length = max_length

current_length = int(current_length/h)
input_sequence = input_embeddings.reshape((1, current_length, h, d))

output_symbols = output_embeddings[:, [(x+1)*h-1 for x in range(0,current_length)]] ## the output symbol is always the same tokem?
print(embedding.decode(output_symbols))
input_sequence.shape, output_symbols.shape

tensor([[2520,  310,  247, 1071,  323,  247, 1077, 2159, 2159, 3425,  273, 1249]]) tensor([[  310,   247,  1071,   323,   247,  1077,  2159,  2159,  3425,   273,
          1249, 21761]])
(tensor([[  323,  2159, 21761]]), tensor([[1.0219, 0.9891, 1.0037]]))


(torch.Size([1, 3, 4, 1024]), torch.Size([1, 3, 1024]))

In [38]:
embedding.decode(input_sequence)[0], embedding.decode(output_symbols)[0], input_ids

(tensor([[[2520,  310,  247, 1071],
          [ 323,  247, 1077, 2159],
          [2159, 3425,  273, 1249]]]),
 tensor([[  323,  2159, 21761]]),
 tensor([[2520,  310,  247, 1071,  323,  247, 1077, 2159, 2159, 3425,  273, 1249]]))

In [39]:
input_sequence.shape

torch.Size([1, 3, 4, 1024])

In [40]:
output_symbols.shape

torch.Size([1, 3, 1024])

In [41]:
layer = MeMoLayer(d, h)
display(layer)

## update the input sequence for the next layer
_, seq_encoding_for_the_last_layer = layer.memorize(input_sequence, output_symbols, is_last=False)
layer.directly_memorize(seq_encoding_for_the_last_layer)

for i in range(0,3):
    _, seq_encoding_for_the_last_layer = layer.retrieve(input_sequence[0][i].unsqueeze(0).unsqueeze(0))
    print(seq_encoding_for_the_last_layer.shape)
                                                        
    retreived_output_symbol_vector, m = embedding.decode(layer.directly_retrieve(seq_encoding_for_the_last_layer))
    print(retreived_output_symbol_vector, m)

MeMoLayer(
  (W_v_single_head): ProjectionTokens(in_features=1024, out_features=256)
  (Prj): ProjectionSequence((trasposed wrt saved one) in_features=4096, out_features=1024)
  (CMM): CorrelationMatrixMemory(in_features=1024, out_features=1024)
)

torch.Size([1, 1024])
tensor([323]) tensor([1.0180], grad_fn=<MaxBackward0>)
torch.Size([1, 1024])
tensor([2159]) tensor([1.0677], grad_fn=<MaxBackward0>)
torch.Size([1, 1024])
tensor([21761]) tensor([1.1110], grad_fn=<MaxBackward0>)


## Test the entire MeMo model

In [42]:
from MeMoPyTorch.modelling_memo import MeMo

In [43]:
from MeMoPyTorch.modelling_memo_tokenizer import MeMoTokenizer

In [44]:
with open("testo_di_prova.txt") as my_first_text_f:
    my_first_text = my_first_text_f.read()

In [45]:
max_length = 384
print(max_length, h)
tokenizer = MeMoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", 
                                          padding_side='left', truncation_side='left', 
                                          max_length=max_length, head_number=h)

384 4


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPTNeoXTokenizer'. 
The class this function is called from is 'MeMoTokenizer'.


Setting pad token and pad token id = <|endoftext|>, 0


In [46]:
memo_input = tokenizer.memo_heads_encode(my_first_text[0:10])
memo_input[0]['labels'].shape

torch.Size([1, 384])

In [47]:
device='cuda:0' #'cpu'

In [48]:
d,h,l

(1024, 4, 3)

In [49]:
model = MeMo(inner_dim=d, 
             num_of_heads=h, 
             num_of_layers=3, 
             chunk_length=max_length, 
             num_embeddings=tokenizer.vocab_size, 
             padding_idx=tokenizer.pad_token_id, 
             device=device) #MeMoModel
model

MeMo embedding initilialization


MeMo(
  (encoder): MeMoEmbedding(50254, 1024, padding_idx=0)
  (layers): ModuleList(
    (0-2): 3 x MeMoLayer(
      (W_v_single_head): ProjectionTokens(in_features=1024, out_features=256)
      (Prj): ProjectionSequence((trasposed wrt saved one) in_features=4096, out_features=1024)
      (CMM): CorrelationMatrixMemory(in_features=1024, out_features=1024)
    )
  )
)

In [50]:
import tqdm


class Evaluation:
    def check_memorization(self, model, tokenizer, text, # device='cpu',
                           starting_point=None):
        if starting_point == None:
            basic_block = model.h ** model.l
        else:
            basic_block = starting_point
        
        
        input_ = tokenizer(my_first_text, padding='longest', truncation='do_not_truncate', max_length=None)
        input_ = tokenizer.pad(input_, pad_to_multiple_of=basic_block)
        input_ids = input_['input_ids']
                
        count = 0
        correct = 0
        max_length = tokenizer.max_length
        (batch_size, number_of_tokens) = input_ids.shape

        #print(f"(batch_size, number_of_tokens) = {(batch_size, number_of_tokens)}")
        
        for i in tqdm.tqdm(range(basic_block,  number_of_tokens - 1)):
            text_tokens = input_ids[:, i - basic_block:i]
            
            (batch_size, number_of_tokens) = text_tokens.shape
            
            text_tokens = torch.concat((torch.zeros((batch_size, max_length-1-number_of_tokens), 
                                                    dtype=torch.int), 
                                        text_tokens), axis=1
                                      )
            
            #print(i - basic_block, i)
            out, max_value = model.retrieve(text_tokens)
            #print(out, input_ids[:, i])
            #print(out[0].item())
            
            count += batch_size
            correct += torch.sum(out.to('cpu') == input_ids[:, i])
        
                           
        return correct / count

    def check_pretokenized(self, model, tokenizer, input_ids,# device='cpu',
                           starting_point=None):
        if starting_point == None:
            basic_block = model.h ** model.l
        else:
            basic_block = starting_point
                
        count = 0
        correct = 0
        max_length = tokenizer.max_length
        (batch_size, number_of_tokens) = input_ids.shape

        #print(f"(batch_size, number_of_tokens) = {(batch_size, number_of_tokens)}")
        
        for i in tqdm.tqdm(range(basic_block,  number_of_tokens - 1)):
            text_tokens = input_ids[:, i - basic_block:i]
            
            (batch_size, number_of_tokens) = text_tokens.shape
            
            text_tokens = torch.concat((torch.zeros((batch_size, max_length-1-number_of_tokens), 
                                                    dtype=torch.int), 
                                        text_tokens), axis=1
                                      )
            
            #print(i - basic_block, i)
            out, max_value = model.retrieve(text_tokens)
            #print(out, input_ids[:, i])
            #print(out[0].item())
            
            count += batch_size
            correct += torch.sum(out.to('cpu') == input_ids[:, i])
        
                           
        return correct / count
        

In [51]:
model = MeMo(inner_dim=d, 
             num_of_heads=h, 
             num_of_layers=l, 
             chunk_length=max_length, 
             num_embeddings=tokenizer.vocab_size, 
             padding_idx=tokenizer.pad_token_id, 
             device=device)

memo_input = tokenizer.get_text_batch_encoding([my_first_text]*8)
model.memorize_text(memo_input)

e = Evaluation()
out = e.check_pretokenized(model, tokenizer, memo_input[0]['input_ids'])
print("Degree of memorization: %f ", out)

MeMo embedding initilialization


100%|██████████| 319/319 [00:00<00:00, 751.66it/s]

Degree of memorization: %f  tensor(0.9310)





In [52]:
model = MeMo(inner_dim=d, 
             num_of_heads=h, 
             num_of_layers=l, 
             chunk_length=max_length, 
             num_embeddings=tokenizer.vocab_size, 
             padding_idx=tokenizer.pad_token_id, 
             device=device)
print("CMM pre learning")
display(model.layers[0].CMM.weight)



memo_input = tokenizer.get_text_batch_encoding([my_first_text]*8)
(bs, ml) = memo_input[0]['input_ids'].shape

for b in range(bs):
    memo_single_input = {h: {k: memo_input[h][k][b].unsqueeze(0) for k in memo_input[h]} for h in memo_input}
    #print(memo_single_input[0]['input_ids'].shape)

    model.memorize_text(memo_single_input)

e = Evaluation()
out = e.check_pretokenized(model, tokenizer, memo_input[0]['input_ids'])
print("Degree of memorization: %f ", out)

MeMo embedding initilialization
CMM pre learning


Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', requires_grad=True)

100%|██████████| 319/319 [00:00<00:00, 770.52it/s]

Degree of memorization: %f  tensor(0.8574)





In [53]:
Prj = model.layers[0].Prj.weight.detach().cpu()
CMM = model.layers[0].CMM.weight.detach().cpu()

Prj.T @ Prj

tensor([[ 0.9837,  0.0131,  0.0149,  ...,  0.0116, -0.0091,  0.0254],
        [ 0.0131,  1.0073,  0.0030,  ..., -0.0221, -0.0078, -0.0083],
        [ 0.0149,  0.0030,  1.0084,  ...,  0.0109,  0.0149, -0.0034],
        ...,
        [ 0.0116, -0.0221,  0.0109,  ...,  0.9977, -0.0294, -0.0096],
        [-0.0091, -0.0078,  0.0149,  ..., -0.0294,  1.0146,  0.0030],
        [ 0.0254, -0.0083, -0.0034,  ..., -0.0096,  0.0030,  1.0199]])

In [54]:
CMM

tensor([[ 4.8936e-02, -1.9148e-02,  9.6987e-03,  ...,  3.7883e-02,
         -1.9608e-02, -1.3408e-02],
        [-2.8734e-02,  2.4769e-02,  6.7619e-03,  ...,  5.3196e-02,
         -6.7219e-03, -2.7895e-02],
        [-4.6326e-02, -5.1699e-02,  1.4755e-02,  ..., -5.4009e-02,
          8.8325e-03,  3.6818e-02],
        ...,
        [ 1.3820e-02,  2.2205e-02,  2.1027e-02,  ...,  2.3433e-02,
          3.4943e-02, -1.5351e-02],
        [ 1.2819e-02,  1.3980e-02,  2.4138e-02,  ...,  1.4555e-02,
          9.8114e-05,  1.5104e-02],
        [ 1.2497e-02,  2.0969e-02,  1.2242e-02,  ...,  1.8295e-02,
         -1.6273e-02, -1.4039e-02]])

In [55]:
out = e.check_pretokenized(model, tokenizer, memo_input[0]['input_ids'])
print("Degree of memorization: %f ", out)

100%|██████████| 319/319 [00:00<00:00, 763.36it/s]

Degree of memorization: %f  tensor(0.8574)





In [56]:
model.forget_text(memo_input)

In [57]:
out = e.check_pretokenized(model, tokenizer, memo_input[0]['input_ids'])
print("Degree of memorization: %f ", out)

100%|██████████| 319/319 [00:00<00:00, 742.21it/s]

Degree of memorization: %f  tensor(0.0909)



