In [None]:
import numpy as np
import matplotlib.pyplot as pl
import torch
import torch.nn as nn
import torch.optim as optim

from functools import partial


from masking import *
from tuning import Q10RBFNet
import json

import time
import datetime
import sys

from rbf import RBFNet

In [None]:
E0_distributed=True
I0_distributed=True
write_results=True
expe_name='1-22' 

### Target functions

In [None]:
def target_func(f, beta=0.7, Q_0=2.):
    #return log Q in function of freq f. No level dependance
    f0 = 1000
    return np.log10(Q_0)+beta*(torch.log10(f)-np.log10(f0))


### RBF NeuralNet for Q10

In [None]:
#n=6
#net=Q10RBFNet(n, sig=0.3)
net=Q10RBFNet.create_from_jsonfile('RBF_params.json')

In [None]:

def plot_gauss(x, f, c, weight, sig, log=True, mult_factor=1.):
    arr=mult_factor*weight*torch.exp(- (x-c)**2/(2*sig)**2)
    if log:
        pl.plot(f, 10**arr, '--')
    else:
        pl.plot(f, arr, '--')

def plot_Q10(label='', plot_target=False, plot_rbfs=False):
    m=100
    x=torch.linspace(0,1,m)
    f = net.real_coord(x)

    out=net.forward(f)
    pl.plot(f.numpy(), 10**out.data.numpy()[:,0], label=label)
    if plot_target:
        target=target_func(f)
        pl.plot(f.numpy(), 10**target, label="target")
    if plot_rbfs:
        for i in range(net.n_centers):
            c=net.centers[i]
            weight=net.l2.weight[0, i]
            with torch.no_grad():
                plot_gauss(x, f, c, weight, net.sig)
    #pl.xscale('log')
    #pl.yscale('log')
    pl.xlabel('f')
    pl.xlim([800, 10000])
    #pl.legend()
    #pl.show()
    
plot_Q10(plot_rbfs=True)

In [None]:
net.centers

### RBF NeuralNet for I0

In [None]:
#n_I0=6
#net_I0=RBFNet(n_I0, sig=0.3)
net_I0=RBFNet.create_from_jsonfile('RBF_I0_params.json')

In [None]:
CFs=[3000, 4000, 5000, 6000]
results_folder=f'./results/fit{expe_name}-distrib/'
#results_folder=f'./results/fit{expe_name}/'
    

CFs=[3000, 4000, 5000, 6000, 8000]

I0s=[]

for CF in CFs:
    wb_cdf=WeibullCDF_IOFunc.load_from_npz(f'{results_folder}/wbcfdIO_{CF}.npz')
    I0s.append(wb_cdf.I0)

def target_func_I0(f):
    return np.interp(f, CFs, I0s)

    

In [None]:

def plot_I0(label='', plot_target=False, plot_rbfs=False):
    m=100
    x=torch.linspace(0,1,m)
    f = net_I0.real_coord(x)

    out=net_I0.forward(f)
    pl.plot(f.numpy(), out.data.numpy()[:,0], label=label)
    if plot_target:
        target=target_func_I0(f)
        pl.plot(f.numpy(), target, label="target")
    if plot_rbfs:
        for i in range(net_I0.n_centers):
            c=net_I0.centers[i]
            weight=net_I0.l2.weight[0, i]
            with torch.no_grad():
                plot_gauss(x, f, c, weight, net_I0.sig, log=False, mult_factor=net_I0.mult_factor)
    #pl.xscale('log')
    #pl.yscale('log')
    pl.xlabel('f')
    pl.xlim([800, 10000])
    #pl.legend()
    #pl.show()
    
plot_I0(plot_rbfs=True, plot_target=True)

### Learning

Q10

In [None]:
lr = 2e-2
lr_centers=0
optimizer = optim.SGD([
    {'params':net.parameters()}, 
    {'params': [net.centers], 'lr':lr_centers}], #centers 
    lr=lr, momentum=0.9)


In [None]:
n_steps=100
batch_size=8
test_batch_size=256
criterion = nn.MSELoss()
verbose=True
step_test=5 #all step_test, estimate loss 
losses=[]

#mode for selectinf frequencies
#mode='random'
mode='fixed'

f_min=800.
f_max=15000.


#targetfunc=partial(target_func)

targetfunc=partial(target_func, beta=0.4, Q_0=1.5)

