In [None]:
import os
import time
from tqdm.auto import tqdm
from chipwhisperer.capture.targets.CW305 import CW305
from low_level.jarvis import Jarvis
from low_level.picoscope import PicoScope
from ciphers import Aes
from io_dat.output_writer import OutputWriter

## Load bitstream

In [None]:
PLL_NUM = 1
FREQUENCY_CLK_PLL_1 = 100  # MHz
VCC = 1.0  # Volts

BISTREAM_PATH = "</path/to/bistream.bit>"
BINARY_PATH = "</path/to/binary.vmem>"

In [None]:
assert os.path.isfile(BISTREAM_PATH), "Bitstream file not found"

cw = CW305()
cw.con(bsfile=BISTREAM_PATH, force=True)
# Voltage configuration
cw.vccint_set(VCC)
assert cw.vccint_get() == VCC
print("Vcc =", VCC)
# PLL configuration
# we need to output to PLL_1 because we are using N13 pin for the clock pin in the xdc
cw.pll.cdce906init()
frequency = FREQUENCY_CLK_PLL_1 * 1000000
cw.pll.pll_outfreq_set(freq=frequency, outnum=PLL_NUM)
print("PLL {} - frequency: {:,}".format(PLL_NUM, cw.pll.pll_outfreq_get(PLL_NUM)))
cw.pll.pll_outenable_set(enabled=True, outnum=PLL_NUM)
print("PLL {} - {}".format(PLL_NUM,
      "enabled" if cw.pll.pll_outenable_get(PLL_NUM) else "disabled"))

In [None]:
jarvis = Jarvis("/dev/ttyUSB0")
assert jarvis.isConnected(), "Cannot connect to SoC"

In [None]:
assert os.path.isfile(BINARY_PATH), "Binary file not found"
jarvis.loadBinary(BINARY_PATH)

## Configure addresses

In [None]:
PLAINTEXT = 0x41414141414141414141414141414141
KEY       = 0x2b7e151628aed2a6abf7158809cf4f3c

breakpoints   = [0x00, 0x00]
triggerpoints = [0x00, 0x00]
state_addr    = 0x00
key_addr      = 0x00

state_size_words = 4
state_size_bytes = 4*state_size_words

## Set Jarvis

In [None]:
jarvis.configureCpu(Jarvis.cpuId, breakpoints, triggerpoints)
assert jarvis.getBreakPoints(Jarvis.cpuId) == breakpoints
assert jarvis.getTriggerPoints(Jarvis.cpuId) == triggerpoints

In [None]:
jarvis.memoryWriteVariable(state_addr, state_size_words, PLAINTEXT)
readState = jarvis.memoryReadVariable(state_addr, state_size_words)
assert readState == PLAINTEXT

In [None]:
jarvis.memoryWriteVariable(key_addr, state_size_words, KEY)
readKey = jarvis.memoryReadVariable(key_addr, state_size_words)
assert readKey == KEY

In [None]:
jarvis.restartCpu(Jarvis.cpuId)

## Set Piscoscope

In [None]:
BATCH_SIZE = 2048
BATCH_NUMBER = 1

batch_size_jarvis = 2048

In [None]:
pico = PicoScope()

In [None]:
pico.setBatchSize(BATCH_SIZE)
pico.setup(3)

assert pico.getSamplesPerChannel() == pico.getSamplesPerSegment()*BATCH_SIZE

# Segment duration in nanoseconds
print("Segment duration:", pico.getTimeBaseNanoseconds()
      * pico.getSamplesPerSegment(), "ns")
print("Sampling frequency: ", 1000/pico.getTimeBaseNanoseconds(), "MHz")

## Capture side-channel traces
We save one batch for each key byte value

In [None]:
TRACES_PATH = "</path/to/save/traces/folder>"
BATCH_NUMBER = 1

In [None]:
first_key_bytes = list(range(0, 256))
n_traces_to_save = 1024

prefix = "%d-%d-%d_AES_k0x" % time.localtime()[0:3]
suffix = f"_{n_traces_to_save}.dat"

jarvis.setFrequency(Jarvis.dfsId, 50)

for kb in tqdm(first_key_bytes):

    # Update first key's byte
    jarvis.memoryWriteVariable(state_addr, state_size_words, PLAINTEXT)
    jarvis.memoryWriteByte(key_addr, int(kb))

    curren_key = jarvis.memoryReadVariable(key_addr, state_size_words)
    current_plain = jarvis.memoryReadVariable(state_addr, state_size_words)

    # Configure output file
    output_file = TRACES_PATH + prefix + "%.2x" % kb + suffix
    ow = OutputWriter(output_file, pico)
    ow.writeDatHeader(curren_key)

    jarvis.restartCpu(Jarvis.cpuId)

    for _ in range(0, BATCH_NUMBER):
        pico.run()

        jarvis.resumeCPU(Jarvis.cpuId)

        # Compute golden model plaintexts
        cipher = Aes(curren_key, current_plain)
        texts = cipher.computePlaintexts(batch_size_jarvis + 1)

        jarvis.waitForBreakPoint(Jarvis.cpuId, 100000)
        # Depending on the DU configuration, the CPU may halt one clock cycle after
        # reaching the breakpoint. In this case, the CPU may be halted one instruction later. 
        assert (jarvis.getPc(Jarvis.cpuId) == breakpoints[1]) | (
            jarvis.getPc(Jarvis.cpuId) == (breakpoints[1]+4)), f"Actual PC is {hex(jarvis.getPc(Jarvis.cpuId))}"

        current_plain = texts[-1]
        assert current_plain == jarvis.memoryReadVariable(state_addr, state_size_words)

        # Retrieve data from PicoScope and write it to file
        data = pico.retrieveData("A")
        ow.writeMemorySegments(data[:n_traces_to_save],
                              texts[:n_traces_to_save])

        jarvis.restartCpu(Jarvis.cpuId)