In [None]:
import openai

In [6]:
file_path = "../linear-algebra/kernels/2mm/2mm.c"

# Open the file using the path
with open(file_path, "r") as file:
    # Read the file content into a string
    file_content = file.read()

# Now `file_content` holds the content of the file as a string
print(file_content)

/**
 * This version is stamped on May 10, 2016
 *
 * Contact:
 *   Louis-Noel Pouchet <pouchet.ohio-state.edu>
 *   Tomofumi Yuki <tomofumi.yuki.fr>
 *
 * Web address: http://polybench.sourceforge.net
 */
/* 2mm.c: this file is part of PolyBench/C */

#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <math.h>

/* Include polybench common header. */
#include <polybench.h>

/* Include benchmark-specific header. */
#include "2mm.h"


/* Array initialization. */
static
void init_array(int ni, int nj, int nk, int nl,
		DATA_TYPE *alpha,
		DATA_TYPE *beta,
		DATA_TYPE POLYBENCH_2D(A,NI,NK,ni,nk),
		DATA_TYPE POLYBENCH_2D(B,NK,NJ,nk,nj),
		DATA_TYPE POLYBENCH_2D(C,NJ,NL,nj,nl),
		DATA_TYPE POLYBENCH_2D(D,NI,NL,ni,nl))
{
  int i, j;

  *alpha = 1.5;
  *beta = 1.2;
  for (i = 0; i < ni; i++)
    for (j = 0; j < nk; j++)
      A[i][j] = (DATA_TYPE) ((i*j+1) % ni) / ni;
  for (i = 0; i < nk; i++)
    for (j = 0; j < nj; j++)
      B[i][j] = (DATA_TYPE) (i*(j+1) % nj) / nj;
  for

In [None]:
def get_ps_sols(n: int, source_code: str, incorrect_ps_sols: set[str]) -> list[str]:

    # Define the prompt structure as mentioned in the paper.
    prompt_template = """
    Your task is to rewrite the given '2mm' C benchmark into Python using PyTorch. You need to use only the
    provided functions and constants to achieve this. The rewritten program should be semantically equivalent to the
    '2mm' benchmark.

    Use only the following functions:
    # torch.matmul
    # torch.add
    # torch.mul
    # torch.zeros
    # Avoid using loops. Simplify the code as much as possible. Use functional constructs instead of for-loops.
    # Rewrite the above C benchmark in Python using the provided functions. Your Python code should use PyTorch for matrix operations.

    Here is the C function:
    {source_code}

    """

    # Prepare the final prompt with the source code.
    prompt = prompt_template.format(source_code=source_code)
    
    # Generate multiple PS using the LLM.
    ps_sols = []
    
    for _ in range(n):
        # Use the OpenAI API (or your LLM service) to get a program summary.
        # Assuming Gemini is available via an API.
        try:
            response = openai.Completion.create(
                engine="gemini", 
                prompt=prompt,
                max_tokens=300,
                n=1,
                stop=["\n"]
            )
            ps_sol = response.choices[0].text.strip()

            # Check if the generated PS is in the incorrect PS list
            if ps_sol not in incorrect_ps_sols:
                ps_sols.append(ps_sol)
        except Exception as e:
            print(f"Error generating PS: {e}")
    
    return ps_sols


In [None]:
def transpile_code(source_code: str, num_iters: int, n: int) -> str:
    for _ in range(num_iters):
        incorrect_ps_sols: set[str] = set()
        seen_ps_sols: set[str] = set()
            # Get Program Summaries (PS) using the LLM to translate to PyTorch.
        ps_sols: list[str] = get_ps_sols(n, source_code, incorrect_ps_sols)

        for ps_sol in ps_sols:
                # Skip if we have processed this PS before
            if ps_sol in seen_ps_sols:
                    continue

            if not parse(ps_sol):
                    continue
            
            seen_inv_sols_for_ps: set[str] = set()
            inv_sols: list[str] = get_inv_sols_for_ps(n, ps_sol)

            for inv_sol in inv_sols:
                if inv_sol in seen_inv_sols_for_ps:
                        continue
                
                if not parse(inv_sol):
                        continue
                
                if verify(inv_sol, ps_sol):
                        return ps_sol
                
                 # Mark the invariant as processed
                seen_inv_sols_for_ps.add(inv_sol)
            
            # If none of the invariants work, mark this PS as incorrect
            incorrect_ps_sols.add(ps_sol)
            seen_ps_sols.add(ps_sol)
    
    # If no valid solution is found
    return None