In [None]:
import numpy as np
import matplotlib.pyplot as pl
import torch
import torch.nn as nn
import torch.optim as optim

from functools import partial

from tuning import Q10RBFNet
import json

In [None]:
E0_distributed=True

### Target functions

In [None]:
def target_func(f, beta=0.7, Q_0=2.):
    #return log Q in function of freq f. No level dependance
    f0 = 1000
    return np.log10(Q_0)+beta*(torch.log10(f)-np.log10(f0))


### RBF NeuralNet

In [None]:
#n=6
#net=Q10RBFNet(n, sig=0.3)
net=Q10RBFNet.create_from_jsonfile('RBF_params.json')

In [None]:

def plot_gauss(x, f, c, weight, sig):
    arr=weight*torch.exp(- (x-c)**2/(2*sig)**2)
    pl.plot(f, 10**arr, '--')

def plot_Q10(label='', plot_target=False, plot_rbfs=False):
    m=100
    x=torch.linspace(0,1,m)
    f = net.real_coord(x)

    out=net.forward(f)
    pl.plot(f.numpy(), 10**out.data.numpy()[:,0], label=label)
    if plot_target:
        target=target_func(f)
        pl.plot(f.numpy(), 10**target, label="target")
    if plot_rbfs:
        for i in range(net.n_centers):
            c=net.centers[i]
            weight=net.l2.weight[0, i]
            with torch.no_grad():
                plot_gauss(x, f, c, weight, net.sig)
    #pl.xscale('log')
    #pl.yscale('log')
    pl.xlabel('f')
    pl.xlim([800, 10000])
    #pl.legend()
    #pl.show()
    
plot_Q10(plot_rbfs=True)

In [None]:
net.centers

### Learning

In [None]:
lr = 2e-2
lr_centers=0
optimizer = optim.SGD([
    {'params':net.parameters()}, 
    {'params': [net.centers], 'lr':lr_centers}], #centers 
    lr=lr, momentum=0.9)

In [None]:
n_steps=100
batch_size=8
test_batch_size=256
criterion = nn.MSELoss()
verbose=True
step_test=5 #all step_test, estimate loss 
losses=[]

#mode for selectinf frequencies
#mode='random'
mode='fixed'

f_min=800.
f_max=15000.


#targetfunc=partial(target_func)

targetfunc=partial(target_func, beta=0.4, Q_0=1.5)

f_arr=torch.tensor([1500., 2200., 3000., 4000., 5000., 6000., 8000.])
for i in range(n_steps):
    optimizer.zero_grad()
    if mode =='random':
        f=f_min+(f_max-f_min)*torch.rand((batch_size, 1), requires_grad=False)
    else:
        ind=torch.randint(len(f_arr), (batch_size, 1))
        f=f_arr[ind]
    #random_values = torch.rand(batch_size,2, requires_grad=False)
    #I, f = net.real_coord(random_values[:,0], random_values[:,1])
    target=targetfunc(f)    
    target.unsqueeze_(-1)
    out=net.forward(f, verbose=(i%step_test==0))
    loss = criterion(target, out)
    loss.backward()
    optimizer.step()
    if verbose and i%step_test==0:
        #test
        
        random_values = torch.rand(test_batch_size,1, requires_grad=False)
        f = net.real_coord(random_values)
        out=net.forward(f)
        target=targetfunc(f)
        target.unsqueeze_(-1)
        loss = criterion(target, out)/test_batch_size
        grad_norm=net.l2.weight.grad.norm()
        losses.append(loss)
        #print("ex:I={:.1f} dB, f={:.1f} kHz, estimate={:.2f}, target={:.2f}".format(I[0].item(), f[0].item(),10**out[0].item(), 10**target[0].item()))
        print("step : {}, loss: {:.5f}, grad norm: {:.3f}".format(i, loss.data, grad_norm))
        
pl.figure()
pl.title("MSE loss")
pl.plot(range(0,n_steps, step_test), losses[0::])
pl.show()

### Distributed learning

In [None]:

import torch.distributed as dist

from datetime import timedelta

In [None]:
backend=dist.Backend('GLOO')
n_workers=2

In [None]:
dist.init_process_group(backend, init_method='tcp://127.0.0.1:1234', world_size=n_workers, rank=0)  

In [None]:
net.l2.weight

In [None]:
if E0_distributed:
    with open('E0_params.json') as f:
        params = json.load(f)        
        f_min=float(params['f_min'])
        f_max=float(params['f_max'])
        m=int(params['m'])

    E0=1/2*torch.ones((m,), dtype=torch.float64)

    #pl.plot(np.linspace(f_min*1e-3, f_max*1e-3, m), E0)
    #pl.xlabel('Frequency (kHz)')
    #pl.ylabel('Init raw excitation')

In [None]:
E0.shape

In [None]:
grad_E0=torch.zeros_like(E0, dtype=torch.float64)

In [None]:
grad_E0

In [None]:
for rank in range(1, n_workers):
    dist.send(net.l2.weight, rank, tag=7)

In [None]:
n_it=1 #100
nb_steps=1 #5
tot_steps=  n_it*nb_steps  #normally 3x but count only steps for Q10
it_step_plot=10

alpha=3
alpha_E0=3

if E0_distributed:
    grad_E0=torch.zeros_like(E0, dtype=torch.float64)
grad=torch.zeros_like(net.l2.weight)

for k_it in range(n_it):
    for rank in range(1, n_workers):  #the other nodes update weights at start of loop
        if E0_distributed:   
            dist.send(E0, rank, tag=8)
        dist.send(net.l2.weight, rank, tag=7)
        
    if E0_distributed:
        for step in range(nb_steps):
            for rank in range(1, n_workers): #gradients are forwarded by the other nodes
                dist.recv(grad_E0, src=rank, tag=2)
                E0.data-=alpha_E0*grad_E0
                            
    for step in range(nb_steps):
        for rank in range(1, n_workers): #gradients are forwarded by the other nodes
            dist.recv(grad, src=rank, tag=1)
            net.l2.weight.data-=alpha*grad
    if k_it%it_step_plot==0:
        plot_Q10(label=f'step {k_it}')
    
        
pl.legend()


In [None]:
plot_Q10()