In [141]:
%matplotlib widget

In [165]:
import numpy
import scipy
import scipy.optimize

from matplotlib import pyplot as plot

In [166]:
n = 10
d_in = 20
d_out = 20
rank = 2

In [167]:
''' generate data '''
X = numpy.random.randn(n, d_in)
A_true = numpy.dot(numpy.random.randn(d_in, rank), numpy.random.randn(d_out, rank).T) 
Y = numpy.dot(X, A_true)

In [168]:
''' noise '''
epsilon = numpy.random.randn(n, d_out)
Y_noise = Y + epsilon

In [169]:
def loss_low(W, rank, X, Y):
    W = W.reshape([d_in+d_out, rank])
    A = W[:d_in,:]
    B = W[d_in:,:]
    return ((numpy.dot(X, numpy.dot(A, B.T)) - Y) ** 2).sum(1).mean()

In [170]:
small_rank = 4
large_rank = 8

In [171]:
n_trials = 50

errors = []

for ti in range(n_trials):
    
    print('{}-th trial...'.format(ti+1),end='')

    ''' fit a small model '''
    W0 = numpy.random.randn(d_in+d_out, small_rank)[:]
    small_res = scipy.optimize.minimize(loss_low, W0, args=(small_rank, X, Y_noise), method='CG', 
                                  options={'disp': False,
                                           'maxiter': 100}, 
                                  tol=1e-2
                                 )
    small_error_clean = loss_low(small_res.x, small_rank, X, Y)

    ''' fit a large model '''
    W0 = numpy.random.randn(d_in+d_out, large_rank)[:]
    large_res = scipy.optimize.minimize(loss_low, W0, args=(large_rank, X, Y_noise), method='CG', 
                                  options={'disp': False,
                                           'maxiter': 100}, 
                                  tol=1e-2
                                 )

    large_error_clean = loss_low(large_res.x, large_rank, X, Y)

    ''' distillation '''
    A_large = large_res.x.reshape([-1, large_rank])[:d_in,:]
    B_large = large_res.x.reshape([-1, large_rank])[d_in:,:]
    Y_distil = numpy.dot(X, numpy.dot(A_large, B_large.T))
    W0 = numpy.random.randn(d_in+d_out, small_rank)[:]
    distil_res = scipy.optimize.minimize(loss_low, W0, args=(small_rank, X, Y_distil), method='CG', 
                                  options={'disp': False,
                                           'maxiter': 100}, 
                                  tol=1e-2
                                 )

    distil_error_clean = loss_low(distil_res.x, small_rank, X, Y)
    
    errors.append([small_error_clean, large_error_clean, distil_error_clean])
    
    print('Done')

1-th trial...Done
2-th trial...Done
3-th trial...Done
4-th trial...Done
5-th trial...Done
6-th trial...Done
7-th trial...Done
8-th trial...Done
9-th trial...Done
10-th trial...Done
11-th trial...Done
12-th trial...Done
13-th trial...Done
14-th trial...Done
15-th trial...Done
16-th trial...Done
17-th trial...Done
18-th trial...Done
19-th trial...Done
20-th trial...Done
21-th trial...Done
22-th trial...Done
23-th trial...Done
24-th trial...Done
25-th trial...Done
26-th trial...Done
27-th trial...Done
28-th trial...Done
29-th trial...Done
30-th trial...Done
31-th trial...Done
32-th trial...Done
33-th trial...Done
34-th trial...Done
35-th trial...Done
36-th trial...Done
37-th trial...Done
38-th trial...Done
39-th trial...Done
40-th trial...Done
41-th trial...Done
42-th trial...Done
43-th trial...Done
44-th trial...Done
45-th trial...Done
46-th trial...Done
47-th trial...Done
48-th trial...Done
49-th trial...Done
50-th trial...Done


In [175]:
numpy.mean(errors, 0)

array([13.40841242, 18.64377965, 13.41818219])

In [176]:
numpy.std(errors, 0)

array([0.25825559, 0.55703543, 0.25060436])

In [178]:
plot.figure()

errors = numpy.array(errors)
plot.plot(errors[:, 0], errors[:, 2], 'x')
# plot.plot(errors[:, 1], errors[:, 2], 'o')

plot.xlabel('small')
plot.ylabel('distil')
plot.ylim(plot.xlim())
plot.grid(True)
plot.tight_layout()
plot.show()

FigureCanvasNbAgg()