Skip to content

Commit

Permalink
Add NE instruction.
Browse files Browse the repository at this point in the history
  • Loading branch information
thealmarty committed Jan 3, 2024
1 parent 911ed0f commit 8678667
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 1 deletion.
1 change: 1 addition & 0 deletions alu_u32/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod add;
pub mod bitwise;
pub mod div;
pub mod lt;
pub mod ne;
pub mod mul;
pub mod shift;
pub mod sub;
29 changes: 29 additions & 0 deletions alu_u32/src/ne/columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use core::borrow::{Borrow, BorrowMut};
use core::mem::{size_of, transmute};
use valida_derive::AlignedBorrow;
use valida_machine::Word;
use valida_util::indices_arr;

#[derive(AlignedBorrow, Default)]
pub struct Ne32Cols<T> {
pub input_1: Word<T>,
pub input_2: Word<T>,

/// Boolean flags indicating which byte pair differs
pub byte_flag: [T; 3],

/// Bit decomposition of 256 + input_1 - input_2
pub bits: [T; 10],

pub output: T,

pub multiplicity: T,
}

pub const NUM_NE_COLS: usize = size_of::<Ne32Cols<u8>>();
pub const NE_COL_MAP: Ne32Cols<usize> = make_col_map();

const fn make_col_map() -> Ne32Cols<usize> {
let indices_arr = indices_arr::<NUM_NE_COLS>();
unsafe { transmute::<[usize; NUM_NE_COLS], Ne32Cols<usize>>(indices_arr) }
}
155 changes: 155 additions & 0 deletions alu_u32/src/ne/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
extern crate alloc;

use alloc::vec;
use alloc::vec::Vec;
use columns::{Ne32Cols, NE_COL_MAP, NUM_NE_COLS};
use core::iter;
use core::mem::transmute;
use valida_bus::MachineWithGeneralBus;
use valida_cpu::MachineWithCpuChip;
use valida_machine::{
instructions, Chip, Instruction, Interaction, Operands, Word, MEMORY_CELL_BYTES,
};
use valida_opcodes::NE32;

use p3_air::VirtualPairCol;
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::*;
use valida_util::pad_to_power_of_two;

pub mod columns;
pub mod stark;

#[derive(Clone)]
pub enum Operation {
Ne32(Word<u8>, Word<u8>, Word<u8>), // (dst, src1, src2)
}

#[derive(Default)]
pub struct Ne32Chip {
pub operations: Vec<Operation>,
}

impl<F, M> Chip<M> for Ne32Chip
where
F: PrimeField,
M: MachineWithGeneralBus<F = F>,
{
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<M::F> {
let rows = self
.operations
.par_iter()
.map(|op| self.op_to_row(op))
.collect::<Vec<_>>();

let mut trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_NE_COLS);

pad_to_power_of_two::<NUM_NE_COLS, F>(&mut trace.values);

trace
}

fn global_receives(&self, machine: &M) -> Vec<Interaction<M::F>> {
let opcode = VirtualPairCol::constant(M::F::from_canonical_u32(NE32));
let input_1 = NE_COL_MAP.input_1.0.map(VirtualPairCol::single_main);
let input_2 = NE_COL_MAP.input_2.0.map(VirtualPairCol::single_main);
let output = (0..MEMORY_CELL_BYTES - 1)
.map(|_| VirtualPairCol::constant(M::F::zero()))
.chain(iter::once(VirtualPairCol::single_main(NE_COL_MAP.output)));

let mut fields = vec![opcode];
fields.extend(input_1);
fields.extend(input_2);
fields.extend(output);

let receive = Interaction {
fields,
count: VirtualPairCol::single_main(NE_COL_MAP.multiplicity),
argument_index: machine.general_bus(),
};
vec![receive]
}
}

impl Ne32Chip {
fn op_to_row<F>(&self, op: &Operation) -> [F; NUM_NE_COLS]
where
F: PrimeField,
{
let mut row = [F::zero(); NUM_NE_COLS];
let cols: &mut Ne32Cols<F> = unsafe { transmute(&mut row) };

match op {
Operation::Ne32(dst, src1, src2) => {
if let Some(n) = src1
.into_iter()
.zip(src2.into_iter())
.enumerate()
.find_map(|(n, (x, y))| if x == y { Some(n) } else { None })
{
let z = 256u16 + src1[n] as u16 - src2[n] as u16;
for i in 0..10 {
cols.bits[i] = F::from_canonical_u16(z >> i & 1);
}
if n < 3 {
cols.byte_flag[n] = F::one();
}
}
cols.input_1 = src1.transform(F::from_canonical_u8);
cols.input_2 = src2.transform(F::from_canonical_u8);
cols.output = F::from_canonical_u8(dst[3]);
cols.multiplicity = F::one();
}
}
row
}
}

pub trait MachineWithNe32Chip: MachineWithCpuChip {
fn ne_u32(&self) -> &Ne32Chip;
fn ne_u32_mut(&mut self) -> &mut Ne32Chip;
}

instructions!(Ne32Instruction);

impl<M> Instruction<M> for Ne32Instruction
where
M: MachineWithNe32Chip,
{
const OPCODE: u32 = NE32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1 = state.mem_mut().read(clk, read_addr_1, true, pc, opcode, 0, "");
let src2 = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state.mem_mut().read(clk, read_addr_2, true, pc, opcode, 1, "")
};

