In [1]:
def mod_mul_standard(a, b, p):
    result = (a * b) % p
    if result > (p>>1):
        result -= p
    return result

In [2]:
def Mu(b,q):
    mu = 2*int(((b << 31)/(2*q))+0.5)
    return mu

In [3]:
def Barrett_mul(a, b, p, mu_b):
    # Perform the multiplication
    z = (a * b) % (1 << 32)
    if z >= (1 << 31):
        z -= (1 << 32)  # Convert to signed 32-bit integer

    # Perform the multiplication for mu_b and get the high 32 bits
    temp = (a * (mu_b << 1)) % (1 << 64)
    t_high = (temp >> 32) % (1 << 32)
    if t_high >= (1 << 31):
        t_high -= (1 << 32)  # Convert to signed 32-bit integer
    
    result = z - t_high * p
    # Simulate 32-bit signed integer overflow for the result
    result %= (1 << 32)
    if result >= (1 << 31):
        result -= (1 << 32)  # Convert to signed 32-bit integer

    return result

In [4]:
a = 312222
b = 213222
p = 33550337
mu_b = Mu(b,p)
print(mu_b)

print(Barrett_mul(a,b,p,mu_b))
print(mod_mul_standard(a,b,p))

13647874
8730676
8730676


In [5]:
%%writefile Barrett_mul.c

#include <stdio.h>
#include <stdint.h>
#include <stdlib.h> // 提供 atoi 函數

// Barrett Reduction Multiplication
int32_t Barrett_mul(int32_t a, int32_t b, int32_t mod, int32_t mu_b) {
    int32_t z = a * b;
    int32_t t_high = (int32_t)(((int64_t)a * (int64_t)(mu_b << 1)) >> 32);
    return (z - t_high * mod);
}

int main(int argc, char *argv[]) {
    if (argc != 5) { // 檢查參數數量
        printf("Usage: ./Barrett_mul <a> <b> <mod> <mu_b>\n");
        return 1; // 錯誤返回值
    }

    // 將命令列參數轉為整數
    int32_t a = atoi(argv[1]);
    int32_t b = atoi(argv[2]);
    int32_t mod = atoi(argv[3]);
    int32_t mu_b = atoi(argv[4]);

    // 計算結果
    int32_t result = Barrett_mul(a, b, mod, mu_b);
    printf("Result: %d\n", result);

    return 0;
}

Overwriting Barrett_mul.c


In [6]:
!gcc -o Barrett_mul Barrett_mul.c

In [7]:
import subprocess

# 定義參數
a = 312222
b = 213222
mod = 33550337
mu_b = 13647874

# 執行 C 程式
process = subprocess.run(
    ["./Barrett_mul", str(a), str(b), str(mod), str(mu_b)],
    capture_output=True,
    text=True
)

# 處理輸出
if process.returncode == 0:
    # 擷取並轉換輸出結果
    result = int(process.stdout.split(":")[1].strip())
    print(f"Captured Result: {result}")
else:
    print("Error:", process.stderr)

Captured Result: 8730676


In [None]:
result == Barrett_mul(a,b,p,mu_b)

True

In [9]:
# correctness test

import subprocess

# 固定的模數
p = 33550337



# 遍歷 b
for b in range(-p // 2, p // 2 + 1):
    mu_b = Mu(b, p)  # 計算對應的 mu 值

    # 遍歷 a
    for a in range(-p // 2, p // 2 + 1):
        # 執行 C 程式
        process = subprocess.run(
            ["./Barrett_mul", str(a), str(b), str(p), str(mu_b)],
            capture_output=True,
            text=True
        )

        # 確保執行成功
        if process.returncode != 0:
            print(f"Error for a={a}, b={b}, mu_b={mu_b}: {process.stderr}")
            continue

        # 獲取結果
        c_result = int(process.stdout.split(":")[1].strip())  # C 程式輸出的結果
        python_result = mod_mul_standard(a, b, p)  # Python 標準結果

        # 驗證
        if c_result != python_result:
            print(f"Discrepancy found: a={a}, b={b}, C_result={c_result}, Python_result={python_result}")


Discrepancy found: a=-16775169, b=-16775169, C_result=83876866, Python_result=-8387584
Discrepancy found: a=-16775168, b=-16775169, C_result=33551360, Python_result=8387584
Discrepancy found: a=-16775167, b=-16775169, C_result=16776191, Python_result=-8387585
Discrepancy found: a=-16775166, b=-16775169, C_result=-33549315, Python_result=8387583
Discrepancy found: a=-16775165, b=-16775169, C_result=-50324484, Python_result=-8387586
Discrepancy found: a=-16775164, b=-16775169, C_result=-100649990, Python_result=8387582
Discrepancy found: a=-16775163, b=-16775169, C_result=-117425159, Python_result=-8387587
Discrepancy found: a=-16775162, b=-16775169, C_result=-167750665, Python_result=8387581
Discrepancy found: a=-16775161, b=-16775169, C_result=-184525834, Python_result=-8387588
Discrepancy found: a=-16775160, b=-16775169, C_result=-234851340, Python_result=8387580
Discrepancy found: a=-16775159, b=-16775169, C_result=-251626509, Python_result=-8387589
Discrepancy found: a=-16775158, b=

KeyboardInterrupt: 