In [15]:
import torch
import networkx as nx
from models import fourier_nn
from utils import graph_generation
import copy

In [16]:
shape = [1, 3, 3, 1]
model = fourier_nn.FourierNet(shape, scale=0.05)
p = torch.nn.utils.parameters_to_vector(model.parameters())
models = {i: copy.deepcopy(model) for i in range(3)}
G = nx.wheel_graph(3)
W = graph_generation.get_metropolis(G)

In [17]:
def params_to_gradv(params):
    vec = []
    for p in params:
        vec.append(p.grad.flatten())
    
    return torch.cat(vec)

        

In [32]:
x = torch.linspace(-1, 1, 100).reshape(-1, 1)
y = torch.sin(x).reshape(-1, 1)

xlists = {}
ylists = {}
glists = {}

plists_base = list(models[0].parameters())

num_params = len(plists_base)
zero_plist = [torch.zeros_like(plists_base[i], requires_grad=False) for i in range(num_params)] 

for i in range(3):
    neighs = list(G.neighbors(i))
    neighs.append(i)
    xlists[i] = list(models[i].parameters())
    ylists[i] = copy.deepcopy(zero_plist)
    glists[i] = copy.deepcopy(zero_plist)

alph0 = 0.1

for k in range(10):
    alph = alph0 / (k + 1)
    for i in range(3):
        neighs = list(G.neighbors(i))
        with torch.no_grad():
            for p in range(num_params):
                xlists[i][p].multiply_(W[i, i])
                for j in neighs:
                    xlists[i][p].add_(W[i, j] * xlists[j][p])
                
        yh = models[i].forward(x)
        loss = torch.nn.MSELoss()(yh, y)
        loss.backward()
        print("Node {}: {}".format(i, loss.item()))
        
        with torch.no_grad():
            for p in range(num_params):
                xlists[i][p].add_(alph * xlists[i][p].grad)
                xlists[i][p].grad.zero_()
        
    

Node 0: 0.528134688462565
Node 1: 0.5292426990216434
Node 2: 0.5292426990216434
Node 0: 0.5319328674138187
Node 1: 0.5320676861379974
Node 2: 0.5320676861379974
Node 0: 0.5319399675897495
Node 1: 0.5317290797551719
Node 2: 0.5317290797551719
Node 0: 0.5311687804564645
Node 1: 0.5309082288115815
Node 2: 0.5309082288115815
Node 0: 0.5303675746038345
Node 1: 0.5301436928157564
Node 2: 0.5301436928157564
Node 0: 0.529710815311864
Node 1: 0.5295371469763355
Node 2: 0.5295371469763355
Node 0: 0.5292100234045826
Node 1: 0.5290797076746281
Node 2: 0.5290797076746281
Node 0: 0.5288364921774252
Node 1: 0.5287392109835193
Node 2: 0.5287392109835193
Node 0: 0.5285579074782476
Node 1: 0.5284846598992329
Node 2: 0.5284846598992329
Node 0: 0.5283478060377503
Node 1: 0.5282917793301989
Node 2: 0.5282917793301989


In [20]:
opt = torch.optim.SGD(models[0].parameters(), lr=0.01)

for k in range(3):
    opt.zero_grad()
    yh = models[0].forward(x)
    loss = torch.nn.MSELoss()(yh, y)
    loss.backward()
    print(loss)
    opt.step()

tensor(0.4937, grad_fn=<MseLossBackward>)
tensor(0.4930, grad_fn=<MseLossBackward>)
tensor(0.4924, grad_fn=<MseLossBackward>)


In [33]:
G2 = nx.erdos_renyi_graph(5, 0.8)

print(nx.is_connected(G2))

A = torch.randn((10, 10)).multiply_(5.0)
import math
math.sqrt(2.0)

True


TypeError: sqrt(): argument 'input' (position 1) must be Tensor, not float