In [None]:
import numpy as np
from matplotlib import pyplot as plt
from time import time

In these experiments, we try to use circular regression to solve the following problem:
given pairs (a, b=a*s mod p), find s.

First, we visualize the loss and gradient. 

In [None]:
def likelihood(pred, a, b, p):
    diff = 2 * np.pi / p * (b - a * pred)
    loss = -np.sum(np.cos(diff))
    grad = -2 * np.pi / p * np.sum(a * np.sin(diff))
    return loss, grad

def plot_examples(p, s):
        a = np.arange(p)
        b = (a*s) % p

        x = [s-p//2 + 0.03*i for i in range(int(p/0.03))]
        likelihoods = np.array([likelihood(x, a, b, p) for x in x])
        losses, grads = likelihoods[:,0], likelihoods[:,1]

        xi = [i for i in range(s-p//2, s)] + [i for i in range(s+1, s+p//2+1)]
        likelihoodsi = np.array([likelihood(x, a, b, p) for x in xi])
        _, gradsi = likelihoodsi[:,0], likelihoodsi[:,1]

        _, axs = plt.subplots(1,3, figsize = (20, 5))
        axs[0].plot(x, losses)
        axs[0].set_ylabel('loss')
        axs[0].set_xlabel('prediction')
        axs[0].set_title(f'Circular Regression Loss, p={p}, s={s}')

        axs[1].plot(x, grads)
        axs[1].scatter(xi, gradsi, s=7, c='r')
        axs[1].set_ylabel('gradient')
        axs[1].set_xlabel('prediction')
        axs[1].set_title(f'Circular Regression Gradient, p={p}, s={s}')
        
        axs[2].scatter(xi, 1/gradsi, s=5)
        axs[2].set_ylabel('1 / gradient')
        axs[2].set_xlabel('prediction')
        axs[2].set_title(f'grad_r = 1 / gradient, p={p}, s={s}')

In [None]:
p, s = 41, 3
# for s in np.random.choice(p, 3):
# for p in [23, 41, 71, 113, 251, 367, 967, 1471]:
plot_examples(p, s)

From the plots above, we conclude that the loss is lowest at the correct answers but much closer to 0 everywhere else. For simplicity, we only show one interval of length p. It's periodic. Although it has a local minimum in each interval of length p, all the local minima are the global minimum. The gradient at integer points always has the sign that points to the closest correct answer. 

However, the gradient's magnitude is giving the opposite information of how large a step we want to take. It's extremely large when it's close to the answer. So, instead of using a fixed learning rate multiplied on the gradient, we try taking the reciprocal of the gradient, implemented as below.

In [None]:
def circ_reg(p, A, B, lr, bs):
    batch_size = min(len(A), bs)
    # Augment the dataset if the original dataset is too small and not the full set
    # a, b = [], []
    # for idx in np.random.choice(len(A), size=(batch_size, 3)):
    #     a.append(sum(A[idx]) % p)
    #     b.append(sum(B[idx]) % p)
    # a, b = np.array(a), np.array(b)
    indices = np.random.choice(len(A), size=batch_size)
    a, b = A[indices], B[indices]
    pred = np.random.choice(p) # init guess
    ll, grad = likelihood(pred, a, b, p)
    t, lls, preds = 0, [ll], [pred]
    best_result, min_loss = pred, ll
    while t < p:
        # use the reciprocal of the gradient, multiplied by the batch_size
        pred -= lr * batch_size / grad
        pred %= p
        ll, grad = likelihood(pred, a, b, p)
        lls.append(ll)
        preds.append(pred % p)
        if ll < min_loss:
            min_loss, best_result = ll, pred
        if verify(a, b, pred):
            best_result = pred
            # print(ll, ll/batch_size)
            break
        t += 1
    return np.round(best_result % p, 5), preds, t

def verify(a, b, pred):
    # return np.abs(pred - s) < 0.5
    err = ((a[:20]*np.round(pred))- b[:20]) % p
    err[err > p//2] -= p
    if np.std(err) < 6:
        return True
    return False

np.random.seed(0)
for batch_size in [64, 128, 256, 512]:
    success = []
    for lr in [0.5, 1, 2]:
        for p in [251, 1471, 11197]: # 251, 1471, 11197, 130769
            size = p-1 #int(np.sqrt(p))
            a = np.random.choice([k for k in range(1,p)], size=size, replace=False)
            steps = []
            for s in np.random.choice([k for k in range(1,p)], 20, replace=False):
                # print('s =', s)
                b = (a*s) % p + np.random.normal(0, 3, size=size).astype(int)
                starttime = time()
                prediction, preds, t = circ_reg(p,a,b,lr, batch_size)
        #         plt.plot(preds - s)
        #         plt.show()
        #         plt.close()
                # print(f'p={p}, secret={s}, prediction={prediction}, time={np.round(time()-starttime, 2)}s, steps={t}')
                if np.abs(prediction - s) < 0.5:
                    steps.append(t)
            success.append(len(steps))
            if lr==2 and batch_size==256:
                print(p, len(steps), sorted(steps))
    print('/20 & '.join([str(t) for t in success]))

In [None]:
np.random.seed(0)
batch_size = 256
lr = 2
for p in [20663, 42899, 115301, 222553]: 
    size = p-1 
    a = np.random.choice([k for k in range(1,p)], size=size, replace=False)
    steps = []
    for s in np.random.choice([k for k in range(1,p)], 20, replace=False):
        b = (a*s) % p + np.random.normal(0, 3, size=size).astype(int)
        starttime = time()
        prediction, preds, t = circ_reg(p,a,b,lr, batch_size)
        if np.abs(prediction - s) < 0.5:
            steps.append(t)
    success.append(len(steps))
    print(p, lr, batch_size, len(steps), sorted(steps))