In [2]:
import torch as t
t.set_num_threads(8)
import pandas as pd
from train import train
from models import Transformer, AoT
from utils import generate_data, power_unif_law
from tqdm import tqdm

In [2]:
""" Experiment 1. Scaling laws on H with fixed d=d_head. """
t.manual_seed(2222)

# Model parameters.
N = 200
nb_layers = 1
nb_head = 1
n_gram = 3
context_window = n_gram

# Distribution parameters.
alphas = [1, 1, 1]
nb_tokens=[100, 100, 1]
pi = power_unif_law(alphas, nb_tokens, N)

# Training parameters.
batch_size=2**10
num_batch=1000
lr=1e-3
epochs=10
repetition = 2
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)

# Scaling parameters
d = 50

d_head=d 

mean_accuracy = []
para_list = []
N_list = []
d_list = []
d_head_list = []

for para in tqdm([1, 3, 5, 7, 9, 11]):
    accuracy = 0

    for _ in range(repetition):
        model = AoT(d, N, nb_layers, para, d_head, nb_head, context_window, pi)

        dict = train(model, Data, epochs, lr=lr, next_token=True)
        acc = sum(dict['Acc'][-101:-1])/100
            
        accuracy += acc

    mean_accuracy.append(accuracy/repetition)
    N_list.append(N)
    d_list.append(d)
    d_head_list.append(d_head)
    para_list.append(para)

results = {
    'acc': mean_accuracy,
    'para': para_list,
    'N': N_list,
    'd': d_list,
    'd_head': d_head_list,
}

# We save the results as a dataframe.
data = pd.DataFrame(results)
data.to_csv(f'Scaling laws/Data_exp_1_{7}_dim.csv', index=False)

100%|██████████| 10/10 [01:01<00:00,  6.14s/it]
100%|██████████| 10/10 [01:00<00:00,  6.09s/it]
100%|██████████| 10/10 [01:41<00:00, 10.12s/it]
100%|██████████| 10/10 [01:41<00:00, 10.15s/it]
100%|██████████| 10/10 [02:20<00:00, 14.03s/it]
100%|██████████| 10/10 [02:18<00:00, 13.85s/it]
100%|██████████| 10/10 [02:51<00:00, 17.19s/it]
100%|██████████| 10/10 [02:50<00:00, 17.09s/it]
100%|██████████| 10/10 [03:27<00:00, 20.74s/it]
100%|██████████| 10/10 [03:30<00:00, 21.00s/it]
100%|██████████| 10/10 [04:05<00:00, 24.53s/it]
100%|██████████| 10/10 [04:06<00:00, 24.68s/it]
100%|██████████| 6/6 [30:56<00:00, 309.44s/it]


In [3]:
""" Experiment 2. Scaling laws on d_head, with d!=d_head and H (=para) fixed. """
t.manual_seed(2222)

# Model parameters.
N = 200
d = 50
para = 8
nb_layers = 1
nb_head = 1
n_gram = 3
context_window = n_gram

# Distribution parameters.
alphas = [1, 1, 1]
nb_tokens=[100, 100, 1]
pi = power_unif_law(alphas, nb_tokens, N)

# Training parameters.
batch_size=2**10
num_batch=1000
lr=1e-3
epochs=10
repetition = 2
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)

# Scaling parameters

mean_accuracy = []
para_list = []
N_list = []
d_list = []
d_head_list = []
for d_head in tqdm([1, 10, 20, 30, 40, 50]):
    accuracy = 0

    for _ in range(repetition):
        model = AoT(d, N, nb_layers, para, d_head, nb_head, context_window, pi)

        dict = train(model, Data, epochs, lr=lr, next_token=True)
        acc = sum(dict['Acc'][-101:-1])/100
        
        accuracy += acc

    mean_accuracy.append(accuracy/repetition)
    N_list.append(N)
    d_list.append(d)
    d_head_list.append(d_head)
    para_list.append(para)

results = {
    'acc': mean_accuracy,
    'para': para_list,
    'N': N_list,
    'd': d_list,
    'd_head': d_head_list,
}

