In [2]:
import torch as t
t.set_num_threads(8)
import pandas as pd
from train import train
from models import Transformer, AoT
from utils import generate_data, power_unif_law
from tqdm import tqdm

In [None]:
""" Experiment 1. Scaling laws on H with fixed d=d_head. """
t.manual_seed(2222)

# Model parameters.
N = 50
nb_layers = 5 #Depth of the network
nb_head = 1
n_gram = 3
context_window = n_gram

# Distribution parameters.
alphas = [1, 1, 1]
nb_tokens=[100, 100, 1]
pi = power_unif_law(alphas, nb_tokens, N)

# Training parameters.
batch_size=2**10
num_batch=1000
lr=5e-4
epochs=10
repetition = 2
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)

# Scaling parameters
d = 10

d_head=d 

mean_accuracy = []
para_list = []
N_list = []
d_list = []
d_head_list = []

for para in tqdm([1, 6, 11, 16, 21]):
    accuracy = 0

    for _ in range(repetition):
        model = AoT(d, N, nb_layers, para, d_head, nb_head, context_window, pi)

        dict = train(model, Data, epochs, lr=lr, next_token=True)
        acc = sum(dict['Acc'][-101:-1])/100
            
        accuracy += acc

    mean_accuracy.append(accuracy/repetition)
    N_list.append(N)
    d_list.append(d)
    d_head_list.append(d_head)
    para_list.append(para)

results = {
    'acc': mean_accuracy,
    'para': para_list,
    'N': N_list,
    'd': d_list,
    'd_head': d_head_list,
}

# We save the results as a dataframe.
data = pd.DataFrame(results)
data.to_csv(f'Scaling laws/Data_exp_1_{7}_depth.csv', index=False)

100%|██████████| 10/10 [01:26<00:00,  8.61s/it]
100%|██████████| 10/10 [01:35<00:00,  9.58s/it]
100%|██████████| 10/10 [06:36<00:00, 39.66s/it]
100%|██████████| 10/10 [06:42<00:00, 40.29s/it]
100%|██████████| 10/10 [10:31<00:00, 63.10s/it]
100%|██████████| 10/10 [10:30<00:00, 63.03s/it]
100%|██████████| 10/10 [15:16<00:00, 91.65s/it]
100%|██████████| 10/10 [14:59<00:00, 89.93s/it]
100%|██████████| 10/10 [16:31<00:00, 99.13s/it]t]
100%|██████████| 10/10 [18:12<00:00, 109.24s/it]
100%|██████████| 5/5 [1:42:22<00:00, 1228.50s/it]


: 

In [None]:
""" Experiment 2. Scaling laws on d_head, with d!=d_head and H (=para) fixed. """
t.manual_seed(2222)

# Model parameters.
N = 50
d = 10
para = 21
nb_layers = 5 # Depth of the network
nb_head = 1
n_gram = 3
context_window = n_gram

# Distribution parameters.
alphas = [1, 1, 1]
nb_tokens=[100, 100, 1]
pi = power_unif_law(alphas, nb_tokens, N)

# Training parameters.
batch_size=2**10
num_batch=1000
lr=5e-4
epochs=10
repetition = 2
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)

# Scaling parameters
mean_accuracy = []
para_list = []
N_list = []
d_list = []
d_head_list = []
for d_head in tqdm([1, 3, 5, 7, 10]):
    accuracy = 0

    for _ in range(repetition):
        model = AoT(d, N, nb_layers, para, d_head, nb_head, context_window, pi)

        dict = train(model, Data, epochs, lr=lr, next_token=True)
        acc = sum(dict['Acc'][-101:-1])/100
        
        accuracy += acc

    mean_accuracy.append(accuracy/repetition)
    N_list.append(N)
    d_list.append(d)
    d_head_list.append(d_head)
    para_list.append(para)

results = {
    'acc': mean_accuracy,
    'para': para_list,
    'N': N_list,
    'd': d_list,
    'd_head': d_head_list,
}

# We save the results as a dataframe.
data = pd.DataFrame(results)
data.to_csv(f'Scaling laws/Data_exp_2_depth.csv', index=False)

100%|██████████| 10/10 [09:07<00:00, 54.71s/it]
100%|██████████| 10/10 [09:13<00:00, 55.38s/it]
100%|██████████| 10/10 [14:15<00:00, 85.56s/it]t]
100%|██████████| 10/10 [13:44<00:00, 82.43s/it]
100%|██████████| 10/10 [14:14<00:00, 85.42s/it]t]
100%|██████████| 10/10 [14:48<00:00, 88.84s/it]
100%|██████████| 10/10 [17:27<00:00, 104.76s/it]]
100%|██████████| 10/10 [18:01<00:00, 108.17s/it]
100%|██████████| 10/10 [24:58<00:00, 149.89s/it]]
100%|██████████| 10/10 [26:20<00:00, 158.02s/it]
100%|██████████| 5/5 [2:42:11<00:00, 1946.38s/it]