f_arr=torch.tensor([1500., 2200., 3000., 4000., 5000., 6000., 8000.])
for i in range(n_steps):
    optimizer.zero_grad()
    if mode =='random':
        f=f_min+(f_max-f_min)*torch.rand((batch_size, 1), requires_grad=False)
    else:
        ind=torch.randint(len(f_arr), (batch_size, 1))
        f=f_arr[ind]
    #random_values = torch.rand(batch_size,2, requires_grad=False)
    #I, f = net.real_coord(random_values[:,0], random_values[:,1])
    target=targetfunc(f)    
    target.unsqueeze_(-1)
    out=net.forward(f, verbose=(i%step_test==0))
    loss = criterion(target, out)
    loss.backward()
    optimizer.step()
    if verbose and i%step_test==0:
        #test
        
        random_values = torch.rand(test_batch_size,1, requires_grad=False)
        f = net.real_coord(random_values)
        out=net.forward(f)
        target=targetfunc(f)
        target.unsqueeze_(-1)
        loss = criterion(target, out)/test_batch_size
        grad_norm=net.l2.weight.grad.norm()
        losses.append(loss)
        #print("ex:I={:.1f} dB, f={:.1f} kHz, estimate={:.2f}, target={:.2f}".format(I[0].item(), f[0].item(),10**out[0].item(), 10**target[0].item()))
        print("step : {}, loss: {:.5f}, grad norm: {:.3f}".format(i, loss.data, grad_norm))
        
pl.figure()
pl.title("MSE loss")
pl.plot(range(0,n_steps, step_test), losses[0::])
pl.show()

I0

In [None]:
lr = 1e-3
lr_centers=0
optimizer_I0 = optim.SGD([
    {'params':net_I0.parameters()}, 
    {'params': [net_I0.centers], 'lr':lr_centers}], #centers 
    lr=lr, momentum=0.9)


In [None]:
n_steps=200
batch_size=8
test_batch_size=256
criterion = nn.MSELoss()
verbose=True
step_test=5 #all step_test, estimate loss 
losses=[]

#mode for selectinf frequencies
mode='random'
#mode='fixed'

f_min=800.
f_max=10000.


targetfunc=target_func_I0

f_arr=torch.tensor([1500., 2200., 3000., 4000., 5000., 6000., 8000.])
for i in range(n_steps):
    optimizer_I0.zero_grad()
    if mode =='random':
        f=f_min+(f_max-f_min)*torch.rand((batch_size, 1), requires_grad=False)
    else:
        ind=torch.randint(len(f_arr), (batch_size, 1))
        f=f_arr[ind]
    #random_values = torch.rand(batch_size,2, requires_grad=False)
    #I, f = net.real_coord(random_values[:,0], random_values[:,1])
    target=targetfunc(f)    
    target=torch.tensor(target, dtype=torch.float)
    target.unsqueeze_(-1)
    out=net_I0.forward(f, verbose=(i%step_test==0))
    loss = criterion(target, out)
    loss.backward()
    optimizer_I0.step()
    if verbose and i%step_test==0:
        #test
        
        random_values = torch.rand(test_batch_size,1, requires_grad=False)
        f = net_I0.real_coord(random_values)
        out=net_I0.forward(f)
        target=targetfunc(f)
        target=torch.tensor(target, dtype=torch.float)
        target.unsqueeze_(-1)
        loss = criterion(target, out)/test_batch_size
        grad_norm=net_I0.l2.weight.grad.norm()
        losses.append(loss)
        #print("ex:I={:.1f} dB, f={:.1f} kHz, estimate={:.2f}, target={:.2f}".format(I[0].item(), f[0].item(),10**out[0].item(), 10**target[0].item()))
        print("step : {}, loss: {:.5f}, grad norm: {:.3f}".format(i, loss.data, grad_norm))
        
pl.figure()
pl.title("MSE loss")
pl.plot(range(0,n_steps, step_test), losses[0::])
pl.show()

In [None]:

    
plot_I0(plot_rbfs=True, plot_target=True)

### Distributed learning

In [None]:

import torch.distributed as dist

from datetime import timedelta

In [None]:
backend=dist.Backend('gloo')
n_workers=4

In [None]:
dist.init_process_group(backend, init_method='tcp://127.0.0.1:1234', world_size=n_workers, rank=0, 
                        timeout=datetime.timedelta(0, 20))  

In [None]:
if E0_distributed:
    with open('E0_params.json') as f:
        params = json.load(f)        
        f_min=float(params['f_min'])
        f_max=float(params['f_max'])
        m=int(params['m'])

    E0=1/2*torch.ones((m,), dtype=torch.float64)

    #pl.plot(np.linspace(f_min*1e-3, f_max*1e-3, m), E0)
    #pl.xlabel('Frequency (kHz)')
    #pl.ylabel('Init raw excitation')

In [None]:
def wait_handle(h, timeout=10, interval=0.02, name=''):
    start = time.time()
    
        
    #should be the normal way to go but it is bugged:
    '''while (not h.is_completed()) and time.time() - start < timeout:
        time.sleep(interval)
    '''
    try:
        h.wait()
    except RuntimeError as e:
        print(e)
        print(f'handle [{name}] not completed before timeout')
        
    

def wait_list_handles(l, names=None, timeout=10):
    for i, handle in enumerate(l):
        name = None if names is None else names[i]
        #handle.wait()
        wait_handle(handle, name=name, timeout=timeout)

Send weights for RBF net (Q10)

In [None]:
send_handles=[]
handle_names=[]
for rank in range(1, n_workers):
    handle=dist.isend(net.l2.weight, rank, tag=7)
    handle_name = f'update weights RBF rank {rank}'
    send_handles.append(handle)
    handle_names.append(handle_name)