# We save the results as a dataframe.
data = pd.DataFrame(results)
data.to_csv(f'Scaling laws/Data_exp_2_dim.csv', index=False)

100%|██████████| 10/10 [02:16<00:00, 13.66s/it]
100%|██████████| 10/10 [02:21<00:00, 14.16s/it]
100%|██████████| 10/10 [03:54<00:00, 23.48s/it]
100%|██████████| 10/10 [03:59<00:00, 23.97s/it]
100%|██████████| 10/10 [04:56<00:00, 29.68s/it]
100%|██████████| 10/10 [04:46<00:00, 28.70s/it]
100%|██████████| 10/10 [08:29<00:00, 50.93s/it]
100%|██████████| 10/10 [08:56<00:00, 53.64s/it]
100%|██████████| 10/10 [10:33<00:00, 63.31s/it]
100%|██████████| 10/10 [10:17<00:00, 61.71s/it]
100%|██████████| 10/10 [10:23<00:00, 62.31s/it]]
100%|██████████| 10/10 [11:41<00:00, 70.16s/it]
100%|██████████| 6/6 [1:22:38<00:00, 826.34s/it] 


In [4]:
""" Experiment 5. Scaling laws on the width of Transformer using MLPs. """
t.manual_seed(3333)

# Model parameters.
N = 200
nb_layers = 1
nb_head = 1
n_gram = 3
context_window = n_gram

# Distribution parameters.
alphas = [1, 1, 1]
nb_tokens=[100, 100, 1]
pi = power_unif_law(alphas, nb_tokens, N)

# Training parameters.
batch_size=2**10
num_batch=1000
lr=1e-3
epochs=10
repetition = 2
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)

for d, exp_num in zip([40, 50, 60], [4, 7, 10]):
    para = 1
    d_head = d
    min_width = 2*d*(1-1)
    max_width = 2*d*(11-1)
    step = 2*d*2

    mean_accuracy = []
    para_list = []
    N_list = []
    d_list = []
    d_head_list = []
    width_list = []
    for width in tqdm(range(min_width, max_width+1, step)):
        accuracy = 0

        for _ in range(repetition):
            model = Transformer(d, N, nb_layers, width, para, d_head, nb_head, context_window, pi)

            dict = train(model, Data, epochs, lr=lr, next_token=True)
            acc = sum(dict['Acc'][-101:-1])/100
            
            accuracy += acc
            print(accuracy)

        mean_accuracy.append(accuracy/repetition)
        N_list.append(N)
        d_list.append(d)
        d_head_list.append(d_head)
        para_list.append(para)
        width_list.append(width)

    results = {
        'acc': mean_accuracy,
        'para': para_list,
        'N': N_list,
        'd': d_list,
        'd_head': d_head_list,
        'width': width_list,
    }

    # We save the results as a dataframe.
    data = pd.DataFrame(results)
    data.to_csv(f'Scaling laws/Data_exp_5_{exp_num}_dim.csv', index=False)


for d, exp_num in zip([40, 60], [4, 10]):
    d_head = d
    min_para = 1
    max_para = 11
    step = 2

    mean_accuracy = []
    para_list = []
    N_list = []
    d_list = []
    d_head_list = []
    width_list = []
    for para in tqdm(range(min_para, max_para+1, step)):
        accuracy = 0

        for _ in range(repetition):
            model = AoT(d, N, nb_layers, para, d_head, nb_head, context_window, pi)

            dict = train(model, Data, epochs, lr=lr, next_token=True)
            acc = sum(dict['Acc'][-101:-1])/100
            
            accuracy += acc
            print(accuracy)

        mean_accuracy.append(accuracy/repetition)
        N_list.append(N)
        d_list.append(d)
        d_head_list.append(d_head)
        para_list.append(para)
        width_list.append(width)

    results = {
        'acc': mean_accuracy,
        'para': para_list,
        'N': N_list,
        'd': d_list,
        'd_head': d_head_list,
        'width': width_list,
    }

    # We save the results as a dataframe.
    data = pd.DataFrame(results)
    data.to_csv(f'Scaling laws/Data_exp_1_{exp_num}_dim.csv', index=False)