: 

In [3]:
""" Experiment 5. Scaling laws on the width of Transformer using MLPs. """
t.manual_seed(3333)

# Model parameters.
N = 50
para = 1
nb_layers = 5 # Depth of the network
nb_head = 1
n_gram = 3
context_window = n_gram

# Distribution parameters.
alphas = [1, 1, 1]
nb_tokens=[100, 100, 1]
pi = power_unif_law(alphas, nb_tokens, N)

# Training parameters.
batch_size=2**11
num_batch=2000
lr=5e-4
epochs=10
repetition = 2
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)

for d, exp_num in zip([7, 10, 13], [4, 7, 10]):
    d_head = d
    min_width = 2*d*(1-1)
    max_width = 2*d*(21-1)
    step = 2*d*5

    mean_accuracy = []
    para_list = []
    N_list = []
    d_list = []
    d_head_list = []
    width_list = []
    for width in tqdm(range(min_width, max_width+1, step)):
        accuracy = 0

        for _ in range(repetition):
            model = Transformer(d, N, nb_layers, width, para, d_head, nb_head, context_window, pi)

            dict = train(model, Data, epochs, lr=lr, next_token=True)
            acc = sum(dict['Acc'][-101:-1])/100
            
            accuracy += acc
            print(accuracy)

        mean_accuracy.append(accuracy/repetition)
        N_list.append(N)
        d_list.append(d)
        d_head_list.append(d_head)
        para_list.append(para)
        width_list.append(width)

    results = {
        'acc': mean_accuracy,
        'para': para_list,
        'N': N_list,
        'd': d_list,
        'd_head': d_head_list,
        'width': width_list,
    }

    # We save the results as a dataframe.
    data = pd.DataFrame(results)
    data.to_csv(f'Scaling laws/Data_exp_5_{exp_num}_depth.csv', index=False)


for d, exp_num in zip([7, 13], [4, 10]):
    d_head = d
    min_para = 1
    max_para = 21
    step = 5

    mean_accuracy = []
    para_list = []
    N_list = []
    d_list = []
    d_head_list = []
    width_list = []
    for para in tqdm(range(min_para, max_para+1, step)):
        accuracy = 0

        for _ in range(repetition):
            model = AoT(d, N, nb_layers, para, d_head, nb_head, context_window, pi)

            dict = train(model, Data, epochs, lr=lr, next_token=True)
            acc = sum(dict['Acc'][-101:-1])/100
            
            accuracy += acc
            print(accuracy)

        mean_accuracy.append(accuracy/repetition)
        N_list.append(N)
        d_list.append(d)
        d_head_list.append(d_head)
        para_list.append(para)
        width_list.append(width)

    results = {
        'acc': mean_accuracy,
        'para': para_list,
        'N': N_list,
        'd': d_list,
        'd_head': d_head_list,
        'width': width_list,
    }

    # We save the results as a dataframe.
    data = pd.DataFrame(results)
    data.to_csv(f'Scaling laws/Data_exp_1_{exp_num}_depth.csv', index=False)

100%|██████████| 10/10 [05:12<00:00, 31.24s/it]


0.0713134765625


100%|██████████| 10/10 [05:46<00:00, 34.66s/it]
 20%|██        | 1/5 [10:59<43:57, 659.41s/it]

0.16455078125


100%|██████████| 10/10 [10:12<00:00, 61.29s/it]


0.182373046875


100%|██████████| 10/10 [13:06<00:00, 78.63s/it]
 40%|████      | 2/5 [34:18<54:43, 1094.57s/it]

0.58119140625


100%|██████████| 10/10 [20:24<00:00, 122.44s/it]


0.435


100%|██████████| 10/10 [18:40<00:00, 112.08s/it]
 60%|██████    | 3/5 [1:13:23<55:31, 1665.65s/it]

0.926962890625


100%|██████████| 10/10 [20:48<00:00, 124.83s/it]


0.3352978515625


100%|██████████| 10/10 [17:52<00:00, 107.23s/it]
 80%|████████  | 4/5 [1:52:04<32:04, 1924.23s/it]

0.831572265625


100%|██████████| 10/10 [20:48<00:00, 124.84s/it]


0.6718212890625


100%|██████████| 10/10 [16:20<00:00, 98.08s/it]
100%|██████████| 5/5 [2:29:13<00:00, 1790.74s/it]


1.171162109375


100%|██████████| 10/10 [04:18<00:00, 25.88s/it]


0.140830078125


100%|██████████| 10/10 [04:05<00:00, 24.55s/it]
 20%|██        | 1/5 [08:24<33:37, 504.27s/it]

0.262548828125


