<a href="https://colab.research.google.com/github/ddinesan/Manga/blob/master/Lecture_13.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is a worksheet prepared by Kai Wang and Dominic Fluet for Lecture 13 of the reading course "Introduction to Quantum Computer Programming" (AMATH 900/ AMATH 495/ QIC 895) at the University of Waterloo.

Course Webpage: https://sites.google.com/view/quantum-computer-programming

Text followed in the course: [Quantum Computing, An Applied Approach](https://www.springer.com/gp/book/9783030239213) by Jack D. Hidary (2019)


# Shor's algorithm
Shor's algorithm takes an integer as an input and output the factors of that integer by finding the period $r$, of a certain function $f$.  The general procedure is:



1.   Choose a random number a < N
2.   Compute gcd(a,N) using Euclid’s algorithm
3.   If $gcd(a,N)$ is not 1, a shares a nontrival factor with $N$. 
4.   **Otherwise: find the period of $ a^S(mod 𝑁)$ and call it r**
5.   If r is odd, or $a^{r/2} ≠  -1 (𝑚𝑜𝑑 𝑁)  $
, find a new number and start over 
6.  Else, we know that gcd$(a^{r/2}+1,N)$, gcd($a^{r/2}-1,N$) are both factors of N.

Quantum Fourier transform allows us to find the period, so we will a quantum computer for Step 4. All other steps can be done efficiently using a classical computer.




# Implementation:

# Importing Cirq and setup

In [None]:
# install cirq
!pip install cirq 
import math
import random
import argparse


def printInfo(str):
  print(str)




# Quantum Components
This time we will not deal with qubits directly. Recall that QC with $n$ qubits has $2^n$  basis states, so $2^n$ quantum states with amplitudes can describe the behaviour of QC with $n$ qubits.
This is a diagram of the classes. 

![alt text](https://drive.google.com/uc?id=1gqJu_ruHUODjkWpDmGSVfECDLMPN1ks4)



In [None]:
class Mapping:
  def __init__(self, state, amplitude):
    self.state = state
    self.amplitude = amplitude

class QuantumState:
  def __init__(self, amplitude, register):
    self.amplitude = amplitude
    self.register = register
    self.entangled = {}# {} means dictinonary in python, 
                       #  a data structure stores key-value pairs.

  # Add entanglement information, stored in Mapping pair.
  def entangle(self, fromState, amplitude):
    register = fromState.register
    entanglement = Mapping(fromState, amplitude)
    try:
      self.entangled[register].append(entanglement)
    except KeyError:
      # if entangled[register] is not initialized, KeyError will be thrown
      self.entangled[register] = [entanglement]
  
  #return the total number of entanglement from a given register.
  def entangles(self, register = None):
    entangles = 0
    if register is None:
      for states in self.entangled.values():
        entangles += len(states)
    else:
      entangles = len(self.entangled[register])

    return entangles

class QubitRegister:
  def __init__(self, numBits):
    self.numBits = numBits
    self.numStates = 1 << numBits #shift 1 to left by numBits, same as 2^numBits
                                  #QC with N qubits has 2^N  basis states
    self.entangled = [] # contains a list of registers that are entangled with current register
    self.states = [QuantumState(complex(0.0), self) for x in range(self.numStates)]
    self.states[0].amplitude = complex(1.0)

  # Update amplitudes of states according to entanglement information stroed.
  def propagate(self, fromRegister = None):
    if fromRegister is not None:
      for state in self.states:
        amplitude = complex(0.0)

        try:
          entangles = state.entangled[fromRegister]
          for entangle in entangles:
            amplitude += entangle.state.amplitude * entangle.amplitude

            state.amplitude = amplitude
        except KeyError:
          state.amplitude = amplitude
    for register in self.entangled:
      if register is fromRegister:
        continue

      register.propagate(self)

  # map applies an unitary operator or transform (mapping)
  # Map will convert mapping to a unitary tensor given each element v,
  # returned by the mapping has the property v * v.conjugate() = 1
  def map(self, toRegister, mapping, propagate = True):
    self.entangled.append(toRegister)
    toRegister.entangled.append(self)
    # Create the covariant/contravariant representations
    mapTensorX = {}
    mapTensorY = {}
    for x in range(self.numStates):
      mapTensorX[x] = {}
      codomain = mapping(x)
      for element in codomain:
        y = element.state
        mapTensorX[x][y] = element

        try:
          mapTensorY[y][x] = element
        except KeyError:
          mapTensorY[y] = { x: element }

    # Normalize the mapping:
    def normalize(tensor, p = False):
      lSqrt = math.sqrt
      for vectors in tensor.values():
        sumProb = 0.0
        for element in vectors.values():
          amplitude = element.amplitude
          sumProb += (amplitude * amplitude.conjugate()).real

        normalized = lSqrt(sumProb)
        for element in vectors.values():
          element.amplitude = element.amplitude / normalized

    normalize(mapTensorX)
    normalize(mapTensorY, True)

    # Entangle the registers
    for x, yStates in mapTensorX.items():
      for y, element in yStates.items():
        amplitude = element.amplitude
        toState = toRegister.states[y]
        fromState = self.states[x]
        toState.entangle(fromState, amplitude)
        fromState.entangle(toState, amplitude.conjugate())

    if propagate:
      toRegister.propagate(self)
    
  # An interesting algorithm to choose a quantum state to be the result 
  # according to amplitudes.
  def measure(self):
    measure = random.random()
    sumProb = 0.0

    # Pick a state
    finalX = None
    finalState = None
    for x, state in enumerate(self.states):
      amplitude = state.amplitude
      sumProb += (amplitude * amplitude.conjugate()).real

      if sumProb > measure:
        finalState = state
        finalX = x
        break

    # If state was found, update the system
    if finalState is not None:
      for state in self.states:
        state.amplitude = complex(0.0)

      finalState.amplitude = complex(1.0)
      self.propagate()
    return finalX

  # returns all the entanglement information in this register
  def entangles(self, register = None):
    entangles = 0
    for state in self.states:
      entangles += state.entangles(None)

    return entangles

  # returns all the amplitudes of all states in this register
  def amplitudes(self):
    amplitudes = []
    for state in self.states:
      amplitudes.append(state.amplitude)

    return amplitudes



Now we have a quantum register that can:


1.   initialize arbitary number of qubits,
2.   entangle qubits as needed,
3.   perform some quantum operations.




# Quantum operations 

In [None]:
def printEntangles(register):
  printInfo("Entagles: " + str(register.entangles()))


def printAmplitudes(register):
  amplitudes = register.amplitudes()
  for x, amplitude in enumerate(amplitudes):
    printInfo('State #' + str(x) + '\'s amplitude: ' +str(amplitude))

def hadamard(x, Q):
  codomain = []
  for y in range(Q):
    amplitude = complex(pow(-1.0, bitCount(x & y) & 1))
    codomain.append(Mapping(y, amplitude))

  return codomain

# Quantum Modular Exponentiation
def qModExp(a, exp, mod):
  state = modExp(a, exp, mod)
  amplitude = complex(1.0)
  return [Mapping(state, amplitude)]

# Quantum Fourier Transform
def qft(x, Q):
  fQ = float(Q)
  k = -2.0 * math.pi
  codomain = []

  for y in range(Q):
    theta = (k * float((x * y) % Q)) / fQ
    amplitude = complex(math.cos(theta), math.sin(theta))
    codomain.append(Mapping(y, amplitude))

  return codomain

# Find the period

Procedures:


1.   Initialize all input qubits to superposition using Hadamard gates.
2.   Apply the unitary transform that implements the function $f = a^S(mod 𝑁)$
3.   Perform the quantum Fourier transform on the input register.
4.   Measurement will give a value c with high probability, such that $\frac c q =\frac s r$, where $0 \leq c \leq r-1$, $s$ is an integer.
5.   Use continued-fraction algorithm to find the period $r$ given $c$.



The circuit:

![](https://drive.google.com/uc?id=1JK2_ztrDdRo1Nori9bo6YHLplfcyT-Ay)


In [None]:
def findPeriod(a, N):
  nNumBits = N.bit_length()
  inputNumBits = (2 * nNumBits) - 1
  inputNumBits += 1 if ((1 << inputNumBits) < (N * N)) else 0
  Q = 1 << inputNumBits

  printInfo("Finding the period...")
  printInfo("Q = " + str(Q) + "\ta = " + str(a))

  inputRegister = QubitRegister(inputNumBits)
  hmdInputRegister = QubitRegister(inputNumBits)
  qftInputRegister = QubitRegister(inputNumBits)
  outputRegister = QubitRegister(inputNumBits)

  printInfo("Registers generated")
  ##############################################################################
  #Initialize all input qubits to superposition using Hadamard gates.
  printInfo("Performing Hadamard on input register")

  inputRegister.map(hmdInputRegister, lambda x: hadamard(x, Q),False)
  
  # inputRegister.hadamard(False)
  printInfo("Hadamard complete")

  ##############################################################################
  #Apply the unitary transform that implements the function  𝑓
  printInfo("Mapping input register to output register, where f(x)is a^x mod N")
  hmdInputRegister.map(outputRegister, lambda x: qModExp(a, x, N),False)

  printInfo("Modular exponentiation complete")

  ##############################################################################
  #Perform the quantum Fourier transform on the input register.
  printInfo("Performing quantum Fourier transform on outputregister")

  hmdInputRegister.map(qftInputRegister, lambda x: qft(x, Q), False)
  inputRegister.propagate()

  printInfo("Quantum Fourier transform complete")

  ##############################################################################
  #Measure the output register will give a value c 
  printInfo("Performing a measurement on the output register")

  y = outputRegister.measure()
  printInfo("Output register measured\ty = " + str(y))

  #Interesting to watch - simply uncomment
  #printAmplitudes(inputRegister)
  #printAmplitudes(qftInputRegister)
  #printAmplitudes(outputRegister)
  #printEntangles(inputRegister)
  printInfo("Performing a measurement on the periodicity register")

  x = qftInputRegister.measure()

  printInfo("QFT register measured\tx = " + str(x))

  if x is None:
    return None

  ##############################################################################
  #Use continued-fraction algorithm to find the period  r given  𝑐
  printInfo("Finding the period via continued fractions")

  r = cf(x, Q, N) #you will see the implementation of cf in next section
  printInfo("Candidate period\tr = " + str(r))
  return r



# Some classical algorithms

In [None]:
BIT_LIMIT = 12

def bitCount(x):
  sumBits = 0
  while x > 0:
    sumBits += x & 1
    x >>= 1
  return sumBits

# Greatest Common Divisor
def gcd(a, b):
  while b != 0:
    tA = a % b
    a = b
    b = tA
  return a

# Extended Euclidean
def extendedGCD(a, b):
  fractions = []
  while b != 0:
    fractions.append(a // b)
    tA = a % b
    a = b
    b = tA
  return fractions

# Continued Fractions
def cf(y, Q, N):
  fractions = extendedGCD(y, Q)
  depth = 2

  def partial(fractions, depth):
    c = 0
    r = 1

    for i in reversed(range(depth)):
      tR = fractions[i] * r + c
      c = r
      r = tR
    return c
  
  r = 0
  for d in range(depth, len(fractions) + 1):
    tR = partial(fractions, d)
    if tR == r or tR >= N:
      return r
    r = tR

  return r

# Modular Exponentiation
def modExp(a, exp, mod):
  fx = 1
  while exp > 0:
    if (exp & 1) == 1:
      fx = fx * a % mod
    a = (a * a) % mod
    exp = exp >> 1

  return fx

# generates a random number that is less than N.
def pick(N):
  a = math.floor((random.random() * (N - 1)) + 0.5)
  return a

# checkCandidates checks if r, multiples of r and neighbourhoods of r
# is the period of f
def checkCandidates(a, r, N, neighborhood):
  if r is None:
    return None
  
  # Check multiples
  for k in range(1, neighborhood + 2):
    tR = k * r
    if modExp(a, a, N) == modExp(a, a + tR, N):
      return tR

  # Check lower neighborhood
  for tR in range(r - neighborhood, r):
    if modExp(a, a, N) == modExp(a, a + tR, N):
      return tR

  # Check upper neighborhood
  for tR in range(r + 1, r + neighborhood + 1):
    if modExp(a, a, N) == modExp(a, a + tR, N):
      return tR
      
  return None

# The algorithm

In [None]:
def shors(N, attempts = 30, neighborhood = 0.01, numPeriods = 2):

  if(N.bit_length() > BIT_LIMIT or N < 3):
    return False

  periods = []
  neighborhood = math.floor(N * neighborhood) + 1

  printInfo("N = " + str(N))
  printInfo("Neighborhood = " + str(neighborhood))
  printInfo("Number of periods = " + str(numPeriods))

  for attempt in range(attempts):
    printInfo("\nAttempt #" + str(attempt))
  ######################################################################
  # choose a random number a < N
    a = pick(N)
    while a < 2:
      a = pick(N)

  ######################################################################
  # Compute gcd(x,N) using Euclid’s algorithm
    d = gcd(a, N)
    if d > 1:
      #If gcd(x,N) is not 1, you have found a factor 
      printInfo("Found factors classically, re-attempt")
      continue
      
  ######################################################################
  # Else: find the period of a^S(mod N) and call it r
    r = findPeriod(a, N)
    printInfo("Checking candidate period, nearby values, and multiples")

    r = checkCandidates(a, r, N, neighborhood)

  ######################################################################
  # If r is not the period, find a new number and start over 
    if r is None:
      printInfo("Period was not found, re-attempt")
      continue

  ######################################################################
  # If r is odd or trivial, find a new number and start over 
    if (r % 2) > 0:
      printInfo("Period was odd, re-attempt")
      continue

    d = modExp(a, (r // 2), N)

    if r == 0 or d == (N - 1):
      printInfo("Period was trivial, re-attempt")
      continue

  ######################################################################
  # Else, we know that gcd(a^(r/2)+1,N), gcd(a^(r/2)-1,N) 
  # are both non-trivial factors of N, 
    

    periods.append(r)
    if(len(periods) < numPeriods):
      continue
    
    printInfo("Period found\tr = " + str(r))
    r = 1
    for period in periods:
      d = gcd(period, r)
      r = (r * period) // d

    b = modExp(a, (r // 2), N)
    f1 = gcd(N, b + 1)
    f2 = gcd(N, b - 1)

    if f1 == 1 or f1 == N:
      continue;

    printInfo("The factors are: ")
    return [f1, f2]

  return None


# Run

In [None]:
shors(15)


N = 15
Neighborhood = 1
Number of periods = 2

Attempt #0
Found factors classically, re-attempt

Attempt #1
Found factors classically, re-attempt

Attempt #2
Finding the period...
Q = 256	a = 7
Registers generated
Performing Hadamard on input register
Hadamard complete
Mapping input register to output register, where f(x)is a^x mod N
Modular exponentiation complete
Performing quantum Fourier transform on outputregister
Quantum Fourier transform complete
Performing a measurement on the output register
Output register measured	y = 1
Performing a measurement on the periodicity register
QFT register measured	x = 128
Finding the period via continued fractions
Candidate period	r = 2
Checking candidate period, nearby values, and multiples

Attempt #3
Finding the period...
Q = 256	a = 2
Registers generated
Performing Hadamard on input register
Hadamard complete
Mapping input register to output register, where f(x)is a^x mod N
Modular exponentiation complete
Performing quantum Fourier transform

[5, 3]

# Reference 

1.   P. W. Shor, “Polynomial-Time Algorithms for Prime Factorization and Discrete Logarithms
on a Quantum Computer,” SIAM Journal on Computing, vol. 26, no. 5, pp. 1484–1509,1997.
2.   Nielsen, M. A. and Chuang, I. L, "Quantum Computation and Quantum Information", Pg 32-36, 2016

