In [None]:
import math

import torch
from torch.nn.functional import normalize

import numpy as np
from sklearn.metrics import accuracy_score
from pandas import DataFrame
import pandas as pd

def memorization_experiement(mem_rate = 0.0001, num_of_heads=4,embed_size=2048,embed_out_size=2048,vocab_size=2000,batch_size=80,num_of_iterations=1000, different_tokens=True,equal_tokens=0):
    #Initialization of the embedding matrices
    # Embedding for sequence elements
    embed_seq_size = embed_size // num_of_heads
    
    seq_symbol_embed = torch.normal(0, 1/math.sqrt(embed_seq_size), size=(vocab_size,embed_seq_size))
    
    # Embedding for output elements # USING embed_seq_size instead of embed_size
    out_symbol_embed = torch.normal(0, 1/math.sqrt(embed_size), size=(vocab_size,embed_out_size))

    # Initialization of the key-value matrix
    W0 = torch.zeros(size=(embed_size,embed_out_size), requires_grad=False)

    np.random.seed(22)
    results = []
    stored_pairs = 0
    avg_a = 1.0
    step = 0
    #for epoch in range(num_of_iterations):
    while avg_a > 0.9:
        samples = batch_size
        if different_tokens:
            input_seq = np.random.randint(0, vocab_size, size=(samples, num_of_heads))
        else:
            ### Sequences with only one different token 
            #equal_tokens = 2
            input_seq = np.zeros((samples, num_of_heads))
            input_sequence = np.random.randint(0, vocab_size, size=(1,num_of_heads))
            for i in range(samples):
                input_seq[i] = input_sequence
            token_to_change = np.random.randint(0, vocab_size, size=(samples,num_of_heads-equal_tokens))
            input_seq[:,0:num_of_heads-equal_tokens] = token_to_change
            print(f"IN {input_seq}")
        
        output_symb = np.random.randint(0,vocab_size, size=(samples))

        # embed_size x samples
        keys = torch.transpose(seq_symbol_embed[:][input_seq.reshape(1, num_of_heads * samples)].reshape(samples, embed_size), 0, 1)
        # samples x embed_size
        values = out_symbol_embed[output_symb]

        W0 = W0 + mem_rate*torch.matmul(keys,values)

        if W0[0][0] > 10000.0:
            W0 = normalize(W0)
            print("normalized")

        #if epoch == num_of_iterations - 1 :
        #    W0 = W0 - torch.matmul(keys_0, values_0)
        #    print("Removing first batch")



        # samples x vocab_size
        out = torch.argmax(torch.matmul(torch.matmul(torch.transpose(keys,0,1),W0),torch.transpose(out_symbol_embed,0,1)),dim=1)
        out = out.cpu().detach().numpy()
        a = accuracy_score(output_symb,out)
        stored_pairs += samples
        if stored_pairs%(10*samples) == 0: print(f".{stored_pairs}", end="")
        if step == 0:
            output_symb_0 = output_symb
            keys_0 = keys
            values_0 = values
            a_0 = "N/A"
            avg_a = a
        else:
            out = torch.argmax(
                torch.matmul(torch.matmul(torch.transpose(keys_0, 0, 1), W0), torch.transpose(out_symbol_embed, 0, 1)),
                dim=1)
            out_0 = out.cpu().detach().numpy()
            a_0 = accuracy_score(output_symb_0, out_0)
            avg_a = (a + a_0)/2

        #print(f"{stored_pairs} \t {a_0} \t {a}", end=" - ")
        step += 1
    #print("procedure ended")
    return {"stored_pairs":stored_pairs, "num_of_heads":num_of_heads,"embed_size":embed_size,"embed_OUT_size":embed_out_size,"vocab_size":vocab_size,
            "batch_size":batch_size}



In [None]:
results = DataFrame()
c_vocab_size = 100000

for c_in_h_embed_size in [512//32,1024//32,2048//32,4096//32,8192//32]:
    for c_out_size in [512,1024,2048,4096,8192]:
        for c_head in [2,4,8,16,32]:
            c_embed_size = c_in_h_embed_size*c_head
            print(c_embed_size)
            out = memorization_experiement(mem_rate = 1.0, num_of_iterations=100,vocab_size=c_vocab_size,\
                                            embed_size=c_embed_size, embed_out_size=c_out_size, num_of_heads=c_head, batch_size=1000)
            results = pd.concat([results,DataFrame([out])])
            print(out)
results.to_excel("MemorizationResults.xlsx")

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

df = pd.read_excel("MemorizationResults.xlsx")
size = 100000
df["num_of_params"] = df["embed_size"]*df["embed_OUT_size"]
df["stored_pairs"] = df["stored_pairs"]
#print(df)
#for c_out_size in [512,1024,2048,4096,8192]:
#for h in [2,4,8,16,32]: 
plt.plot("num_of_params","stored_pairs",data=df.loc[(df["vocab_size"]==size)][["stored_pairs",
"num_of_params"]], label=f"${c_out_size}$",marker="*",linewidth=0)

x = df["num_of_params"]
y = df["stored_pairs"]

z = np.polyfit(x, y, 1)
p = np.poly1d(z)
plt.plot(x,p(x),linestyle="solid",linewidth=0.5 )

plt.ylabel("Sequences")
plt.xlabel("Parameters")
#plt.title("vocabolary size=100,000")
#plt.legend()
#plt.show()
plt.savefig("CMM_storing_capacity_vs_NoParameters.png",dpi=300)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

df = pd.read_excel("MemorizationResults.xlsx")
size = 100000
#df["num_of_params"] = df["embed_size"]*df["embed_OUT_size"]
df["stored_pairs"] = df["stored_pairs"]
#print(df)
for c_out_size in [512,1024,2048,4096,8192]:
#for h in [2,4,8,16,32]: #&(df["embed_OUT_size"]==c_out_size)&(df["embed_size"]==c_out_size)
    plt.plot("num_of_heads","stored_pairs",data=df.loc[(df["vocab_size"]==size)&(df["embed_OUT_size"]==c_out_size)&(df["embed_OUT_size"]==df["embed_size"])][["stored_pairs",
"num_of_heads"]],label = f"$d$ = {c_out_size}", marker="o",linewidth=1)

#x = df["num_of_params"]
#y = df["stored_pairs"]

#z = np.polyfit(x, y, 1)
#p = np.poly1d(z)
#plt.plot(x,p(x),linestyle="solid",linewidth=0.5 )

plt.ylabel("Sequences")
plt.xlabel("Parameters")
#plt.title("vocabolary size=100,000")
plt.legend()
#plt.show()
plt.savefig("CMM_storing_capacity_vs_dim.png",dpi=300)