In [None]:
import dataclasses
import multiprocessing
import re
import time
import os
import sys
import math
from pathlib import Path
from typing import Any, Optional
import yaml

import torch.cuda

from utils import set_seed
try:
    from task import TestSpec
except ImportError:
    TestSpec = dict

from reference import check_implementation, generate_input


@dataclasses.dataclass
class TestCase:
    args: dict
    spec: str
    memory_usage: Optional[float] = None
    FLOPs: Optional[int] = None

def _combine(a: int, b: int) -> int:
    # combine two integers into one:
    # we need this to generate a secret seed based on the test-level seed and
    # the global secret seed.
    # the test-level seeds are public knowledge, and typically relatively small numbers,
    # so we need to make sure they don't provide any useful info for the full seed.
    # This Cantor construction ensures that if the secret seed is a large number,
    # then so is the overall seed.
    return int(a + (a+b)*(a+b+1)//2)


def get_test_cases(file_name: str, data_type: str, seed: Optional[int]) -> list[TestCase]:
    try:
        with open(file_name, 'r') as file:
            data = yaml.safe_load(file)
    except Exception as E:
        print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr)
        exit(113)

    tests_data = data.get(data_type, None)
    if tests_data is None:
        print(f"Could not find test data for type `{data_type}` in file `{file_name}`", file=sys.stderr)
        exit(113)

    tests = []
    for data in tests_data:
        memory_usage = (data['m'] * data['n'] + data['k'] * data['n'] + data['k'] * data['m']) * 2 / 1024 / 1024
        FLOPs = (data['m'] * data['n'] * data['k']) * 2
        tests.append(TestCase(spec=str(data), args=data, memory_usage=memory_usage, FLOPs=FLOPs))

    if seed is not None:
        for test in tests:
            if "seed" in test.args:
                test.args["seed"] = _combine(test.args["seed"], seed)

    return tests



@dataclasses.dataclass
class Stats:
    runs: int
    mean: float
    std: float
    err: float
    best: float
    worst: float

def calculate_stats(durations: list[int]):
    """
    Calculate statistical data from a list of durations.

    @param durations: A list of durations in nanoseconds.
    @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations.
    """
    runs = len(durations)
    total = sum(durations)
    best = min(durations)
    worst = max(durations)

    avg = total / runs
    variance = sum(map(lambda x: (x - avg)**2, durations))
    std = math.sqrt(variance / (runs - 1))
    err = std / math.sqrt(runs)

    return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best),
                 worst=float(worst))

def _clone_data(data):
    """
    Recursively goes through data and clones all tensors.
    """
    if isinstance(data, tuple):
        return tuple(_clone_data(x) for x in data)
    elif isinstance(data, list):
        return [_clone_data(x) for x in data]
    elif isinstance(data, dict):
        return {k: _clone_data(v) for k, v in data.items()}
    elif isinstance(data, torch.Tensor):
        return data.clone()
    else:
        return data
        
def wrap_check_implementation(data, submission_output):
    # Old version returned just a single string, new version
    # returns (bool, str); this function ensures compatibility with old
    # problem definitions.
    result = check_implementation(data, submission_output)
    if isinstance(result, tuple):
        return result
    else:
        return not bool(result), result


In [None]:
# Define the path to the CUDA source file and the Python submission file
cuda_file_path = Path('./csrc_kernel/gemm_fp8_v1.hpp')
submission_file_path = Path('./hip_submission.py') # Assuming hip_submission.py is in the current directory

# Check if the CUDA source file exists
if not cuda_file_path.exists():
    print(f"Error: CUDA source file not found at {cuda_file_path}", file=sys.stderr)
