In [1]:
%%writefile ecc.py

from unittest import TestCase


class FieldElement:

    def __init__(self, num, prime):
        if num >= prime or num < 0:
            error = 'Num {} not in field range 0 to {}'.format(
                num, prime - 1)
            raise ValueError(error)
        self.num = num
        self.prime = prime

    def __repr__(self):
        return 'FieldElement_{}({})'.format(self.prime, self.num)

    def __eq__(self, other):
        if other is None:
            return False
        return self.num == other.num and self.prime == other.prime

    def __ne__(self, other):
        # this should be the inverse of the == operator
        return not (self == other)

    def __add__(self, other):
        if self.prime != other.prime:
            raise TypeError('Cannot add two numbers in different Fields')
        # self.num and other.num are the actual values
        # self.prime is what we need to mod against
        num = (self.num + other.num) % self.prime
        # We return an element of the same class
        return self.__class__(num, self.prime)

    def __sub__(self, other):
        if self.prime != other.prime:
            raise TypeError('Cannot subtract two numbers in different Fields')
        # self.num and other.num are the actual values
        # self.prime is what we need to mod against
        num = (self.num - other.num) % self.prime
        # We return an element of the same class
        return self.__class__(num, self.prime)

    def __mul__(self, other):
        if self.prime != other.prime:
            raise TypeError('Cannot multiply two numbers in different Fields')
        # self.num and other.num are the actual values
        # self.prime is what we need to mod against
        num = (self.num * other.num) % self.prime
        # We return an element of the same class
        return self.__class__(num, self.prime)

    def __pow__(self, exponent):
        n = exponent % (self.prime - 1)
        num = pow(self.num, n, self.prime)
        return self.__class__(num, self.prime)

    def __truediv__(self, other):
        if self.prime != other.prime:
            raise TypeError('Cannot divide two numbers in different Fields')
        # self.num and other.num are the actual values
        # self.prime is what we need to mod against
        # use fermat's little theorem:
        # self.num**(p-1) % p == 1
        # this means:
        # 1/n == pow(n, p-2, p)
        num = (self.num * pow(other.num, self.prime - 2, self.prime)) % self.prime
        # We return an element of the same class
        return self.__class__(num, self.prime)


class FieldElementTest(TestCase):

    def test_ne(self):
        a = FieldElement(2, 31)
        b = FieldElement(2, 31)
        c = FieldElement(15, 31)
        self.assertEqual(a, b)
        self.assertTrue(a != c)
        self.assertFalse(a != b)

    def test_add(self):
        a = FieldElement(2, 31)
        b = FieldElement(15, 31)
        self.assertEqual(a + b, FieldElement(17, 31))
        a = FieldElement(17, 31)
        b = FieldElement(21, 31)
        self.assertEqual(a + b, FieldElement(7, 31))

    def test_sub(self):
        a = FieldElement(29, 31)
        b = FieldElement(4, 31)
        self.assertEqual(a - b, FieldElement(25, 31))
        a = FieldElement(15, 31)
        b = FieldElement(30, 31)
        self.assertEqual(a - b, FieldElement(16, 31))

    def test_mul(self):
        a = FieldElement(24, 31)
        b = FieldElement(19, 31)
        self.assertEqual(a * b, FieldElement(22, 31))

    def test_pow(self):
        a = FieldElement(17, 31)
        self.assertEqual(a**3, FieldElement(15, 31))
        a = FieldElement(5, 31)
        b = FieldElement(18, 31)
        self.assertEqual(a**5 * b, FieldElement(16, 31))

    def test_div(self):
        a = FieldElement(3, 31)
        b = FieldElement(24, 31)
        self.assertEqual(a / b, FieldElement(4, 31))
        a = FieldElement(17, 31)
        self.assertEqual(a**-3, FieldElement(29, 31))
        a = FieldElement(4, 31)
        b = FieldElement(11, 31)
        self.assertEqual(a**-4 * b, FieldElement(13, 31))


# tag::source1[]
class Point:

    def __init__(self, x, y, a, b):
        self.a = a
        self.b = b
        self.x = x
        self.y = y
        # end::source1[]
        # tag::source2[]
        if self.x is None and self.y is None:  # <1>
            return
        # end::source2[]
        # tag::source1[]
        if self.y**2 != self.x**3 + a * x + b:  # <1>
            raise ValueError('({}, {}) is not on the curve'.format(x, y))

    def __eq__(self, other):  # <2>
        return self.x == other.x and self.y == other.y \
            and self.a == other.a and self.b == other.b
    # end::source1[]

    def __ne__(self, other):
        # this should be the inverse of the == operator
        raise NotImplementedError

    def __repr__(self):
        if self.x is None:
            return 'Point(infinity)'
        else:
            return 'Point({},{})_{}_{}'.format(self.x, self.y, self.a, self.b)

    # tag::source3[]
    def __add__(self, other):  # <2>
        if self.a != other.a or self.b != other.b:
            raise TypeError('Points {}, {} are not on the same curve'.format
            (self, other))

        if self.x is None:  # <3>
            return other
        if other.x is None:  # <4>
            return self
        # end::source3[]

        # Case 1: self.x == other.x, self.y != other.y
        # Result is point at infinity

        # Case 2: self.x != other.x
        # Formula (x3,y3)==(x1,y1)+(x2,y2)
        # s=(y2-y1)/(x2-x1)
        # x3=s**2-x1-x2
        # y3=s*(x1-x3)-y1

        # Case 3: self == other
        # Formula (x3,y3)=(x1,y1)+(x1,y1)
        # s=(3*x1**2+a)/(2*y1)
        # x3=s**2-2*x1
        # y3=s*(x1-x3)-y1

        raise NotImplementedError


