In [7]:
def generate_median_network(n):
    def compare_exchange(i, j):
        return f"PIX_SORT(p[{i}], p[{j}]);"

    def bitonic_merge(low, cnt, ascending):
        if cnt > 1:
            k = cnt // 2
            for i in range(low, low + k):
                if ascending == (i % 2 == 0):
                    yield compare_exchange(i, i + k)
            yield from bitonic_merge(low, k, ascending)
            yield from bitonic_merge(low + k, k, ascending)

    def bitonic_sort(low, cnt, ascending):
        if cnt > 1:
            k = cnt // 2
            yield from bitonic_sort(low, k, True)
            yield from bitonic_sort(low + k, k, False)
            yield from bitonic_merge(low, cnt, ascending)

    return list(bitonic_sort(0, n, True))

def print_cuda_median_network(n):
    network = generate_median_network(n)
    print("typedef unsigned char pixelvalue;")
    print("#define PIX_SORT(a,b) { if ((a)>(b)) PIX_SWAP((a),(b)); }")
    print("#define PIX_SWAP(a,b) { pixelvalue temp=(a);(a)=(b);(b)=temp; }")
    print()
    print(f"__device__ pixelvalue opt_med{n}(pixelvalue * p) {{")
    
    # Group comparisons into lines of 3 (or less for the last line)
    for i in range(0, len(network), 3):
        line = " ".join(network[i:i+3])
        print(f"    {line}")
    
    # Add return statement
    if n % 2 == 0:
        print(f"    return (p[{n//2 - 1}] + p[{n//2}]) / 2;")
    else:
        print(f"    return p[{n//2}];")
    
    print("}")

# Example usage
print_cuda_median_network(49)

typedef unsigned char pixelvalue;
#define PIX_SORT(a,b) { if ((a)>(b)) PIX_SWAP((a),(b)); }
#define PIX_SWAP(a,b) { pixelvalue temp=(a);(a)=(b);(b)=temp; }

__device__ pixelvalue opt_med49(pixelvalue * p) {
    PIX_SORT(p[0], p[1]); PIX_SORT(p[3], p[4]); PIX_SORT(p[0], p[3]);
    PIX_SORT(p[2], p[5]); PIX_SORT(p[0], p[1]); PIX_SORT(p[6], p[7]);
    PIX_SORT(p[9], p[10]); PIX_SORT(p[7], p[10]); PIX_SORT(p[9], p[10]);
    PIX_SORT(p[0], p[6]); PIX_SORT(p[2], p[8]); PIX_SORT(p[4], p[10]);
    PIX_SORT(p[0], p[3]); PIX_SORT(p[2], p[5]); PIX_SORT(p[0], p[1]);
    PIX_SORT(p[6], p[9]); PIX_SORT(p[8], p[11]); PIX_SORT(p[6], p[7]);
    PIX_SORT(p[12], p[13]); PIX_SORT(p[15], p[16]); PIX_SORT(p[12], p[15]);
    PIX_SORT(p[14], p[17]); PIX_SORT(p[12], p[13]); PIX_SORT(p[18], p[19]);
    PIX_SORT(p[21], p[22]); PIX_SORT(p[19], p[22]); PIX_SORT(p[21], p[22]);
    PIX_SORT(p[13], p[19]); PIX_SORT(p[15], p[21]); PIX_SORT(p[17], p[23]);
    PIX_SORT(p[13], p[16]); PIX_SORT(p[15], p[16]); PIX_SORT(p[1

In [13]:
import random

def generate_median_network(n):
    def compare_exchange(i, j):
        return f"PIX_SORT(p[{i}], p[{j}]);"

    network = []
    for i in range(n):
        for j in range(i + 1, n):
            network.append((i, j))
    
    return network

def print_cuda_median_network(n):
    network = generate_median_network(n)
    print("typedef unsigned char pixelvalue;")
    print("#define PIX_SORT(a,b) { if ((a)>(b)) PIX_SWAP((a),(b)); }")
    print("#define PIX_SWAP(a,b) { pixelvalue temp=(a);(a)=(b);(b)=temp; }")
    print()
    print(f"__device__ pixelvalue opt_med{n}(pixelvalue * p) {{")
    
    # Group comparisons into lines of 3 (or less for the last line)
    for i in range(0, len(network), 3):
        line = " ".join(f"PIX_SORT(p[{a}], p[{b}]);" for a, b in network[i:i+3])
        print(f"    {line}")
    
    # Add return statement
    if n % 2 == 0:
        print(f"    return (p[{n//2 - 1}] + p[{n//2}]) / 2;")
    else:
        print(f"    return p[{n//2}];")
    
    print("}")

def test_median_network(n, num_tests=1000):
    network = generate_median_network(n)
    
    def apply_network(arr):
        for i, j in network:
            if arr[i] > arr[j]:
                arr[i], arr[j] = arr[j], arr[i]
        return arr[n//2] if n % 2 != 0 else (arr[n//2 - 1] + arr[n//2]) // 2

    for _ in range(num_tests):
        test_array = [random.randint(0, 255) for _ in range(n)]
        network_result = apply_network(test_array.copy())
        true_median = sorted(test_array)[n//2]
        
        if network_result != true_median:
            print(f"Test failed for input: {test_array}")
            print(f"Network result: {network_result}")
            print(f"True median: {true_median}")
            return False

    print(f"All {num_tests} tests passed for n={n}")
    return True

# Test the network generator
test_median_network(81)

# Example usage
print_cuda_median_network(81)

All 1000 tests passed for n=81
typedef unsigned char pixelvalue;
#define PIX_SORT(a,b) { if ((a)>(b)) PIX_SWAP((a),(b)); }
#define PIX_SWAP(a,b) { pixelvalue temp=(a);(a)=(b);(b)=temp; }

__device__ pixelvalue opt_med81(pixelvalue * p) {
    PIX_SORT(p[0], p[1]); PIX_SORT(p[0], p[2]); PIX_SORT(p[0], p[3]);
    PIX_SORT(p[0], p[4]); PIX_SORT(p[0], p[5]); PIX_SORT(p[0], p[6]);
    PIX_SORT(p[0], p[7]); PIX_SORT(p[0], p[8]); PIX_SORT(p[0], p[9]);
    PIX_SORT(p[0], p[10]); PIX_SORT(p[0], p[11]); PIX_SORT(p[0], p[12]);
    PIX_SORT(p[0], p[13]); PIX_SORT(p[0], p[14]); PIX_SORT(p[0], p[15]);
    PIX_SORT(p[0], p[16]); PIX_SORT(p[0], p[17]); PIX_SORT(p[0], p[18]);
    PIX_SORT(p[0], p[19]); PIX_SORT(p[0], p[20]); PIX_SORT(p[0], p[21]);
    PIX_SORT(p[0], p[22]); PIX_SORT(p[0], p[23]); PIX_SORT(p[0], p[24]);
    PIX_SORT(p[0], p[25]); PIX_SORT(p[0], p[26]); PIX_SORT(p[0], p[27]);
    PIX_SORT(p[0], p[28]); PIX_SORT(p[0], p[29]); PIX_SORT(p[0], p[30]);
    PIX_SORT(p[0], p[31]); PIX_SORT(p[0],