let dst = if src1 != src2 {
Word::from(1)
} else {
Word::from(0)
};
state.mem_mut().write(clk, write_addr, dst, true);

state
.ne_u32_mut()
.operations
.push(Operation::Ne32(dst, src1, src2));
state
.cpu_mut()
.push_bus_op(imm, opcode, ops);
}
}
72 changes: 72 additions & 0 deletions alu_u32/src/ne/stark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use super::columns::Ne32Cols;
use super::Ne32Chip;
use core::borrow::Borrow;

use crate::ne::columns::NUM_NE_COLS;
use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::AbstractField;
use p3_matrix::MatrixRowSlices;

impl<F> BaseAir<F> for Ne32Chip {
fn width(&self) -> usize {
NUM_NE_COLS
}
}

impl<F, AB> Air<AB> for Ne32Chip
where
F: AbstractField,
AB: AirBuilder<F = F>,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local: &Ne32Cols<AB::Var> = main.row_slice(0).borrow();

let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512].map(AB::Expr::from_canonical_u32);

let bit_comp: AB::Expr = local
.bits
.into_iter()
.zip(base_2.iter().cloned())
.map(|(bit, base)| bit * base)
.sum();

// Check bit decomposition of z = 256 + input_1[n] - input_2[n], where
// n is the most significant byte that differs between inputs
for i in 0..3 {
builder
.when_ne(local.byte_flag[i], AB::Expr::one())
.assert_eq(local.input_1[i], local.input_2[i]);

builder.when(local.byte_flag[i]).assert_eq(
AB::Expr::from_canonical_u32(256) + local.input_1[i] - local.input_2[i],
bit_comp.clone(),
);

builder.assert_bool(local.byte_flag[i]);
}

// Check final byte (if no other byte flags were set)
let flag_sum = local.byte_flag[0] + local.byte_flag[1] + local.byte_flag[2];
builder.assert_bool(flag_sum.clone());
builder
.when_ne(local.multiplicity, AB::Expr::zero())
.when_ne(flag_sum, AB::Expr::one())
.assert_eq(
AB::Expr::from_canonical_u32(256) + local.input_1[3] - local.input_2[3],
bit_comp.clone(),
);

// Output constraints
builder.when(local.bits[8]).assert_zero(local.output);
builder
.when_ne(local.multiplicity, AB::Expr::zero())
.when_ne(local.bits[8], AB::Expr::one())
.assert_one(local.output);

// Check bit decomposition
for bit in local.bits.into_iter() {
builder.assert_bool(bit);
}
}
}
17 changes: 17 additions & 0 deletions basic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use valida_alu_u32::{
},
div::{Div32Chip, Div32Instruction, SDiv32Instruction, MachineWithDiv32Chip},
lt::{Lt32Chip, Lt32Instruction, MachineWithLt32Chip},
ne::{Ne32Chip, Ne32Instruction, MachineWithNe32Chip},
mul::{MachineWithMul32Chip, Mul32Chip, Mul32Instruction},
shift::{MachineWithShift32Chip, Shift32Chip, Shl32Instruction, Shr32Instruction},
sub::{MachineWithSub32Chip, Sub32Chip, Sub32Instruction},
Expand Down Expand Up @@ -74,6 +75,8 @@ pub struct BasicMachine<F: PrimeField64 + TwoAdicField, EF: ExtensionField<F>> {
shr32: Shr32Instruction,
#[instruction(lt_u32)]
lt32: Lt32Instruction,
#[instruction(ne_u32)]
ne32: Ne32Instruction,
#[instruction(bitwise_u32)]
and32: And32Instruction,
#[instruction(bitwise_u32)]
Expand Down Expand Up @@ -106,6 +109,8 @@ pub struct BasicMachine<F: PrimeField64 + TwoAdicField, EF: ExtensionField<F>> {
#[chip]
lt_u32: Lt32Chip,
#[chip]
ne_u32: Ne32Chip,
#[chip]
bitwise_u32: Bitwise32Chip,
#[chip]
output: OutputChip,
Expand Down Expand Up @@ -256,6 +261,18 @@ impl<F: PrimeField64 + TwoAdicField, EF: ExtensionField<F>> MachineWithLt32Chip
}
}

impl<F: PrimeField64 + TwoAdicField, EF: ExtensionField<F>> MachineWithNe32Chip
for BasicMachine<F, EF>
{
fn ne_u32(&self) -> &Ne32Chip {
&self.ne_u32
}

fn ne_u32_mut(&mut self) -> &mut Ne32Chip {
&mut self.ne_u32
}
}

impl<F: PrimeField64 + TwoAdicField, EF: ExtensionField<F>> MachineWithShift32Chip
for BasicMachine<F, EF>
{
Expand Down
2 changes: 1 addition & 1 deletion opcodes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub const SHR32: u32 = 106;
pub const AND32: u32 = 107;
pub const OR32: u32 = 108;
pub const XOR32: u32 = 109;
pub const NE: u32 = 111; //TODO
pub const NE32: u32 = 111; //TODO
pub const MULHU32 : u32 = 112; //TODO
pub const SRA32 : u32 = 113; //TODO
pub const MULHS32 : u32 =114; //TODO
Expand Down

0 comments on commit 8678667

Please sign in to comment.