class PointTest(TestCase):

    def test_ne(self):
        a = Point(x=3, y=-7, a=5, b=7)
        b = Point(x=18, y=77, a=5, b=7)
        self.assertTrue(a != b)
        self.assertFalse(a != a)

    def test_add0(self):
        a = Point(x=None, y=None, a=5, b=7)
        b = Point(x=2, y=5, a=5, b=7)
        c = Point(x=2, y=-5, a=5, b=7)
        self.assertEqual(a + b, b)
        self.assertEqual(b + a, b)
        self.assertEqual(b + c, a)

    def test_add1(self):
        a = Point(x=3, y=7, a=5, b=7)
        b = Point(x=-1, y=-1, a=5, b=7)
        self.assertEqual(a + b, Point(x=2, y=-5, a=5, b=7))

    def test_add2(self):
        a = Point(x=-1, y=-1, a=5, b=7)
        self.assertEqual(a + a, Point(x=18, y=77, a=5, b=7))

Overwriting ecc.py


In [2]:
############## PLEASE RUN THIS CELL FIRST! ###################

# import everything and define a test runner function

import sys
if sys.version_info.major == 3:
    from importlib import reload
# from helper import run
import ecc
# import helper

from ecc import Point

# From run.py
from unittest import TestSuite, TextTestRunner


def run(test):
    suite = TestSuite()
    suite.addTest(test)
    TextTestRunner().run(suite)

In [3]:
from ecc import Point
p1 = Point(-1, -1, 5, 7)
# p2 = Point(-1, -2, 5, 7)

### Exercise 1

Determine which of these points are on the curve \\(y^{2}\\)=\\(x^{3}\\)+5x+7:

(2,4), (-1,-1), (18,77), (5,7)

In [4]:
# Exercise 1

# (2,4), (-1,-1), (18,77), (5,7)
# equation in python is: y**2 == x**3 + 5*x + 7

### Exercise 2

Write the `__ne__` method for `Point`.

#### Make [this test](/edit/code-ch02/ecc.py) pass: `ecc.py:PointTest:test_ne`

In [5]:
# Exercise 2

reload(ecc)
run(ecc.PointTest("test_ne"))

E
ERROR: test_ne (ecc.PointTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/ubuntu/alien/programmingbitcoin/code-ch02/ecc.py", line 189, in test_ne
    self.assertTrue(a != b)
  File "/home/ubuntu/alien/programmingbitcoin/code-ch02/ecc.py", line 146, in __ne__
    raise NotImplementedError
NotImplementedError

----------------------------------------------------------------------
Ran 1 test in 0.001s

FAILED (errors=1)


In [6]:
from ecc import Point
p1 = Point(-1, -1, 5, 7)
p2 = Point(-1, 1, 5, 7)
inf = Point(None, None, 5, 7)
print(p1 + inf)
print(inf + p2)
print(p1 + p2)

Point(-1,-1)_5_7
Point(-1,1)_5_7


NotImplementedError: 

### Exercise 3

Handle the case where the two points are additive inverses. That is, they have the same `x`, but a different `y`, causing a vertical line. This should return the point at infinity.

#### Make [this test](/edit/code-ch02/ecc.py) pass: `ecc.py:PointTest:test_add0`

In [None]:
# Exercise 3

reload(ecc)
run(ecc.PointTest("test_add0"))

### Exercise 4

For the curve \\(y^{2}\\)=\\(x^{3}\\)+5x+7, what is (2,5) + (-1,-1)?

In [None]:
# Exercise 4

from ecc import Point

a = 5
b = 7
x1, y1 = 2, 5
x2, y2 = -1, -1

# (x1,y1) + (x2,y2)

### Exercise 5

Write the `__add__` method where \\(x_{1}\\)≠\\(x_{2}\\)

#### Make [this test](/edit/code-ch02/ecc.py) pass: `ecc.py:PointTest:test_add1`

In [None]:
# Exercise 5

reload(ecc)
run(ecc.PointTest("test_add1"))

### Exercise 6

For the curve \\(y^{2}\\)=\\(x^{3}\\)+5x+7, what is (-1,-1) + (-1,-1)?

In [None]:
# Exercise 6

from ecc import Point

a = 5
b = 7
x1, y1 = -1, -1
# (-1,-1) + (-1,-1)

### Exercise 7

Write the `__add__` method when \\(P_{1}\\)=\\(P_{2}\\).

#### Make [this test](/edit/code-ch02/ecc.py) pass: `ecc.py:PointTest:test_add2`

In [None]:
# Exercise 7

reload(ecc)
run(ecc.PointTest("test_add2"))