In [1]:
import numpy as np

In [485]:
def loss(w, data_, labels_):
    return np.sum((labels_**2 - w.dot(data_)**2)**2)
def loss_grad(w, data_, labels_):
    return -data_.dot((labels_**2 - w.dot(data_)**2) * w.dot(data_))

In [525]:
p = 512
n = 4096
data = np.random.normal(scale=1/p, size=(p, n))
weights = np.random.normal(scale=1, size=p)
labels = np.abs(weights.dot(data))
cur_weights = np.random.normal(scale=np.sqrt(p), size=p)
print(loss(cur_weights, data, labels))
print(loss_grad(cur_weights, data, labels))

10913.005863613136

array([ 3.60118433e-03,  3.16865122e-01, -1.02492253e+00, -1.34071807e+00,
       -6.42842908e-01,  1.32912782e+00, -1.26918248e+00,  9.52485271e-02,
       -9.00526828e-01, -7.52997585e-01, -1.22219649e+00,  1.56131139e+00,
        1.36211413e+00,  1.84050231e-01,  4.70597572e-01,  4.59676368e-01,
       -1.55749304e+00, -1.01619847e+00, -2.28919984e-01, -1.34661464e+00,
       -1.67352355e+00, -5.35010695e-01,  4.85459234e-01,  2.82429997e-01,
       -3.33798895e-01, -1.24132324e-01, -1.55852218e-02, -1.31687220e+00,
       -2.49295539e+00, -1.98912894e-01, -1.65849231e+00, -6.40891047e-01,
        8.15448718e-01, -9.65882052e-02,  6.49531989e-01,  4.49280404e-01,
        2.04548529e-01, -9.70369559e-01, -8.16288046e-01, -4.33495286e-01,
        7.15191227e-01,  7.07487743e-01,  1.01572171e+00, -5.51088650e-01,
        7.11825008e-01,  9.26231242e-01,  4.15390628e-01,  1.04767671e+00,
       -1.44995308e+00, -9.09607366e-02,  4.27177910e-01, -2.01190645e+00,
       -4.90361780e-01,  

In [426]:
def e_i(i, val, size):
    result = np.zeros(size)
    result[i] = val
    return result

In [433]:
eps = 1e-6
step = e_i(4, eps, p)
(loss(cur_weights + step, data, labels) - loss(cur_weights - step, data, labels)) / (2 * eps)

-0.46855997104522373

In [497]:
p = 512
n = 4096

In [397]:
loss_grad(np.random.normal(size=p), data, labels), weights

(array([ 0.00176857, -0.01133505, -0.00070198, -0.00749542, -0.00608518,
         0.02123821,  0.01477514, -0.00155244, -0.01660302,  0.00544844,
         0.00962066,  0.01645802, -0.00385004,  0.00903544,  0.00226575,
        -0.00707754,  0.00458118, -0.00153326, -0.00496154,  0.0002014 ,
         0.01579167,  0.00715456,  0.00313445,  0.00520388, -0.00527272,
         0.00315297, -0.00430841, -0.00576472,  0.02328908, -0.00132865,
        -0.0118212 ,  0.00917392]),
 array([ 0.65060471, -1.40874358,  2.22791365,  0.3520362 ,  0.87856342,
         0.7096469 ,  0.10349277,  1.66323156, -1.06512056, -0.38180196,
         0.39062964, -0.00824127, -0.04520208, -0.19572806,  0.2482545 ,
        -0.67351836, -1.09678748,  0.34363184,  1.19522063, -0.63432228,
         0.44267345, -1.55656164,  0.57683619, -1.4250233 , -0.33449485,
        -0.20941073,  0.18635833, -0.21757797, -0.78460516, -0.70494459,
        -0.52336127, -0.96374793]))

In [349]:
data = np.random.normal(scale=1 / p, size=(p, n))
weights = np.random.normal(size=p)
labels = np.abs(weights.dot(data))

In [350]:
data.shape

(512, 4096)

In [362]:
n_epoch = 2500
batch_size = 1

In [359]:
lr = 20.0
momentum_gamma = 0.9

In [380]:
a = np.arange(10)

In [383]:
np.array_split(a, 1)

[array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]

In [571]:
def check_success_sgd(data, labels, _type='sgd', n_epoch=2500, batch_size=64, lr=20.0, momentum_gamma=0.9):
    losses = []
    cur_weights = np.random.normal(size=p)
    momentum = np.zeros_like(cur_weights)
    for epoch in range(n_epoch):
        indices = np.arange(data.shape[1])
        np.random.shuffle(indices)
        for batch_indices in np.array_split(indices, len(indices) // batch_size):
            batch_data = data[:, batch_indices]
            batch_labels = labels[batch_indices]
            if _type == 'sgd':
                momentum = lr * loss_grad(cur_weights, batch_data, batch_labels)
            elif _type == 'momentum':
                momentum *= momentum_gamma
                momentum += lr * loss_grad(cur_weights, batch_data, batch_labels)
            elif _type == 'nesterov':
                momentum *= momentum_gamma
                momentum += lr * loss_grad(cur_weights - momentum, batch_data, batch_labels)
            else:
                raise ValueError('Bad type')
            cur_weights -= momentum
        if epoch % 10 == 0:
            cur_loss = loss(cur_weights, data, labels)
            losses.append(cur_loss)
            if cur_loss < 1e-6:
                return 1, losses
    final_loss = loss(cur_weights, data, labels)
    losses.append(final_loss)
    if final_loss < 1e-6:
        return 1, losses
    return 0, losses

In [550]:
p = 512
n = 4096
data = np.random.normal(scale=1/p, size=(p, n))
data /= np.linalg.norm(data, axis=0)
weights = np.random.normal(scale=1, size=p)
weights /= np.linalg.norm(weights) / np.sqrt(p)
labels = np.abs(weights.dot(data))
cur_weights = np.random.normal(scale=np.sqrt(p), size=p)
print(loss(cur_weights, data, labels))
print(loss_grad(cur_weights, data, labels))

3193781967.9915333
[-1.54952864e+05  3.29253176e+05 -4.75765321e+05 -2.08650495e+05
 -3.01817316e+05 -4.98240598e+05  6.83609430e+04  1.62825533e+05
 -4.86477069e+04 -1.89431150e+05 -9.92106868e+04 -5.13738345e+05
 -5.47222993e+04  6.25156359e+04 -2.86302584e+05  6.82440014e+03
  5.58036198e+04 -1.62422333e+05  2.09369308e+05 -2.33261098e+05
 -1.40979714e+05  2.57345254e+04 -5.23030605e+04 -3.94382404e+05
  2.37387171e+04  4.24097145e+05 -5.55578739e+02  1.85526898e+04
 -1.67205227e+05 -2.13443348e+05 -7.57158062e+04  3.52912764e+05
 -9.61477885e+03 -2.00331280e+05 -1.35998694e+05 -1.55922917e+05
 -7.93442153e+05 -5.67349408e+03 -9.11156679e+05 -7.38162934e+04
  3.09412327e+05 -2.18289555e+05 -6.47734884e+04  1.68520589e+05
  2.21805275e+05 -2.46765896e+05 -3.58058697e+04 -3.35771437e+05
 -2.72757834e+05 -1.53404501e+05 -1.00633049e+05  1.31192758e+05
  6.07776582e+04 -4.51686258e+05  3.82234799e+05 -5.20892803e+05
 -3.70365721e+05 -1.43129934e+05 -5.61707138e+05 -1.47518292e+05
 -2.27

In [564]:
check_success_sgd(data, labels, batch_size=1, lr=0.006)

13106.231114056061
9409.677882587404
8074.587050813709
6747.836259952226
5456.6416908288575
4068.274858529893
2480.2693325159285
1055.8743808850963
385.00719029019217
149.5450378823127
63.4609294573813
28.957908425677708
13.96585061594898
7.026917153707767
3.654555854340486
1.9518100668988816
1.0653181743667126
0.5920302770171303
0.3340238494499081


KeyboardInterrupt: 

In [567]:
check_success_sgd(data, labels, _type='nesterov', batch_size=1, lr=0.006)

9270.504281523825
5914.485715663894
5481.42187859892
5191.500982241072
2050.5230322432703
0.4608069335048589
0.0012713659074349674
4.549010828764521e-06
1.8385174398541986e-08


(1,
 [9270.504281523825,
  5914.485715663894,
  5481.42187859892,
  5191.500982241072,
  2050.5230322432703,
  0.4608069335048589,
  0.0012713659074349674,
  4.549010828764521e-06,
  1.8385174398541986e-08])

## Experiments
1. GD
2. SGD, bs = 64
3. SGD w/ momentum, bs = 64
4. SGD w/ nesterov, bs = 64
5. SGD, bs = 1
6. SGD, w/ momentum, bs = 1
7. SGD w/ nesterov, bs = 1

In [369]:
def check_success(data, labels, lr=5.0):
    cur_weights = np.random.normal(size=p)
    losses = []
    for i in range(int(15000)):
        if i % 600 == 0:
            cur_loss = loss(cur_weights, data, labels)
            losses.append(cur_loss)
#             print(f"loss value:{cur_loss:.9f}")
            if cur_loss < 1e-6:
                break
        cur_weights -= lr * loss_grad(cur_weights, data, labels)
    final_loss = loss(cur_weights, data, labels)
    losses.append(final_loss)
    if final_loss < 1e-6:
        return 1, losses
    return 0, losses

In [370]:
from copy import copy

In [371]:
total_result = {}

In [338]:
# total_result

{32: [2, 18, 20, 20, 20],
 64: [0, 14, 17, 20, 20],
 128: [0, 0, 15, 20, 20],
 256: [0, 0, 0, 5, 18]}

In [400]:
# total_result = {}

In [404]:
total_result

{(32, 2, 'sgd', 1): (0,
  [[0.07014768924979527,
    0.008422011553136897,
    0.005371954445234556,
    0.004473954935031572,
    0.0039884589800114586,
    0.003793092611439311,
    0.0036868635074657915,
    0.003621502743820686,
    0.003581470183908154,
    0.0035574653522631883,
    0.003543303959027879,
    0.0035349856574618164,
    0.0035300956829544365,
    0.0035272064583485134,
    0.0035255229372989673,
    0.0035244919685541148,
    0.0035238997464632827,
    0.003523521844636995,
    0.003523293638799958,
    0.0035231753297279133,
    0.0035231116535179006,
    0.003523063661864218,
    0.003523028335689503,
    0.003523017873833359,
    0.003523004711325029,
    0.0035230067497274203],
   [0.05982385072055854,
    0.007604534647830516,
    0.005053423847940008,
    0.004098087511508898,
    0.002905094071383364,
    0.002567384719089552,
    0.002470469177852463,
    0.0024281213597538375,
    0.0024077069953517705,
    0.00239686156967362,
    0.002390527839067523,
  

In [572]:
new_result = {}

In [574]:
new_result.keys()

dict_keys([(32, 2, 'sgd', 1), (32, 2, 'sgd', 32), (32, 2, 'sgd', 64), (32, 2, 'momentum', 1), (32, 2, 'momentum', 32), (32, 2, 'momentum', 64), (32, 2, 'nesterov', 1), (32, 2, 'nesterov', 32), (32, 2, 'nesterov', 64), (32, 4, 'sgd', 1), (32, 4, 'sgd', 32), (32, 4, 'sgd', 128), (32, 4, 'momentum', 1), (32, 4, 'momentum', 32), (32, 4, 'momentum', 128), (32, 4, 'nesterov', 1), (32, 4, 'nesterov', 32), (32, 4, 'nesterov', 128), (32, 6, 'sgd', 1), (32, 6, 'sgd', 32), (32, 6, 'sgd', 192), (32, 6, 'momentum', 1), (32, 6, 'momentum', 32), (32, 6, 'momentum', 192), (32, 6, 'nesterov', 1), (32, 6, 'nesterov', 32), (32, 6, 'nesterov', 192), (32, 8, 'sgd', 1), (32, 8, 'sgd', 32), (32, 8, 'sgd', 256), (32, 8, 'momentum', 1), (32, 8, 'momentum', 32), (32, 8, 'momentum', 256), (32, 8, 'nesterov', 1), (32, 8, 'nesterov', 32), (32, 8, 'nesterov', 256), (64, 2, 'sgd', 1), (64, 2, 'sgd', 32), (64, 2, 'sgd', 128), (64, 2, 'momentum', 1), (64, 2, 'momentum', 32), (64, 2, 'momentum', 128), (64, 2, 'nesterov

In [None]:
for p in [32, 64, 128, 256, 512, 1024]:
#     if p in total_result:
#         continue
    result = []
    for alpha in [4, 6, 8, 10]:
        if alpha == 10 and p < 100:
            continue
        n = p * alpha
        print(p, alpha)
        data = np.random.normal(scale=1 / p, size=(p, n))
        data /= np.linalg.norm(data, axis=0)
        weights = np.random.normal(size=p) 
        weights /= np.linalg.norm(weights) / np.sqrt(p)
        labels = np.abs(weights.dot(data))
        for _type in ['sgd', 'momentum', 'nesterov']:
            for batch_size in [1, 32, n]:
                if (p, alpha, _type, batch_size) in new_result:
                    continue
#                 if alpha > 2 and total_result[(p, alpha - 2, _type, batch_size)][0] == 40:
#                     total_result[(p, alpha, _type, batch_size)] = (40, None)
#                     continue
                print(_type, batch_size)
                n_successes = 0
                losses = []
                for _ in range(40):
                    cur_success, cur_losses = check_success_sgd(data, 
                                                                labels, 
                                                                _type=_type, 
                                                                lr=0.006,
                                                                batch_size=batch_size)
                    losses.append(cur_losses)
                    n_successes += cur_success
                print(n_successes)
                new_result[(p, alpha, _type, batch_size)] = (n_successes, losses)
#     print(result)

32 4
32 6
32 8
64 4
64 6
64 8
128 4
128 6
128 8
128 10
256 4
256 6
sgd 1
29
sgd 32


In [236]:
for p in [512, 1024, 2048]:
    if p in total_result:
        continue
    result = []
    for alpha in [6, 8, 10, 12, 14]:
        n = p * alpha
        print(p, alpha)
        data = np.random.normal(scale=1 / p, size=(p, n))
        weights = np.random.normal(size=p)
        labels = np.abs(weights.dot(data))
        n_successes = 0
        for _ in range(20):
            n_successes += check_success()
        result.append((alpha, n_successes))
    total_result[p] = copy(result)
    print(result)

512 6
512 8


KeyboardInterrupt: 

In [213]:
total_result = {}
total_result[512] = [0, 5, 5, 10, 10]
total_result[1024] = [0, 0, 2, 8, 10]
total_result[2048] = [0, 0, 0, 6, 10]

In [217]:
total_result

{512: [0, 5, 5, 10, 10],
 1024: [0, 0, 2, 8, 10],
 2048: [0, 0, 0, 6, 10],
 32: [4, 10, 0, 0, 0],
 64: [1, 8, 4, 0, 0],
 128: [0, 9, 10, 0, 9],
 256: [0, 7, 9, 10, 10]}

In [207]:
a = []
a.append([])
a[-1].append(5)

In [208]:
a

[[5]]

In [75]:
labels

array([0.01677878, 0.08457702, 0.20521354, 0.03434182, 0.15281656,
       0.12892193, 0.18662645, 0.09923389, 0.0969255 , 0.05328051,
       0.17229276, 0.02789723, 0.12236121, 0.05719339, 0.10668909,
       0.04613221, 0.00195335, 0.0480027 , 0.08887515, 0.06619979,
       0.03381431, 0.07764184, 0.10488952, 0.11959433, 0.08841942,
       0.16383346, 0.02195313, 0.02018177, 0.07840474, 0.07863012,
       0.12649795, 0.07916297, 0.05778262, 0.03162335, 0.23578297,
       0.0665193 , 0.09954116, 0.24490331, 0.04492546, 0.05108796,
       0.17106192, 0.06945352, 0.1703306 , 0.05123332, 0.02600311,
       0.10965685, 0.06380515, 0.07008547, 0.00683129, 0.02807649,
       0.18006602, 0.17939731, 0.01581155, 0.09852935, 0.00350503,
       0.06684381, 0.01473957, 0.11387163, 0.1241995 , 0.06818919,
       0.03449555, 0.14791524, 0.0765093 , 0.01313494, 0.16157505,
       0.06633702, 0.00336668, 0.11781031, 0.03796668, 0.07041825,
       0.09500857, 0.04464702, 0.16741173, 0.05814827, 0.05122

In [74]:
np.abs(cur_weights.dot(data))

array([2.52717663e-02, 7.05611011e-02, 2.08836951e-01, 6.19046135e-03,
       1.55450406e-01, 1.23759492e-01, 2.00692708e-01, 6.47001652e-02,
       7.51685781e-02, 6.55485089e-02, 1.72712249e-01, 1.29659012e-03,
       1.09657710e-01, 6.19178369e-02, 9.13531332e-02, 4.22690921e-02,
       5.58749473e-03, 5.96762247e-02, 9.82875180e-02, 5.50250233e-02,
       4.43569279e-02, 8.51653053e-02, 8.44956828e-02, 1.19149182e-01,
       5.31903349e-02, 1.25871059e-01, 2.27109171e-02, 5.72449182e-03,
       7.27546819e-02, 5.74879792e-02, 1.08196092e-01, 7.92510477e-02,
       7.24727836e-02, 4.80022253e-02, 2.41905367e-01, 7.05530873e-02,
       1.01930565e-01, 2.50274240e-01, 1.63729746e-02, 6.97003706e-02,
       1.71881115e-01, 5.82978162e-02, 1.47809538e-01, 5.39727482e-02,
       6.09269978e-05, 1.09286854e-01, 5.16191091e-02, 7.75067461e-02,
       6.76709382e-03, 5.01492504e-02, 1.86291349e-01, 1.72776180e-01,
       9.29955663e-03, 1.09623933e-01, 1.35735493e-02, 6.78164708e-02,
      

In [73]:
weights

array([-1.25707266, -0.92741478, -0.23652388,  0.71013505, -0.10184017,
        1.94400742,  1.10687753, -0.40480706, -1.08183323,  1.24653129,
        1.02223361,  0.44140892,  1.65734218, -0.82202994, -2.0220724 ,
       -1.80748581,  0.4671527 , -1.20390023,  0.11114245, -0.24902347,
       -0.5360634 ,  0.5744897 , -0.76417533,  0.71679965,  2.49108828,
       -0.28495541,  0.81928736, -0.15401804,  1.06486637, -0.54958599,
       -0.09039982, -0.01090677,  1.27712931, -0.11650647,  0.54913775,
       -0.6008138 , -1.90239667,  1.41390573,  0.12719018, -0.46388095,
       -0.38301101,  1.09553378,  0.45568261,  0.74475355,  0.71348645,
       -0.59382527,  2.04932034, -0.99184526,  0.2962404 , -0.45870736,
       -1.30403219,  1.18957194, -0.6306315 , -1.00965924, -0.85456369,
       -1.36560597,  0.50679704,  0.44527572, -1.00766439,  1.51620152,
        1.89610206,  0.22731783, -1.24310188,  0.94508382, -0.84850357,
        2.14118278, -0.80240669,  0.05536301, -0.27481162,  0.52

In [72]:
cur_weights

array([ 1.4153121 ,  0.79897771,  0.27085228, -0.81408592, -0.05455871,
       -1.85034081, -1.16900235,  0.50394768,  0.83369663, -1.02471107,
       -0.89930012, -0.45205982, -1.78263957,  0.805075  ,  1.91682158,
        1.54217518, -0.5405915 ,  1.32330883, -0.15465093,  0.268061  ,
        0.95128283, -0.42356193,  0.94923663, -0.60200423, -2.71449715,
        0.2744569 , -0.84567935,  0.24867107, -0.7722769 ,  0.48490552,
        0.05064379, -0.14503728, -1.10214968,  0.01060762, -0.36071113,
        0.50157718,  2.16256642, -1.60127709,  0.02864447,  0.5751773 ,
        0.33977996, -1.25774758, -0.14023223, -0.8215842 , -1.09091268,
        0.21285054, -1.91622302,  0.95192484, -0.36102435,  0.18017129,
        1.23327731, -1.16426137,  0.68273263,  1.28869946,  0.93547824,
        1.20422896, -0.51124704, -0.2992268 ,  1.17620095, -1.43219331,
       -1.59601805, -0.27369566,  1.39037345, -0.90924912,  0.51409184,
       -1.8994751 ,  0.51943223, -0.12665832,  0.339321  , -0.50