100%|██████████| 10/10 [09:02<00:00, 54.21s/it]


0.5894189453125


100%|██████████| 10/10 [08:17<00:00, 49.80s/it]
 40%|████      | 2/5 [25:44<40:58, 819.47s/it]

1.160009765625


100%|██████████| 10/10 [11:33<00:00, 69.31s/it]


0.572939453125


100%|██████████| 10/10 [11:41<00:00, 70.19s/it]
 60%|██████    | 3/5 [48:59<36:04, 1082.29s/it]

1.3535839843750002


100%|██████████| 10/10 [15:16<00:00, 91.66s/it]


0.9320654296875


100%|██████████| 10/10 [15:22<00:00, 92.28s/it]
 80%|████████  | 4/5 [1:19:38<23:01, 1381.21s/it]

1.556513671875


100%|██████████| 10/10 [18:29<00:00, 110.98s/it]


0.40263671875


100%|██████████| 10/10 [18:28<00:00, 110.88s/it]
100%|██████████| 5/5 [1:56:37<00:00, 1399.49s/it]


1.30408203125


100%|██████████| 10/10 [04:03<00:00, 24.36s/it]


0.1987353515625


100%|██████████| 10/10 [04:05<00:00, 24.59s/it]
 20%|██        | 1/5 [08:09<32:37, 489.45s/it]

0.3898291015625


100%|██████████| 10/10 [10:02<00:00, 60.23s/it]


0.98287109375


100%|██████████| 10/10 [09:58<00:00, 59.89s/it]
 40%|████      | 2/5 [28:10<45:24, 908.14s/it]

1.9610058593750002


100%|██████████| 10/10 [14:41<00:00, 88.14s/it]


0.9741259765625


100%|██████████| 10/10 [14:15<00:00, 85.59s/it]
 60%|██████    | 3/5 [57:07<42:53, 1286.74s/it]

1.9724560546875


100%|██████████| 10/10 [18:41<00:00, 112.19s/it]


0.9471435546875


100%|██████████| 10/10 [18:42<00:00, 112.22s/it]
 80%|████████  | 4/5 [1:34:32<27:44, 1664.73s/it]

1.9471435546875


100%|██████████| 10/10 [21:13<00:00, 127.33s/it]


1.0


100%|██████████| 10/10 [21:17<00:00, 127.73s/it]
100%|██████████| 5/5 [2:17:02<00:00, 1644.55s/it]


2.0


100%|██████████| 10/10 [03:29<00:00, 20.98s/it]


0.0962841796875


100%|██████████| 10/10 [03:28<00:00, 20.85s/it]
 20%|██        | 1/5 [06:58<27:53, 418.36s/it]

0.199248046875


100%|██████████| 10/10 [13:16<00:00, 79.65s/it]


0.2126171875


100%|██████████| 10/10 [13:04<00:00, 78.40s/it]
 40%|████      | 2/5 [33:18<55:05, 1101.99s/it]

0.45463378906249996


100%|██████████| 10/10 [21:49<00:00, 130.94s/it]


0.2982763671875


100%|██████████| 10/10 [22:03<00:00, 132.33s/it]
 60%|██████    | 3/5 [1:17:11<1:00:01, 1800.98s/it]

0.6292041015625001


100%|██████████| 10/10 [31:33<00:00, 189.40s/it]


0.468037109375


100%|██████████| 10/10 [31:40<00:00, 190.03s/it]
 80%|████████  | 4/5 [2:20:26<43:07, 2587.95s/it]  

0.89357421875


100%|██████████| 10/10 [40:48<00:00, 244.82s/it]


0.370390625


100%|██████████| 10/10 [41:10<00:00, 247.07s/it]
100%|██████████| 5/5 [3:42:25<00:00, 2669.00s/it]


0.9461962890625


100%|██████████| 10/10 [04:01<00:00, 24.20s/it]


0.1841357421875


100%|██████████| 10/10 [04:02<00:00, 24.21s/it]
 20%|██        | 1/5 [08:04<32:16, 484.11s/it]

0.3803662109375


100%|██████████| 10/10 [16:07<00:00, 96.75s/it]


0.593037109375


100%|██████████| 10/10 [16:12<00:00, 97.28s/it]
 40%|████      | 2/5 [40:24<1:07:02, 1340.69s/it]

1.284013671875


100%|██████████| 10/10 [27:27<00:00, 164.79s/it]


0.9460400390625


100%|██████████| 10/10 [27:25<00:00, 164.51s/it]
 60%|██████    | 3/5 [1:35:17<1:14:24, 2232.16s/it]

1.9065380859375


 70%|███████   | 7/10 [29:56<12:50, 256.68s/it]
 60%|██████    | 3/5 [2:05:14<1:23:29, 2504.72s/it]


KeyboardInterrupt: 