In [1]:
import torch
import networkx as nx
from models import fourier_nn
from utils import graph_generation
import copy

In [2]:
shape = [1, 3, 3, 1]
model = fourier_nn.FourierNet(shape, scale=0.05)
p = torch.nn.utils.parameters_to_vector(model.parameters())
models = {i: copy.deepcopy(model) for i in range(3)}
G = nx.wheel_graph(3)
W = graph_generation.get_metropolis(G)

In [3]:
def params_to_gradv(params):
    vec = []
    for p in params:
        vec.append(p.grad.flatten())
    
    return torch.cat(vec)

        

In [4]:
x = torch.linspace(-1, 1, 100).reshape(-1, 1)
y = torch.sin(x).reshape(-1, 1)

xlists = {}
ylists = {}
glists = {}

plists_base = list(models[0].parameters())

num_params = len(plists_base)
zero_plist = [torch.zeros_like(plists_base[i], requires_grad=False) for i in range(num_params)] 

for i in range(3):
    neighs = list(G.neighbors(i))
    neighs.append(i)
    xlists[i] = list(models[i].parameters())
    ylists[i] = copy.deepcopy(zero_plist)
    glists[i] = copy.deepcopy(zero_plist)

alph0 = 0.1

for k in range(10):
    alph = alph0 / (k + 1)
    for i in range(3):
        neighs = list(G.neighbors(i))
        with torch.no_grad():
            for p in range(num_params):
                xlists[i][p].multiply_(W[i, i])
                for j in neighs:
                    xlists[i][p].add_(W[i, j] * xlists[j][p])
                
        yh = models[i].forward(x)
        loss = torch.nn.MSELoss()(yh, y)
        loss.backward()
        print("Node {}: {}".format(i, loss.item()))
        
        with torch.no_grad():
            for p in range(num_params):
                xlists[i][p].add_(-alph * xlists[i][p].grad)
                xlists[i][p].grad.zero_()
        
    

Node 0: 0.5885430838316501
Node 1: 0.585681712742351
Node 2: 0.5818978540607133
Node 0: 0.5769032858524885
Node 1: 0.5744962589315452
Node 2: 0.5722258960322492
Node 0: 0.5704343849350564
Node 1: 0.5687576422108621
Node 2: 0.5673173526065557
Node 0: 0.5661468395576545
Node 1: 0.5649529976117065
Node 2: 0.5639171518283979
Node 0: 0.5630137183586907
Node 1: 0.5621085240219825
Node 2: 0.5612981882405922
Node 0: 0.5605619547641535
Node 1: 0.5598364716103827
Node 2: 0.559170195132305
Node 0: 0.5585513702492523
Node 1: 0.5579461605389912
Node 2: 0.557380642062202
Node 0: 0.556848369647432
Node 1: 0.5563290955356567
Node 2: 0.555838216142638
Node 0: 0.5553719965649142
Node 1: 0.5549172836588754
Node 2: 0.5544839236271466
Node 0: 0.5540695863154581
Node 1: 0.5536652211228937
Node 2: 0.5532775207871665


In [5]:
opt = torch.optim.SGD(models[0].parameters(), lr=0.01)

for k in range(3):
    opt.zero_grad()
    yh = models[0].forward(x)
    loss = torch.nn.MSELoss()(yh, y)
    loss.backward()
    print(loss)
    opt.step()

tensor(0.5533, grad_fn=<MseLossBackward>)
tensor(0.5525, grad_fn=<MseLossBackward>)
tensor(0.5518, grad_fn=<MseLossBackward>)


In [6]:
G2 = nx.erdos_renyi_graph(5, 0.8)

print(nx.is_connected(G2))

A = torch.randn((10, 10)).multiply_(5.0)
import math
math.sqrt(2.0)

True


1.4142135623730951