In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [8]:
multiply_adds = 1

def count_conv2d(shapex, shapey,in_channels = 1,out_channels = 1,kernel_size = (3,3)):
    
    cin = in_channels
    cout = out_channels
    kh, kw = kernel_size
    batch_size = shapex[0]

    out_w = shapey[2]
    out_h = shapey[3]

    # ops per output element
    # kernel_mul = kh * kw * cin
    # kernel_add = kh * kw * cin - 1
    
    kernel_ops = multiply_adds * kh * kw * cin
    ops_per_element = kernel_ops

    # total ops
    # num_out_elements = y.numel()
    output_elements = batch_size * out_w * out_h * cout
    total_ops = output_elements * ops_per_element

    # in case same conv is used multiple times
    total_ops += torch.Tensor([int(total_ops)])
    
    return total_ops

In [21]:
def compare_flops(n,n_iter_conv,n_iter_jac):
    #flops, params = profile(net, input_size=input_shape)
    flop_conv = (count_conv2d((1,1,256,256),(1,1,3,3)) + 2*n**2 + 4) * n_iter_conv # 9 operations for each convolution, and we perform n^2 convolutions in total for each layer + 6 operations for the jacobi iteration step
    flop_jac = (7 * n**2) * n_iter_jac # For each u_ij, we add its 4 neighbors + forcing term + resetting the boundaries
    return flop_conv/flop_jac

In [61]:
def test_model(net, n_tests, grid_size):
    
    losses = []
    
    loss_to_be_achieved = 1e-6
    max_nb_iters = 100000
    f = torch.zeros(1, 1, grid_size, grid_size)
    u_0 = torch.ones(1, 1, grid_size, grid_size)

    for i in range(n_tests):
        
        problem_instance = DirichletProblem(N=grid_size, k=max_nb_iters * 10, )
        gtt = problem_instance.ground_truth
        
        # jacoby method / known solver
        
        u_jac = im.jacobi_method(problem_instance.B_idx, problem_instance.B, f, u_0, k = 1)
        loss_jac = F.mse_loss(gtt, u_jac) # TODO use loss from metrics
        count_jac = 1
        
        nb_iters = 0
        while loss_jac >= loss_to_be_achieved and nb_iters < max_nb_iters:
            u_jac = im.jacobi_method(problem_instance.B_idx, problem_instance.B, f, u_jac,k = 1)
            loss_jac = F.mse_loss(gtt, u_jac)
            count_jac += 1
            nb_iters += 1
            
        # learned solver
        
        u_h = im.H_method(net,problem_instance.B_idx, problem_instance.B, f, u_0 ,k = 1)
        loss_h = F.mse_loss(gtt, u_h)
        count_h = 1
        
        # old method 
        nb_iters = 0
        while loss_h >= loss_to_be_achieved and nb_iters < max_nb_iters:
            u_h = im.H_method(net,problem_instance.B_idx, problem_instance.B, f, u_h,k = 1)
            loss_h = F.mse_loss(gtt, u_h)
            count_h += 1
            nb_iters += 1
        
        
        yield count_jac, count_h, compare_flops(grid_size, count_jac, count_h).item()

In [None]:
import pandas as pd
test_results = pd.DataFrame(test_model(model.net, 1, 50), columns=['count_jacoby', 'count_learned', 'flops_ratio'])

test_results