wait_list_handles(send_handles, names=handle_names, timeout=10)

Send weights for RBF net (I0)

In [None]:
send_handles=[]
handle_names=[]
for rank in range(1, n_workers):
    handle=dist.isend(net_I0.l2.weight, rank, tag=17)
    handle_name = f'update weights RBF I0 rank {rank}'
    send_handles.append(handle)
    handle_names.append(handle_name)

wait_list_handles(send_handles, names=handle_names, timeout=10)

optim steps

In [None]:
it_step_plot=10
k_it=0
#TODO in param file?
alpha=1.5
alpha_E0=6
alpha_I0=0.15

nb_steps_arr=torch.ones((n_workers-1,), dtype=torch.int32)


if E0_distributed:
    grad_E0=torch.zeros_like(E0, dtype=torch.float64)
grad=torch.zeros_like(net.l2.weight)
grad_I0=torch.zeros_like(net_I0.l2.weight)

pl.figure(figsize=(6, 12))


ax1=pl.subplot(2,1,1)
ax2=pl.subplot(2,1,2)
while True:
    
    
    optim_done_handles=[]
    handle_names=[]
    for rank in range(1, n_workers):
        if nb_steps_arr[rank-1]>0:
            optim_done_handles.append(dist.irecv(nb_steps_arr[rank-1], rank, tag=16))
            handle_names.append(f'nb steps it {k_it} rank {rank}')
    wait_list_handles(optim_done_handles, names=handle_names)
    
    if torch.count_nonzero(nb_steps_arr) == 0:
        break
        
        
    #update E0
    if E0_distributed:   
        send_handles=[]
        handle_names=[]
        for rank in range(1, n_workers):  #the other nodes update weights at start of loop
            if nb_steps_arr[rank-1]>0:
                send_handles.append(dist.isend(E0, rank, tag=8))
                handle_names.append(f'update E0 it {k_it} rank {rank}')

        wait_list_handles(send_handles, names=handle_names)
    
    
    
    #update Q10
    send_handles=[]
    handle_names=[]
    for rank in range(1, n_workers):  #the other nodes update weights at start of loop
        if nb_steps_arr[rank-1]>0:
            send_handles.append(dist.isend(net.l2.weight, rank, tag=7))
            handle_names.append(f'update RBF weights it {k_it} rank {rank}')

    wait_list_handles(send_handles, names=handle_names)
    
    
    #update I0
    if I0_distributed:
        send_handles=[]
        handle_names=[]
        for rank in range(1, n_workers):  #the other nodes update weights at start of loop
            if nb_steps_arr[rank-1]>0:
                send_handles.append(dist.isend(net_I0.l2.weight, rank, tag=17))
                handle_names.append(f'update RBF weights (I0) it {k_it} rank {rank}')

        wait_list_handles(send_handles, names=handle_names)


    
    
    max_nb_steps=int(torch.amax(nb_steps_arr))
                     
    if E0_distributed:
        for step in range(1, max_nb_steps+1):
            for rank in range(1, n_workers): #gradients are forwarded by the other nodes
                if step<=nb_steps_arr[rank-1]:  
                    hand = dist.irecv(grad_E0, src=rank, tag=2000+step)
                    wait_handle(hand, name=f'grad E0 it {k_it} step {step} rank {rank}')
                    E0.data-=alpha_E0*grad_E0
                    
    if I0_distributed:
        for step in range(1, max_nb_steps+1):
            for rank in range(1, n_workers): #gradients are forwarded by the other nodes
                if step<=nb_steps_arr[rank-1]:  
                    hand = dist.irecv(grad_I0, src=rank, tag=3000+step)
                    wait_handle(hand, name=f'grad RBF weights (I0) it {k_it} step {step} rank {rank}')
                    net_I0.l2.weight.data-=alpha_I0*grad_I0
                    
    for step in range(1, max_nb_steps+1):
        for rank in range(1, n_workers): #gradients are forwarded by the other nodes
            if step<=nb_steps_arr[rank-1]:  
                hand=dist.irecv(grad, src=rank, tag=1000+step)
                wait_handle(hand, name=f'grad RBF weights it {k_it} step {step} rank {rank}')
                net.l2.weight.data-=alpha*grad

        
    k_it+=1
    
    if k_it%it_step_plot==0:
        pl.sca(ax1)
        plot_Q10(label=f'step {k_it}')
        pl.sca(ax2)
        plot_I0(label=f'step {k_it}')
        
pl.sca(ax1) 
pl.legend()

pl.sca(ax2) 
pl.legend()

pl.show()


In [None]:
plot_Q10()

In [None]:
plot_I0(plot_target=True)

In [None]:
if write_results:    
    m=100
    x=torch.linspace(0,1,m)
    f = net.real_coord(x)
    out=net.forward(f)
    Q10_val=10**out.data.numpy()[:,0]
    
    results_folder=f'./results/fit{expe_name}-distrib/'
    np.savez(f'{results_folder}/Q10.npz', f=f.detach().numpy(), Q10=Q10_val )