In [2]:
import torch
import networkx as nx
from models import fourier_nn
from utils import graph_generation
import copy

In [3]:
shape = [1, 3, 3, 1]
model = fourier_nn.FourierNet(shape, scale=0.05)
p = torch.nn.utils.parameters_to_vector(model.parameters())
models = {i: copy.deepcopy(model) for i in range(3)}
G = nx.wheel_graph(3)
W = graph_generation.get_metropolis(G)

In [4]:
def params_to_gradv(params):
    vec = []
    for p in params:
        vec.append(p.grad.flatten())
    
    return torch.cat(vec)

        

In [7]:
x = torch.linspace(-1, 1, 100).reshape(-1, 1)
y = torch.sin(x).reshape(-1, 1)

xlists = {}
glists = {}
ylists = {}

plists_base = list(models[0].parameters())

num_params = len(plists_base)
zero_plist = [torch.zeros_like(plists_base[i], requires_grad=False) for i in range(num_params)] 

for i in range(3):
    neighs = list(G.neighbors(i))
    neighs.append(i)
    xlists[i] = list(models[i].parameters())
    glists[i] = copy.deepcopy(zero_plist)
    ylists[i] = copy.deepcopy(zero_plist)

alph0 = 0.1

# initialize y and glists
for i in range(3):
    yh = models[i].forward(x)
    loss = torch.nn.MSELoss()(yh, y)
    loss.backward()
    
    with torch.no_grad():
        for p in range(num_params):
            ylists[i][p] = xlists[i][p].grad.detach().clone()
            glists[i][p] = xlists[i][p].grad.detach().clone()
            xlists[i][p].grad.zero_()


for k in range(10):
    alph = alph0 / (k + 1)
    for i in range(3):
        print("Node {}: {}".format(i, loss.item()))
        
        neighs = list(G.neighbors(i))
        with torch.no_grad():
            for p in range(num_params):
                xlists[i][p].multiply_(W[i, i])
                xlists[i][p].add_(-alph0 * ylists[i][p])
                for j in neighs:
                    xlists[i][p].add_(W[i, j] * xlists[j][p])

        yh = models[i].forward(x)
        loss = torch.nn.MSELoss()(yh, y)
        loss.backward()

        with torch.no_grad():
            for p in range(num_params):
                ylists[i][p].multiply_(W[i, i])
                for j in neighs:
                    ylists[i][p].add_(W[i, j] * ylists[j][p])
                
                ylists[i][p].add_(xlists[i][p].grad)
                ylists[i][p].add_(-1.0 * glists[i][p])
                
                glists[i][p] = xlists[i][p].grad.detach().clone()
                # xlists[i][p].add_(-alph * xlists[i][p].grad)
                xlists[i][p].grad.zero_()
        
    

Node 0: 0.6551424789768653
Node 1: 0.6551424789768653
Node 2: 0.6483537453043123
Node 0: 0.6393294941033273
Node 1: 0.6273542061154496
Node 2: 0.6232066877403077
Node 0: 0.6199204884786901
Node 1: 0.6185139004954654
Node 2: 0.6161325031473621
Node 0: 0.6143346797061234
Node 1: 0.6130290684270214
Node 2: 0.611481079445719
Node 0: 0.6102091461224641
Node 1: 0.6091114056615163
Node 2: 0.6079727273417317
Node 0: 0.6069697875714927
Node 1: 0.6060560341068988
Node 2: 0.6051486028585554
Node 0: 0.6043179136334939
Node 1: 0.6035443779335654
Node 2: 0.6027862189909508
Node 0: 0.6020776231016285
Node 1: 0.6014095617169142
Node 2: 0.6007570999502707
Node 0: 0.600139747086175
Node 1: 0.5995526999740806
Node 2: 0.5989796452809412
Node 0: 0.5984329754774256
Node 1: 0.5979097907640373
Node 2: 0.5973988019447506


In [None]:
opt = torch.optim.SGD(models[0].parameters(), lr=0.01)

for k in range(3):
    opt.zero_grad()
    yh = models[0].forward(x)
    loss = torch.nn.MSELoss()(yh, y)
    loss.backward()
    print(loss)
    opt.step()

tensor(0.5533, grad_fn=<MseLossBackward>)
tensor(0.5525, grad_fn=<MseLossBackward>)
tensor(0.5518, grad_fn=<MseLossBackward>)


In [None]:
G2 = nx.erdos_renyi_graph(5, 0.8)

print(nx.is_connected(G2))

A = torch.randn((10, 10)).multiply_(5.0)
import math
math.sqrt(2.0)

True


1.4142135623730951