else:
    # Read the content of the CUDA source file
    try:
        with open(cuda_file_path, 'r') as f:
            cuda_code = f.read()
    except Exception as e:
        print(f"Error reading CUDA source file {cuda_file_path}: {e}", file=sys.stderr)
        cuda_code = None

    if cuda_code is not None:
        # Prepare the new CUDA_SRC variable assignment string
        # We need to represent the multiline CUDA code as a raw string literal in Python
        new_src_assignment = f"CUDA_SRC = r'''\n{cuda_code}\n'''"

        # Check if the submission file exists
        if not submission_file_path.exists():
            print(f"Warning: Submission file {submission_file_path} not found. Creating it with the CUDA_SRC variable.", file=sys.stderr)
            try:
                with open(submission_file_path, 'w') as f:
                    f.write(new_src_assignment + "\n")
                print(f"Successfully created {submission_file_path} and wrote CUDA_SRC.")
            except Exception as e:
                print(f"Error creating or writing to {submission_file_path}: {e}", file=sys.stderr)
        else:
            # Read the content of the Python submission file
            try:
                with open(submission_file_path, 'r') as f:
                    submission_code_lines = f.readlines()
            except Exception as e:
                print(f"Error reading submission file {submission_file_path}: {e}", file=sys.stderr)
                submission_code_lines = None

            if submission_code_lines is not None:
                # Find and replace the CUDA_SRC variable assignment
                # This regex looks for CUDA_SRC = r'''...''' or CUDA_SRC = '''...''' or CUDA_SRC = "..." etc.
                # and handles multi-line strings.
                
                # A simpler approach for this specific problem might be to find the line starting with "CUDA_SRC ="
                # and replace it and subsequent lines if it's a multi-line string,
                # or just replace that line if it's a single line string.
                # For robustness, we'll try to replace a block.

                new_submission_code = []
                src_found = False
                in_src_block = False

                # Try to find an existing CUDA_SRC assignment
                # This pattern is very basic and might need adjustment based on actual CUDA_SRC format
                src_pattern = re.compile(r"^\s*CUDA_SRC\s*=\s*r?['\"]{3}.*?['\"]{3}\s*$", re.DOTALL | re.MULTILINE)
                
                try:
                    with open(submission_file_path, 'r') as f:
                        submission_content = f.read()
                    
                    if src_pattern.search(submission_content):
                        # If a complex multi-line CUDA_SRC string is found, replace it
                        updated_submission_content = src_pattern.sub(new_src_assignment, submission_content, count=1)
                    else:
                        # If not found, or if it's a simpler assignment, try a line-based replacement
                        # This is a fallback and might be less robust for complex existing CUDA_SRC assignments
                        temp_lines = []
                        replaced = False
                        for line in submission_code_lines:
                            if line.strip().startswith("CUDA_SRC =") and not replaced:
                                temp_lines.append(new_src_assignment + "\n")
                                replaced = True
                                # Skip lines if it was a multiline string; this part is tricky without knowing the exact old format
                                # For simplicity, we assume the old CUDA_SRC was either single line or we are just prepending/replacing.
                            elif replaced and (line.strip().startswith("'''") or line.strip().startswith('"""')): # Heuristic to skip old multiline end
                                if "CUDA_SRC =" not in new_src_assignment: # only skip if we are truly replacing a block
                                    continue
                                else:
                                    temp_lines.append(line)
                            elif not replaced or not (line.strip().startswith("'''") or line.strip().startswith('"""')):
                                temp_lines.append(line)
                        
                        if not replaced: # If CUDA_SRC = was not found at all, prepend it
                            updated_submission_content = new_src_assignment + "\n" + "".join(submission_code_lines)
                        else:
                            updated_submission_content = "".join(temp_lines)
                            
                    with open(submission_file_path, 'w') as f:
                        f.write(updated_submission_content)
                    print(f"Successfully updated CUDA_SRC in {submission_file_path}")

                except Exception as e:
                    print(f"Error processing or writing to submission file {submission_file_path}: {e}", file=sys.stderr)

In [None]:


def _run_single_test(test: TestCase):
    """
    Runs a single test case. Do not call directly
    """
    from hip_submission import custom_kernel

    data = generate_input(**test.args)
    start_time = time.time()
    torch.cuda.synchronize()
    submission_output =  custom_kernel(_clone_data(data))
    torch.cuda.synchronize()
    end_time = time.time()
    duration = float((end_time - start_time) * 1e3)  # convert to nanoseconds
    good, message = wrap_check_implementation(data, submission_output)
    return good, message, duration

def run_single_test(pool: multiprocessing.Pool, test: TestCase):
    """
    Runs a single test in another process.
    """
    return pool.apply(_run_single_test, (test,))
    


In [None]:

seed = 42
set_seed(seed or 42)
tests_data = get_test_cases('./task.yml', 'tests', seed)

# The 'multiprocessing' module is already imported in cell 0.
# The 'time' module is already imported in cell 0.

# It's good practice to initialize a flag for overall test status.
passed = True 
total_duration = 0.0
for idx, test in enumerate(tests_data):
    print(f"test.{idx}.name", test.spec)
    good, message, duration = _run_single_test(test)
    total_duration += duration
    if not good:
        print(f"test.{idx}.status", "fail")
        print(f"test.{idx}.error", message)
        passed = False
    else:
        print(f"test.{idx}.status", "pass")
        print(f"test.{idx}.duration {duration:.4f}ms")
        print(f"test.{idx}.TFLOPS {test.FLOPs/duration*1e-9:.4f}")
        if message:
            print(f"test.{idx}.message", f"{message}")
print(f"test.total_duration {total_duration:.4f}ms")


In [None]:
seed = 42
set_seed(seed or 42)
tests_data = get_test_cases('./task.yml', 'benchmarks', seed)

# import multiprocessing
# mp_context = multiprocessing.get_context('spawn')
# with mp_context.Pool(1) as pool:
#     for idx, test in enumerate(tests_data):
#         good, message = run_single_test(pool, test)
# import time


total_duration = 0
for idx, test in enumerate(tests_data):
    print(f"test.{idx}.name", test.spec)
    good, message, duration = _run_single_test(test)
    total_duration += duration
    if not good:
        print(f"test.{idx}.status", "fail")
        print(f"test.{idx}.error", message)
        passed = False
    else:
        print(f"test.{idx}.status", "pass")
        print(f"test.{idx}.duration {duration:.4f}ms")
        print(f"test.{idx}.TFLOPS {test.FLOPs/duration*1e-9:.4f}")
        if message:
            print(f"test.{idx}.message", f"{message}")
print(f"test.total_duration {total_duration:.4f}ms")

In [None]:

import numpy as np

A = np.array([[0, 1, 2, 3], [4, 5, 6, 7]])

offs_k = np.arange(0, 4)
k = offs_k[:,None]
offs_n = np.arange(0, 2)
n = offs_n[None,:]