100%|██████████| 10/10 [01:21<00:00,  8.15s/it]


0.085


100%|██████████| 10/10 [01:26<00:00,  8.64s/it]
 17%|█▋        | 1/6 [02:47<13:59, 168.00s/it]

0.14205078125


100%|██████████| 10/10 [01:51<00:00, 11.16s/it]


0.603525390625


100%|██████████| 10/10 [01:53<00:00, 11.38s/it]
 33%|███▎      | 2/6 [06:33<13:27, 201.78s/it]

1.134228515625


100%|██████████| 10/10 [02:14<00:00, 13.47s/it]


0.945595703125


100%|██████████| 10/10 [02:16<00:00, 13.65s/it]
 50%|█████     | 3/6 [11:04<11:40, 233.51s/it]

1.9330175781249999


100%|██████████| 10/10 [02:38<00:00, 15.85s/it]


0.99693359375


100%|██████████| 10/10 [02:40<00:00, 16.02s/it]
 67%|██████▋   | 4/6 [16:23<08:54, 267.15s/it]

1.9940527343749999


100%|██████████| 10/10 [02:55<00:00, 17.57s/it]


1.0


100%|██████████| 10/10 [02:57<00:00, 17.76s/it]
 83%|████████▎ | 5/6 [22:16<04:58, 298.21s/it]

2.0


100%|██████████| 10/10 [03:19<00:00, 19.92s/it]


1.0


100%|██████████| 10/10 [03:25<00:00, 20.51s/it]
100%|██████████| 6/6 [29:00<00:00, 290.16s/it]


2.0


100%|██████████| 10/10 [01:53<00:00, 11.39s/it]


0.070703125


100%|██████████| 10/10 [01:46<00:00, 10.63s/it]
 17%|█▋        | 1/6 [03:40<18:21, 220.32s/it]

0.155166015625


100%|██████████| 10/10 [02:26<00:00, 14.60s/it]


0.855869140625


100%|██████████| 10/10 [02:27<00:00, 14.72s/it]
 33%|███▎      | 2/6 [08:33<17:32, 263.21s/it]

1.7950292968750001


100%|██████████| 10/10 [02:18<00:00, 13.80s/it]


0.999619140625


100%|██████████| 10/10 [02:15<00:00, 13.58s/it]
 50%|█████     | 3/6 [13:07<13:24, 268.09s/it]

1.9996191406250001


100%|██████████| 10/10 [02:50<00:00, 17.06s/it]


1.0


100%|██████████| 10/10 [02:51<00:00, 17.14s/it]
 67%|██████▋   | 4/6 [18:49<09:54, 297.31s/it]

2.0


100%|██████████| 10/10 [03:16<00:00, 19.65s/it]


1.0


100%|██████████| 10/10 [03:26<00:00, 20.66s/it]
 83%|████████▎ | 5/6 [25:32<05:35, 335.50s/it]

2.0


100%|██████████| 10/10 [03:54<00:00, 23.50s/it]


1.0


100%|██████████| 10/10 [03:59<00:00, 23.92s/it]
100%|██████████| 6/6 [33:26<00:00, 334.48s/it]


2.0


100%|██████████| 10/10 [02:00<00:00, 12.01s/it]


0.072275390625


100%|██████████| 10/10 [02:04<00:00, 12.41s/it]
 17%|█▋        | 1/6 [04:04<20:21, 244.20s/it]

0.153017578125


100%|██████████| 10/10 [02:42<00:00, 16.30s/it]


0.99376953125


100%|██████████| 10/10 [02:40<00:00, 16.04s/it]
 33%|███▎      | 2/6 [09:27<19:23, 290.80s/it]

1.989619140625


100%|██████████| 10/10 [02:59<00:00, 17.90s/it]


1.0


100%|██████████| 10/10 [03:02<00:00, 18.22s/it]
 50%|█████     | 3/6 [15:28<16:08, 322.97s/it]

