In [345]:
# setup
from IPython.display import display,HTML
display(HTML('<style>.prompt{width: 0px; min-width: 0px; visibility: collapse}</style>'))
display(HTML(open('../rise.css').read()))

# imports
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import time
import math

%matplotlib inline
sns.set(style="whitegrid", font_scale=1.5, rc={'figure.figsize':(12, 6)})

<h1>Divide and Conquer: Numerical Examples</h1>

<h3>Last Time</h3>

We studied Binary Search as a first look at divide and conquer type algorithms. We saw that these algorithms are naturally expressed using recursion, and that they can be easily parallelized.

<h3>This Time</h3>

We will look at two more divide and conquer algorithms, (Karatsuba-Ofman algorithm)[https://people.cs.uchicago.edu/~laci/HANDOUTS/karatsuba.pdf] for integer multiplication (see TAOCP 4.3.3) and Strassen's Algorithm for matrix multiplication. We will see that these problems are nice in the sense that there is a good test (the BLR test) to ensure that the algorithms work properly without needing to investigate their properties.

<h3>Multiplication problem setup</h3>

We assume that we have two natural numbers, $a,b\in\mathbb{N}$ written in binary. Their digits are assumed to be stored in RAM. The size of the input is the number of digits of $a$ plus the number of digits of $b$: $n=\log(a)+\log(b)$. This is typical for numerical problems. At this level of abstraction, we assume that we can only work with the individual bits of $a$ and $b$ in any single timestep.

The methods that we describe will work for any base, not just $2$. We'll use the word "digit" to be intentionally ambiguous about the base. We'll call the base $base$. The code is implemented for $base=2$.

We assume that multiplying by the base requires constant time because it can be achieved by a shift operation: adding a $0$ as the least significant digit.

Here are some Python functions that will be useful for converting from Python integers to lists of their digits.

In [346]:
def convert_to_base(int_value,base=2): 
    
    '''assumes that int_value>=0
    #For now, we assume base== 2 here.
    #See https://stackoverflow.com/questions/3528146/convert-decimal-to-binary-in-python'''
    if base !=2:
        print("haven't implemented this.")
    assert int_value >=0
    return [int(x) for x in "{0:b}".format(int_value)[::-1]]

def convert_from_base(digit_list,base=2):
    '''
    '''
    total=0
    for index, bit in enumerate(digit_list):
        total+= bit*2**index
    return total

def test_convert():
    for i in range(20):
        assert i==convert_from_base(convert_to_base(i))
test_convert()

class low_level_int(): 
    '''This class will represent how we are storing our numbers.
    Contains: self.value......the integer that self represents
              self.base.......the base of our representation. Assume self.base==2 for now.
              self.digits.....a list of the digits ine the representation of the absolute value of self.value.
              self.sign.......1 if self.value>=0, -1 otherwise.
    '''
    def __init__(self,initialization_data,base=2):
        '''initialization_data is either
                                        1. value (an integer)
                                        2. a list [digits,sign,base]. 
        '''
        if isinstance(initialization_data,int):
            self.value = initialization_data
            self.base = base
            if self.value>=0:
                self.digits = convert_to_base(initialization_data, base=2)
                self.sign=1
            else:
                self.digits = convert_to_base(-initialization_data, base=2)
                self.sign=-1
        else:
            self.base = initialization_data[2] 
            self.digits, self.sign = initialization_data[0], initialization_data[1]
            assert all([isinstance(s,int) for s in self.digits])
            self.value = convert_from_base(self.digits,base=self.base)*self.sign
    def __int__(self): #How to treat as integer
        return self.value
    def __repr__(self):#This tells print() what to do.
        #We'll print out the digits in their usual order, the reverse of how they are stored.
        return ''.join([str(digit) for digit in self.digits[::-1]])



<h3>Addition</h3>

Let's implement the method for addition that is taught in many elementary schools. This method requires $O(n)$ operations. 

In [347]:
def digit_sum(a,b,base=2):

    #assume a and b are lists. Each element in the list is a valid digit for the base.
    #a = sum_{i=0} a[i]base**i and b = sum_{i=0} b[i]base**i.
    #Returns a list of valid digits that represents the sum a+b.
    assert isinstance(a,list)
    assert all([isinstance(item, int) for item in a])
    assert isinstance(b,list)
    assert all([isinstance(item, int) for item in b])


    max_length = max(len(a),len(b))
    while len(a)<max_length: #make sure both a and b have the same length.
        a.append(0)
    while len(b)<max_length:
        b.append(0)
    carry=0
    to_return = []
    for index in range(max_length):
        c = a[index]+b[index] + carry
        new_digit = c%base
        carry = (c-new_digit)//base
        to_return.append(new_digit)
    if carry!=0:
        to_return.append(carry)
    return to_return

def test_digit_sum():
    a,b=22,45
    a_digits = convert_to_base(a)
    b_digits = convert_to_base(b)
    assert a+b== convert_from_base(digit_sum(a_digits,b_digits))
test_digit_sum()


Now we create a low-level integer class endowed with the extra operation of addition. Inheritance makes this easy. We give it the ability to subtract, which could be implemented with the grade-school method. Later, we will implement multiplication for this class.

In [348]:
class low_level_int_with_addition(low_level_int): #We use inheritance

    def __init__(self,value):
        super().__init__(value) #Initialization does not change from parent class.

    def __add__(self,other): #We're adding a new function based on the addition that we just described.
        assert isinstance(other, low_level_int_with_addition)
        assert self.base==other.base
        if self.sign == other.sign:
            new_digit_sum = digit_sum(self.digits,other.digits,base=self.base)
            new_value = self.value+other.value
            return low_level_int_with_addition([new_digit_sum,self.sign,self.base])
        else:
            #We didn't implement subtraction, so I'm going to cheat.
            #The elementary algorithm for subtraction is linear-time.
            new_value = self.value+other.value
            return low_level_int_with_addition(new_value)
    def __neg__(self):
        return low_level_int_with_addition([self.digits, -self.sign, self.base])
    def __sub__(self,other):
        #subtraction is adding the negative.
        negative_other = -other
        return self + negative_other
    def shift(self,num_to_shift_by):
        return low_level_int_with_addition([[0]*num_to_shift_by+self.digits, self.sign,self.base])
    
def test_low_level_int_with_addition():
    x=low_level_int_with_addition(30)
    y=low_level_int_with_addition(25)
test_low_level_int_with_addition()

<h3>Analysis of Addition</h3>

Addition only requires a single loop through the digits of $a$ and $b$. Each step in the loop performs an addition of digits, and so requires $O(1)$ time. Therefore, addition requires linear time, $O(n)$. To add two numbers, we need to at least examine all of the digits of both numbers, so we need linear time to perform addition. We do not expect to be able to achieve an asymptotic improvement for addition.

<h3>Naive Multiplication</h3>

In the naive method for multiplication that we were taught in elementary school, we put $a$ above $b$. We loop through the digits of $b$ from least significant to most significant. For each digit $d_b$ of $b$, we loop through the $d_a$ digits of $a$. We keep track of a variable "carry" and repeatedly calculate $d_a d_b+carry$, writing the smaller digit, then resetting carry to be the other digit. 

In [349]:
def naive_multiply(a,b,base=2):
    #assume a and b are lists. Each element in the list is a valid digit for the base.
    #a = sum_{i=0} a[i]base**i and b = sum_{i=0} b[i]base**i.
    #Returns a list of valid digits that represents the product a*b.
    results = [] #Will be a list of lists. Each inner list represents the product of a digit of b with all of a.
    for index_b, digit_b in enumerate(b):
        carry = 0
        to_append = [0]*index_b #the inner list to append to results
        for digit_a in a:
            c = digit_b*digit_a+carry
            new_digit = c%base
            new_carry = (c-new_digit)//base
            to_append.append(new_digit) #appends to results.
            carry = new_carry
        if carry!=0:
            to_append.append(carry)
        results.append(to_append)

        max_result_length = max([len(r) for r in results])#Make sure that all of the lists in results have the same length by padding with 0's.
        for r in results:
            while len(r)<max_result_length:
                r.append(0)
    total = []
    for r in results:
        total = digit_sum(total, r)
    return total


def test_naive_multiply():
    a,b=22,45
    a_bits = convert_to_base(a)
    b_bits = convert_to_base(b)
    c_bits = naive_multiply(a_bits,b_bits)
    assert a*b == convert_from_base(c_bits)
test_naive_multiply()

<h3>Analysis of Naive Multiplication</h3>

Naive multiplication requires $O(n^2)$ time in the worst case. One way to see this is that the algorithm uses a pair of nested loops. The lengths of the loops are the number of digits of $a$ and the number of digits of $b$, respectively. In the worst case, there are $\frac{n}{2}$ digits of both $a$ and $b$, so the multiplication requires $\frac{n^2}{4}\in O(n)$ operations.

Actually, we can perform the multiplications in linear time when the base is $2$, because each inner loop can operate in constant time. To mulitply $a$ by the digits $1$ or $0$ of b requires no operations, since the result is either $a$ or $0$.

After performing the multiplications, we have to add the numbers in results. The number of results to add is the number of digits of $b$. Each addition requires linear time. Therefore, the addition step requires $O(n^2)$ time.

<h3>Karatsuba-Ofman Fast Multiplication</h3>

There is a divide-and-conquer approach to multiplication that can improve performance when the numbers to multiply are large enough. This is the Karatsuba-Ofman algorithm. 

Digit lists still represent the same numbers when we append additional zeros to their ends. Therefore, we may assume that the number of digits of both $a$ and $b$ is the same and power of $2$. At most, this will require doubling the number of digits of $a$ and $b$. Thus, we may assume that $\frac{n}{2}$ is an integer and is a power of $2$.

Write $a=(base)^{\frac{n}{2}}a_R + a_L$, where $a_L$ consists of the leftmost (less significant) digits of $a$ and $a_R$ consists of the rightmost (most significant) digits of $a$. This splitting can be performed in constant time, though it takes linear time to implement it using Python's list slicing.

Similarly, we write $b=(base)^{\frac{n}{2}}b_R + b_L$.

Then

\begin{align*}
a*b &= (base)^n(a_R b_R) + (base)^{\frac{n}{2}}(a_Rb_L + a_Lb_R) + a_Lb_L\\
&=(base)^n(a_R b_R) + (base)^{\frac{n}{2}}(a_R b_R + a_Lb_L -(a_R-a_L)(b_R-b_L)) + a_Lb_L.
\end{align*}

The first expression seems to need four multiplications, while the second expression reveals that we only need three multiplications, $U=a_R b_R, V=a_Lb_L, W=(a_R-a_L)(b_R-b_L)$. We need more additions, but addition is fast this doesn't impact the runtime much when the numbers have many digits.

In [350]:
def karatsuba_algorithm(a,b,base=2,power_of_2=None):
    assert isinstance(a,low_level_int_with_addition)
    assert isinstance(b,low_level_int_with_addition)
    #assume a and b are both low_level_int_with_addition.
    #Returns a low_level_int_with_addition that represents the product a*b.
    a_digits = a.digits
    b_digits = b.digits
    if power_of_2 is None:
        power_of_2 = max(2**math.ceil(math.log(len(a_digits))), 2**math.ceil(math.log(len(b_digits))))
    while len(a_digits)<power_of_2:#pad with zeros; ensures both a and b have the same number of digits
        a_digits.append(0)
    while len(b_digits)<power_of_2:
        b_digits.append(0)
    if len(a_digits)==0:
        return low_level_int_with_addition(0)
    elif len(a_digits)==1: #base case.
        return low_level_int_with_addition(a_digits[0]*b_digits[0]*a.sign*b.sign)
    else:
        middle = power_of_2//2
        a_L = low_level_int_with_addition([a_digits[:middle],a.sign,a.base])
        a_R = low_level_int_with_addition([a_digits[middle:],a.sign,a.base])#These steps take linear time and are wasteful.
        b_L = low_level_int_with_addition([b_digits[:middle],b.sign,b.base])
        b_R = low_level_int_with_addition([b_digits[middle:],b.sign,b.base])
        U = karatsuba_algorithm(a_R,b_R) #recursive calls that can be made in parallel.
        V = karatsuba_algorithm(a_L,b_L)
        W = karatsuba_algorithm(a_R-a_L, b_R-b_L)
        return U.shift(power_of_2)+(U+V-W).shift(middle)+V
def test_karatsuba():
    a,b=low_level_int_with_addition(22), low_level_int_with_addition(13)
   
    c= karatsuba_algorithm(a,b)
    assert 22*13 == int(c)
test_karatsuba()

<h3>Analysis of Karatsuba's Algorithm</h3>

Karatsuba's Algorithm follows the recursive divide-and-conquer paradigm. We can use the recursive structure to obtain an equation on the runtime. Let $K(n)$ denote the runtime on an input of length $n$.

The recursive structure of the algorithm allows us to express $K(n)$ in terms of smaller values. The top level call to Karatsuba's algorithm involve three multiplications and some additions. Each multiplication involves numbers with only half as many digits as the orginial problem. This shows that

$K(n) = 3K(\frac{n}{2}) + O(n)$, where $O(n)$ accounts for the additions.

We will see on Monday that we can solve this recurrence to obtain the runtime:

$K(n) \in O(n^{\log_2(3)})\approx O(n^{1.58})$

<h3>Demystifying Karatsuba's Algorithm</h3>

The naive algorithm for multiplication seems very natural, so it is suprising that we can do better. On the other hand, Karatsuba's algorithm seems very strange. How could someone come up with it?

Knuth claims that the idea is very natural because it is a special case of a more general idea. Rather than split the numbers into $2$ roughly equal parts, we can split them into $r+1$ roughly equal parts. When $r=1$, we recover Katsuba's algorithm.

We can think of $a=(base)^\frac{n}{2}a_R + a_L$ and $b=(base)^\frac{n}{2}b_R+b_L$ as polynomials, $a(x)=a_R x +a_L$ $b(x)=b_Rx+b_L$ evaluated at $x=(base)^\frac{n}{2}$. To evaluate $ab$, we perform the harder task of evaluating $a(x)b(x)$. Since $a(x)b(x)$ is a polynomial of degree $2$, it is determined by $3$ its value at three points, 

- $a(0)b(0) = a_Lb_L$
- $a(1)b(1) = (a_R +a_L)(b_R +b_L)$
- $a(2)b(2) = (2a_R+a_L)(2b_R+b_L)$

So if $a(x)b(x)=c(x)=c_2x^2 +c_1x +c_0$, then the coefficients $c_2,c_1,c_0$ must satisfy

$\begin{pmatrix}
0 & 0 & 1\\
1 & 1 & 1\\
4 & 2 & 1
\end{pmatrix}\begin{pmatrix}
c_2\\
c_1\\
c_0
\end{pmatrix}=\begin{pmatrix}
a_L b_L\\
(a_R+a_L)(b_R+b_L)\\
(2a_R+a_L)(2b_R+b_L)
\end{pmatrix}$

We calculate that $\begin{pmatrix}
0 & 0 & 1\\
1 & 1 & 1\\
4 & 2 & 1
\end{pmatrix}^{-1}=\begin{pmatrix}
0.5 & -1 & 0.5\\
-1.5& 2 & -0.5\\
1 & 0 & 0
\end{pmatrix}$ with python below.

In [351]:
#https://www.geeksforgeeks.org/how-to-inverse-a-matrix-using-numpy/
A = np.array([[0, 0, 1],
              [1, 1, 1],
              [4, 2, 1]])
print(np.linalg.inv(A))

[[ 0.5 -1.   0.5]
 [-1.5  2.  -0.5]
 [ 1.   0.   0. ]]



$\begin{pmatrix}
c_2\\
c_1\\
c_0
\end{pmatrix}=\begin{pmatrix}
0.5 & -1 & 0.5\\
-1.5& 2 & -0.5\\
1 & 0 & 0
\end{pmatrix}\begin{pmatrix}
a_L b_L\\
(a_R+a_L)(b_R+b_L)\\
(2a_R+a_L)(2b_R+b_L)
\end{pmatrix}$.

This shows that we can calculate the coefficients of the polynomial $c(x)$ as a linear combination of $3$ products of numbers that are each half as big as the original ones. By replacing $x$ with a shift by $\frac{n}{2}$, we can calculate $c(base^{\frac{n}{2}})=ab$. Maybe for a different choice of points to evaluate $c(x)$, we can recover Karatsuba's algorithm.

The Knuth's generalization of Karatsuba's method splits $a$ and $b$ into $r+1$ parts each. We assume these parts are equally sized. He generalizes $a$ and $b$ to $r$-degree polynomials $a(x)$ and $b(x)$ and calculates the numbers $a(0)b(0),a(1)b(1),\dots,a(2r)b(2r)$. Then, he calculates an appropriate linear combination of these $r+1$ numbers to recover the coefficients of the degree $2r$ polynomial $c(x)=a(x)b(x)$. Finally, he uses shift to evaluate $c(base^{\frac{n}{r+1}})=ab$.

The linear combination of the $r+1$ numbers to use comes from inverting a matrix. This is slow, but it can be computed before $a$ and $b$ are known. 

The generalized Karatsuba's method is a divide and conquer method.

$K_r(n) = (2r+1) K_r(\frac{n}{r+1}) + \Theta(n)$

On Monday, we will see that this recursion can be solved as

$K_r(n) = \Theta(n^{\log_{r+1}(2r+1)})$. As $r$ increases, we get faster algorithms that are nearly linear. 

$\lim_{r\to \infty}\log_{r+1}(2r+1)=\lim_{r\to\infty}\frac{\ln(2r+1)}{\ln(r+1)}=\lim_{r\to\infty} \frac{\frac{2}{2r+1}}{\frac{1}{r+1}}=\lim_{r\to\infty}\frac{2(r+1)}{2r+1}=1$.

We cannot conclude that it is possible to factor in linear time. The argument only shows that for every $\epsilon>0$, there is a multiplication algorithm whose runtime is $O(n^{1+\epsilon})$. The issue is that the constants hidden by the big-Oh notation are sensitive to $\epsilon$.

The state of the art algorithms for matrix multiplication are $\Theta(n\log(n))$ which is conjectured to be optimal.

<h3>Matrix Multiplication</h3>

Matrices are $2$-dimensonal arrays of real numbers that encode linear transformations. Matrix multiplication encodes composition of the associated linear transformations. We will consider our input to the matrix multiplication problem to be two $n\times n$ matrices. In any time step, we can examine an entry of a matrix. We assume we can add, subtract, multiply or divide any pair of real numbers in constant time.

The naive method for matrix multiplication $n$ multiplications and $n-1$ additions for every pair of row of the first matrix and column of the second matrix. This is $(2n-1)n^2$ operations. Thus, the naive algorithm for matrix multiplication is $O(n^3)$.

In [359]:
#courtesy of ChatGPT3.5
def naive_matrix_multiply(A, B):
    # Get the size of the matrices (assuming they are n x n)
    n = len(A)
    
    # Initialize the result matrix C with zeros
    C = [[0 for _ in range(n)] for _ in range(n)]
    
    # Perform the matrix multiplication
    for i in range(n):
        for j in range(n):
            for k in range(n):
                C[i][j] += A[i][k] * B[k][j]
    
    return C

def test_naive_matrix_multiply():
    A = [[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]]

    B = [[9, 8, 7],
        [6, 5, 4],
        [3, 2, 1]]
    #test this matrix multiply against numpy
    assert not np.any(matrix_multiply(A, B) - np.array(A)@np.array(B)) #https://stackoverflow.com/questions/18395725/test-if-numpy-array-contains-only-zeros
test_naive_matrix_multiply()

<h3>Strassen's Algorithm</h3>

Strassen's algorithm is similar to the Karatsuba algorithm, but for matrices. Recall that matrix multiplication can be performed in blocks:

If
$A = \left(\begin{array}{c|c}
A_{11} & A_{12} \\
\hline
A_{21} & A_{22}
\end{array}\right)$
and
$B = \left(\begin{array}{c|c}
B_{11} & B_{12} \\
\hline
B_{21} & B_{22}
\end{array}\right)$

then $AB= \left(\begin{array}{c|c}
A_{11}B_{11} + A_{12}B_{21} & A_{11}B_{12} + A_{12}B_{22} \\
\hline
A_{21}B_{11} + A_{22}B_{21} & A_{21}B_{12} + A_{22}B_{22}
\end{array}\right)
$.


Let 

Strassen's algorithm uses the clever identity that
$AB= \begin{pmatrix}
        M_1+M_4-M_5+M_7 &M_3+M_5\\
        M_2+M_4 & M_1 -M_2 +M_3 +M_6
    \end{pmatrix}$
where
- $M_1=(A_{11}+A_{22})(B_{11}+B_{22})$
- $M_2=(A_{21}+A_{22})B_{11}$
- $M_3=A_{11}(B_{12}-B_{22})$
- $M_4 =A_{22}(B_{21}-B_{11})$
- $M_5=(A_{11}+A_{12})B_{22}$
- $M_6 = (A_{21}-A_{11})(B_{11}+B_{12})$
- $M_7=(A_{12}-A_{22})(B_{21}+B_{22})$

In [362]:
#Courtesy of ChatGPT3.5
def add_matrix(A, B):
    # Add two matrices
    return np.add(A, B)

def subtract_matrix(A, B):
    # Subtract two matrices
    return np.subtract(A, B)

def strassen(A, B):
    # Base case: 1x1 matrix
    if len(A) == 1:
        return A * B

    # Split matrices into quarters
    n = len(A)
    mid = n // 2

    A11 = A[:mid, :mid]
    A12 = A[:mid, mid:]
    A21 = A[mid:, :mid]
    A22 = A[mid:, mid:]

    B11 = B[:mid, :mid]
    B12 = B[:mid, mid:]
    B21 = B[mid:, :mid]
    B22 = B[mid:, mid:]

    # Strassen's sub-matrix multiplications
    M1 = strassen(add_matrix(A11, A22), add_matrix(B11, B22))
    M2 = strassen(add_matrix(A21, A22), B11)
    M3 = strassen(A11, subtract_matrix(B12, B22))
    M4 = strassen(A22, subtract_matrix(B21, B11))
    M5 = strassen(add_matrix(A11, A12), B22)
    M6 = strassen(subtract_matrix(A21, A11), add_matrix(B11, B12))
    M7 = strassen(subtract_matrix(A12, A22), add_matrix(B21, B22))

    # Combine results into the final matrix
    C11 = add_matrix(subtract_matrix(add_matrix(M1, M4), M5), M7)
    C12 = add_matrix(M3, M5)
    C21 = add_matrix(M2, M4)
    C22 = add_matrix(subtract_matrix(add_matrix(M1, M3), M2), M6)

    # Combine the submatrices into one matrix
    C = np.zeros((n, n))
    C[:mid, :mid] = C11
    C[:mid, mid:] = C12
    C[mid:, :mid] = C21
    C[mid:, mid:] = C22

    return C

# Example usage
def test_strassen():
    A = np.array([[1, 2, 3, 4],
                [5, 6, 7, 8],
                [9, 10, 11, 12],
                [13, 14, 15, 16]])

    B = np.array([[16, 15, 14, 13],
                [12, 11, 10, 9],
                [8, 7, 6, 5],
                [4, 3, 2, 1]])

    assert not np.any(strassen(A, B)-np.array(A)@np.array(B))


<h3>Analysis of Strassen's Algorithm</h3>

Like Karatsuba's algorithm, Strassen's algorithm uses the divide-and-conquer technique. This allows us to derive a recursive equation for the runtime of the algorithm. Each step of the algorithm calls itself $7$ times on instances of half the original size. Each step also performs some matrix additions, each of which runs in $O(n^2)$ time.

Thus, letting $Strassen(n)$ be the runtime of Strassen's algorithm on an input of size $n$, we find

$Strassen(n)= 7Strassen(\frac{n}{2})+O(n^3)$.

We will see that we can solve this recursive equation to conclude that

$Strassen(n)\in O(n^{\log(7)})\approx O(n^{2.81})$.

The state-of-the-art matrix multiplications run in $O(n^{2.37})$. Finding the optimal algorithm is still an open question.

<h3>Self-Testing/Self Correcting</h3>

Integer and matrix multiplication both have the nice property of linearity. Whether $x$ and $y$ are matrices or integers, whenever we have $x=x_1+x_2$ and $y=y_1+y_2$, we have $xy = x_1y_1 +x_1y_2 +x_2y_1 +x_2y_2$.

This suggests that we can check our multiplication functions using itself to evaluate both the left and rights side of the equation above. This is the concept of self-testing.

In [366]:
def usual_multiplication(x,y):
    return x*y

def wrong_multiplication(x,y):
    return 2*(x//2)*y
def self_test_multiplication(multiplication_function):
    for num_check in range(20):
        x = random.randint(-1000,1000) #What exactly should be the range here?
        y = random.randint(-1000,1000)

        x_1 = random.randint(-1000,1000)
        x_2 = x-x_1
        y_1 = random.randint(-1000,1000)
        y_2 = y-y_1

        answer = multiplication_function(x,y)
        checked_answer = 0
        for xi in [x_1,x_2]:
            for yi in [y_1,y_2]:
                checked_answer += multiplication_function(xi,yi)

        assert answer==checked_answer
self_test_multiplication(usual_multiplication)
#self_test_multiplication(wrong_multiplication) This fails, as it should.

A detailed statistical analysis reveals that if the self testing passes with high probability, then the function is correct on most inputs.

The example of integer multiplication is trivial, because we can easily check our answers against the standard multiplication algorithm. Self-testing really becomes helpful for floating point multiplication, where the answers are only approximately correct due to rounding. The same idea works for matrix multiplication. It can be very tricky to ascertain that your matrix multiplication algorithm works up to the desired error. The BLR test provides a simple way to test this probabilistically.