In [1]:
def fix_layout(width:int=95):
    from IPython.core.display import display, HTML
    display(HTML('<style>.container { width:' + str(width) + '% !important; }</style>'))
    
fix_layout()

In [2]:
import sys

import numpy as np
import matplotlib.pyplot as plt
from nnpde.utils.logs import enable_logging, logging 
from importlib import reload
import nnpde.functions.iterative_methods as im
from nnpde.functions import geometries, helpers
from nnpde.problems import DirichletProblem 

In [3]:
enable_logging(10)

2018-12-19 11:00:26,989 - root - INFO - logs - logging enabled for level: 10


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
# Define train dimension
N = 16

# Initialize f: we use a zero forcing term for training
f = torch.zeros(1, 1, N, N)

# For each problem instance define number of iteration to perform to obtain the solution
nb_problem_instances = 10
problem_instances = [DirichletProblem(k=k) for k in np.random.randint(1, 20, nb_problem_instances)]

In [None]:
import nnpde.model as M 

# TODO fit would idealy take X, y, (u_0 and u_*)
model = M.JacobyWithConv(max_epochs = 100,batch_size=10, learning_rate = 1e-5).fit(problem_instances)
#losses = model.losses

In [None]:
def count_conv2d(in_shape = (1,1,256,256), out_shape = (1,1,3,3),ks = [3,3],layers = 3):
    
    multiply_adds = 1 # Each addition and multiplication counts as 1 flop

    kw = ks[0]
    kh = ks[1]

    out_w = out_shape[2]
    out_h = out_shape[3]


    kernel_ops = multiply_adds * kh * kw # size of kernel (9 in our case)
     

    output_elements = out_w * out_h # total number of elements
    
    total_ops = output_elements * kernel_ops # for each element we perform kernel_size operations

    total_ops += torch.Tensor([int(total_ops)])
    return layers*total_ops

In [None]:
def compare_flops(n,n_iter_conv,n_iter_jac):
    
    '''
    Number of flops for jacobiwithConv is calculated as the (l*9n^2 + 4 + 2*n^2) where l is the number of layers, 9  because kernel_width * kernel_height = 9, 4 for
    the jacobi iteration method and 2*n^2 for the boundary reset operator
    '''
    
    '''
    Number of flops for jacobi iteration method is calculated as 4*n^2 + 2*n^2 + n^2 where 4*n^2 results from each iteration step where we sum the neighbors
    of each element of the matrix, 2n^2 from the boundary reset operator and n^2 from the forcing term
    '''
    flop_conv = (count_conv2d((1,1,256,256),(1,1,3,3)) + 2*n**2 + 4) * n_iter_conv 
    flop_jac = (7 * n**2) * n_iter_jac 
    return flop_conv/flop_jac

In [None]:
def test_model(net, n_tests, grid_size):
    
    losses = []
    
    loss_to_be_achieved = 1e-6
    max_nb_iters = 100000
    f = torch.zeros(1, 1, grid_size, grid_size)
    u_0 = torch.ones(1, 1, grid_size, grid_size)

    for i in range(n_tests):
        
        problem_instance = DirichletProblem(N=grid_size, k=max_nb_iters * 10, )
        gtt = problem_instance.ground_truth
        
        # jacoby method / known solver
        
        u_jac = im.jacobi_method(problem_instance.B_idx, problem_instance.B, f, u_0, k = 1)
        loss_jac = F.mse_loss(gtt, u_jac) # TODO use loss from metrics
        count_jac = 1
        
        nb_iters = 0
        while loss_jac >= loss_to_be_achieved and nb_iters < max_nb_iters:
            u_jac = im.jacobi_method(problem_instance.B_idx, problem_instance.B, f, u_jac,k = 1)
            loss_jac = F.mse_loss(gtt, u_jac)
            count_jac += 1
            nb_iters += 1
            
        # learned solver
        
        u_h = im.H_method(net,problem_instance.B_idx, problem_instance.B, f, u_0 ,k = 1)
        loss_h = F.mse_loss(gtt, u_h)
        count_h = 1
        
        # old method 
        nb_iters = 0
        while loss_h >= loss_to_be_achieved and nb_iters < max_nb_iters:
            u_h = im.H_method(net,problem_instance.B_idx, problem_instance.B, f, u_h,k = 1)
            loss_h = F.mse_loss(gtt, u_h)
            count_h += 1
            nb_iters += 1
        
        
        yield count_jac, count_h, compare_flops().item()

In [49]:
import pandas as pd
test_results = pd.DataFrame(test_model(model.net, 1, 50), columns=['count_jacoby', 'count_learned', 'flops_ratio'])

test_results

tensor([486.])