2.0


100%|██████████| 10/10 [03:40<00:00, 22.05s/it]


1.0


100%|██████████| 10/10 [03:43<00:00, 22.32s/it]
 67%|██████▋   | 4/6 [22:52<12:21, 370.65s/it]

2.0


100%|██████████| 10/10 [04:12<00:00, 25.23s/it]


1.0


100%|██████████| 10/10 [04:12<00:00, 25.23s/it]
 83%|████████▎ | 5/6 [31:17<06:58, 419.00s/it]

2.0


100%|██████████| 10/10 [03:50<00:00, 23.05s/it]


1.0


100%|██████████| 10/10 [03:35<00:00, 21.59s/it]
100%|██████████| 6/6 [38:43<00:00, 387.31s/it]


2.0


100%|██████████| 10/10 [01:13<00:00,  7.38s/it]


0.057578125


100%|██████████| 10/10 [01:12<00:00,  7.22s/it]
 17%|█▋        | 1/6 [02:25<12:09, 145.97s/it]

0.11748046875000001


100%|██████████| 10/10 [01:52<00:00, 11.25s/it]


0.30662109375


100%|██████████| 10/10 [01:52<00:00, 11.25s/it]
 33%|███▎      | 2/6 [06:10<12:49, 192.47s/it]

0.525712890625


100%|██████████| 10/10 [02:36<00:00, 15.65s/it]


0.43767578125


100%|██████████| 10/10 [02:34<00:00, 15.41s/it]
 50%|█████     | 3/6 [11:21<12:19, 246.44s/it]

0.831826171875


100%|██████████| 10/10 [03:13<00:00, 19.35s/it]


0.659892578125


100%|██████████| 10/10 [03:15<00:00, 19.59s/it]
 67%|██████▋   | 4/6 [17:51<10:05, 302.91s/it]

1.20080078125


100%|██████████| 10/10 [03:54<00:00, 23.48s/it]


0.69677734375


100%|██████████| 10/10 [03:55<00:00, 23.55s/it]
 83%|████████▎ | 5/6 [25:41<06:03, 363.29s/it]

1.515615234375


100%|██████████| 10/10 [04:36<00:00, 27.67s/it]


0.797119140625


100%|██████████| 10/10 [04:36<00:00, 27.60s/it]
100%|██████████| 6/6 [34:54<00:00, 349.05s/it]


1.5873339843749998


100%|██████████| 10/10 [01:22<00:00,  8.24s/it]


0.089501953125


100%|██████████| 10/10 [01:23<00:00,  8.31s/it]
 17%|█▋        | 1/6 [02:45<13:47, 165.54s/it]

0.19283203125


100%|██████████| 10/10 [02:20<00:00, 14.04s/it]


0.42228515625


100%|██████████| 10/10 [02:20<00:00, 14.05s/it]
 33%|███▎      | 2/6 [07:26<15:33, 233.43s/it]

0.89716796875


100%|██████████| 10/10 [03:14<00:00, 19.43s/it]


0.8915625


100%|██████████| 10/10 [03:15<00:00, 19.51s/it]
 50%|█████     | 3/6 [13:55<15:13, 304.67s/it]

1.584775390625


100%|██████████| 10/10 [04:05<00:00, 24.57s/it]


0.998251953125


100%|██████████| 10/10 [04:06<00:00, 24.63s/it]
 67%|██████▋   | 4/6 [22:08<12:37, 378.66s/it]

1.95501953125


100%|██████████| 10/10 [05:00<00:00, 30.06s/it]


0.99845703125


100%|██████████| 10/10 [05:03<00:00, 30.33s/it]
 83%|████████▎ | 5/6 [32:11<07:39, 459.88s/it]

1.99783203125


100%|██████████| 10/10 [05:53<00:00, 35.39s/it]


1.0


100%|██████████| 10/10 [05:54<00:00, 35.41s/it]
100%|██████████| 6/6 [43:59<00:00, 439.99s/it]

2.0



