# KRAIT GPU Executor - Updated with Compilation Support

This notebook monitors GitHub for kernel files and handles both compilation and execution requests.


In [None]:
# Install required dependencies
%pip install nvidia-ml-py3 pynvml
%pip install gitpython
%pip install requests
%pip install torch
%pip install numpy
%pip install triton


In [None]:
# Set up environment variables
# Run this cell to set your GitHub token
# Replace 'your_github_token_here' with your actual token
import os
os.environ['GITHUB_TOKEN'] = 'your_github_token_here'  # Replace with your actual token
print("✅ Environment variable set. You can now run the next cell.")


In [None]:
import subprocess
import json
import time
import os
import git
import requests
from pathlib import Path
import torch
import numpy as np
from datetime import datetime
import re
import base64

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current GPU: {torch.cuda.get_device_name(0)}")

# GitHub API configuration
GITHUB_OWNER = "kcharvi"
GITHUB_REPO = "KRAIT"
GITHUB_API_BASE = "https://api.github.com"


In [None]:
# Set up GitHub token from environment variable
import os
GITHUB_TOKEN = os.getenv('GITHUB_TOKEN', 'YOUR_ACTUAL_GITHUB_TOKEN_HERE')

if GITHUB_TOKEN == 'YOUR_ACTUAL_GITHUB_TOKEN_HERE':
    print("⚠️ WARNING: GITHUB_TOKEN environment variable not set!")
    print("Please set your GitHub token in the environment or replace the placeholder above.")
    print("You can set it by running: !export GITHUB_TOKEN='your_token_here'")
else:
    print(f"✅ GitHub token loaded from environment (first 10 chars: {GITHUB_TOKEN[:10]}...)")

# Test GitHub API connection
def test_github_connection():
    """Test GitHub API connection"""
    try:
        url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}"
        headers = {
            "Authorization": f"token {GITHUB_TOKEN}",
            "Accept": "application/vnd.github.v3+json"
        }
        response = requests.get(url, headers=headers)
        if response.status_code == 200:
            print("✅ GitHub API connection successful")
            return True
        else:
            print(f"❌ GitHub API connection failed: {response.status_code}")
            return False
    except Exception as e:
        print(f"❌ GitHub API connection error: {e}")
        return False

# Test connection
test_github_connection()


In [None]:
# Repository configuration
REPO_URL = "https://github.com/kcharvi/KRAIT.git"  # Replace with your actual repo URL
REPO_DIR = "/content/krait"
KERNELS_DIR = f"{REPO_DIR}/gpu-executor/kernels"
RESULTS_DIR = f"{REPO_DIR}/gpu-executor/results"

# Clone or update repository
if not os.path.exists(REPO_DIR):
    print(f"Cloning repository from {REPO_URL}")
    repo = git.Repo.clone_from(REPO_URL, REPO_DIR)
else:
    print(f"Updating existing repository")
    repo = git.Repo(REPO_DIR)
    repo.remotes.origin.pull()

# Create directories if they don't exist
os.makedirs(KERNELS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

print(f"Repository setup complete")
print(f"Kernels directory: {KERNELS_DIR}")
print(f"Results directory: {RESULTS_DIR}")


In [None]:
def detect_code_type(kernel_code):
    """Detect if code is CUDA C++ or Triton Python"""
    if "@triton.jit" in kernel_code or "import triton" in kernel_code:
        return "triton"
    elif "__global__" in kernel_code or "#include" in kernel_code:
        return "cuda"
    else:
        # Default to CUDA for .cu files
        return "cuda"

def parse_metadata(kernel_content):
    """Parse metadata from kernel file"""
    metadata = {
        "hardware": "NVIDIA T4",
        "backend": "CUDA",
        "timestamp": int(time.time()),
        "type": "execute"  # "execute" or "compile_only"
    }
    
    lines = kernel_content.split('\n')
    for line in lines[:10]:  # Check first 10 lines for metadata
        if "// Hardware:" in line:
            metadata["hardware"] = line.split(":", 1)[1].strip()
        elif "// Backend:" in line:
            metadata["backend"] = line.split(":", 1)[1].strip()
        elif "// Timestamp:" in line:
            try:
                metadata["timestamp"] = int(line.split(":", 1)[1].strip())
            except:
                pass
        elif "// Type:" in line:
            metadata["type"] = line.split(":", 1)[1].strip()
    
    return metadata

def clean_kernel_code(kernel_content):
    """Remove metadata comments from kernel code and fix common issues"""
    lines = kernel_content.split('\n')
    cleaned_lines = []
    
    skip_metadata = False
    defines = []
    other_lines = []
    
    for line in lines:
        if line.strip() == "// COMPILATION REQUEST" or line.strip() == "// EXECUTION REQUEST":
            skip_metadata = True
            continue
        elif skip_metadata and line.strip() and not line.strip().startswith("//"):
            skip_metadata = False
        
        if not skip_metadata:
            # Collect #define statements
            if line.strip().startswith("#define"):
                defines.append(line)
            else:
                other_lines.append(line)
    
    # Combine: defines first, then other code
    result_lines = defines + other_lines
    return '\n'.join(result_lines).strip()


In [None]:
def compile_cuda_kernel(kernel_file_path, kernel_content):
    """Compile CUDA kernel and return compilation results"""
    try:
        print(f"Compiling CUDA kernel: {kernel_file_path}")

        # Compile kernel
        compile_cmd = f"nvcc -o kernel_test {kernel_file_path} -lnvToolsExt --ptxas-options=-v"
        print(f"Compilation command: {compile_cmd}")

        result = subprocess.run(compile_cmd, shell=True, capture_output=True, text=True, timeout=60)

        if result.returncode == 0:
            print("✅ CUDA compilation successful")
            return {
                "success": True,
                "message": "CUDA compilation successful",
                "warnings": result.stderr if result.stderr else [],
                "provider": "colab",
                "timestamp": time.time()
            }
        else:
            print(f"❌ CUDA compilation failed: {result.stderr}")
            return {
                "success": False,
                "error": f"CUDA compilation failed: {result.stderr}",
                "provider": "colab",
                "timestamp": time.time()
            }

    except subprocess.TimeoutExpired:
        error_msg = "CUDA compilation timeout (60s)"
        print(f"❌ {error_msg}")
        return {
            "success": False,
            "error": error_msg,
            "provider": "colab",
            "timestamp": time.time()
        }
    except Exception as e:
        error_msg = f"CUDA compilation error: {str(e)}"
        print(f"❌ {error_msg}")
        return {
            "success": False,
            "error": error_msg,
            "provider": "colab",
            "timestamp": time.time()
        }

def compile_triton_kernel(kernel_content):
    """Validate Triton kernel syntax"""
    try:
        print(f"Validating Triton kernel syntax")

        # Basic syntax validation
        if "@triton.jit" in kernel_content and "import triton" in kernel_content:
            print("✅ Triton syntax validation successful")
            return {
                "success": True,
                "message": "Triton syntax validation successful",
                "provider": "colab",
                "timestamp": time.time()
            }
        else:
            error_msg = "Invalid Triton syntax: missing @triton.jit decorator or import triton"
            print(f"❌ {error_msg}")
            return {
                "success": False,
                "error": error_msg,
                "provider": "colab",
                "timestamp": time.time()
            }

    except Exception as e:
        error_msg = f"Triton validation error: {str(e)}"
        print(f"❌ {error_msg}")
        return {
            "success": False,
            "error": error_msg,
            "provider": "colab",
            "timestamp": time.time()
        }


In [None]:
def upload_to_github_git_api_fixed(file_path, content, commit_message):
    """Upload file using Git API with proper branch handling"""
    try:
        print(f"🔄 Uploading {file_path} using Git API...")
        
        headers = {
            "Authorization": f"token {GITHUB_TOKEN}",
            "Accept": "application/vnd.github.v3+json"
        }
        
        # Get the latest commit from the main branch
        ref_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/refs/heads/main"
        print("Getting latest commit from main branch...")
        ref_response = requests.get(ref_url, headers=headers)
        if ref_response.status_code != 200:
            print(f"❌ Failed to get branch reference: {ref_response.text}")
            return False
        
        latest_commit_sha = ref_response.json()['object']['sha']
        print(f"Latest commit SHA: {latest_commit_sha}")
        
        # Get the commit details
        commit_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/commits/{latest_commit_sha}"
        print("Getting commit details...")
        commit_response = requests.get(commit_url, headers=headers)
        if commit_response.status_code != 200:
            print(f"❌ Failed to get commit: {commit_response.text}")
            return False
        
        commit_data = commit_response.json()
        current_tree_sha = commit_data['tree']['sha']
        print(f"Current tree SHA: {current_tree_sha}")
        
        # Create blob with content
        content_b64 = base64.b64encode(content.encode('utf-8')).decode('utf-8')
        blob_data = {
            "content": content_b64,
            "encoding": "base64"
        }
        
        blob_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/blobs"
        print("🔍 Creating blob...")
        blob_response = requests.post(blob_url, headers=headers, json=blob_data)
        if blob_response.status_code != 201:
            print(f"❌ Failed to create blob: {blob_response.text}")
            return False
        
        blob_sha = blob_response.json()['sha']
        print(f"Blob SHA: {blob_sha}")
        
        # Get current tree
        tree_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/trees/{current_tree_sha}"
        print("🔍 Getting current tree...")
        tree_response = requests.get(tree_url, headers=headers)
        if tree_response.status_code != 200:
            print(f"❌ Failed to get tree: {tree_response.text}")
            return False
        
        tree_data = tree_response.json()
        tree_items = tree_data['tree']
        
        # Add our new file to the tree
        new_tree_items = []
        file_added = False
        
        for item in tree_items:
            if item['path'] == file_path:
                # Update existing file
                new_tree_items.append({
                    "path": file_path,
                    "mode": "100644",
                    "type": "blob",
                    "sha": blob_sha
                })
                file_added = True
                print(f"📝 Updating existing file: {file_path}")
            else:
                new_tree_items.append(item)
        
        if not file_added:
            # Add new file
            new_tree_items.append({
                "path": file_path,
                "mode": "100644",
                "type": "blob",
                "sha": blob_sha
            })
            print(f"Adding new file: {file_path}")
        
        # Create new tree
        new_tree_data = {
            "base_tree": current_tree_sha,
            "tree": new_tree_items
        }
        
        new_tree_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/trees"
        print("🔍 Creating new tree...")
        new_tree_response = requests.post(new_tree_url, headers=headers, json=new_tree_data)
        if new_tree_response.status_code != 201:
            print(f"❌ Failed to create tree: {new_tree_response.text}")
            return False
        
        new_tree_sha = new_tree_response.json()['sha']
        print(f"New tree SHA: {new_tree_sha}")
        
        # Create new commit
        new_commit_data = {
            "message": commit_message,
            "tree": new_tree_sha,
            "parents": [latest_commit_sha]
        }
        
        new_commit_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/commits"
        print("🔍 Creating new commit...")
        new_commit_response = requests.post(new_commit_url, headers=headers, json=new_commit_data)
        if new_commit_response.status_code != 201:
            print(f"❌ Failed to create commit: {new_commit_response.text}")
            return False
        
        new_commit_sha = new_commit_response.json()['sha']
        print(f"New commit SHA: {new_commit_sha}")
        
        # Update branch reference with force update
        ref_data = {
            "sha": new_commit_sha,
            "force": True
        }
        
        print("🔍 Updating branch reference...")
        ref_response = requests.patch(ref_url, headers=headers, json=ref_data)
        if ref_response.status_code != 200:
            print(f"❌ Failed to update branch: {ref_response.text}")
            return False
        
        print(f"✅ Successfully uploaded to GitHub: {file_path}")
        return True
        
    except Exception as e:
        print(f"❌ Error uploading to GitHub: {e}")
        return False

# Test the fixed Git API upload function
print("🔧 Testing Fixed Git API upload...")
test_content = '{"test": "fixed git api upload", "success": true}'
test_result = upload_to_github_git_api_fixed("gpu-executor/results/test_fixed_git_api.json", test_content, "Test Fixed Git API upload")
print(f"Test upload result: {test_result}")


In [None]:
def upload_to_github_git_api(file_path, content, commit_message):
    """Upload file using Git API instead of Contents API"""
    try:
        print(f"🔄 Uploading {file_path} using Git API...")
        
        # Get current commit
        commit_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/commits/d6c27a0c97f4ba35a95b2a2d4beafcc6c1d379dd"
        headers = {
            "Authorization": f"token {GITHUB_TOKEN}",
            "Accept": "application/vnd.github.v3+json"
        }
        
        print("Getting current commit...")
        response = requests.get(commit_url, headers=headers)
        if response.status_code != 200:
            print(f"❌ Failed to get commit: {response.text}")
            return False
        
        commit_data = response.json()
        current_tree_sha = commit_data['tree']['sha']
        print(f"Current tree SHA: {current_tree_sha}")
        
        # Create blob with content
        content_b64 = base64.b64encode(content.encode('utf-8')).decode('utf-8')
        blob_data = {
            "content": content_b64,
            "encoding": "base64"
        }
        
        blob_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/blobs"
        print("🔍 Creating blob...")
        blob_response = requests.post(blob_url, headers=headers, json=blob_data)
        if blob_response.status_code != 201:
            print(f"❌ Failed to create blob: {blob_response.text}")
            return False
        
        blob_sha = blob_response.json()['sha']
        print(f"Blob SHA: {blob_sha}")
        
        # Get current tree
        tree_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/trees/{current_tree_sha}"
        print("🔍 Getting current tree...")
        tree_response = requests.get(tree_url, headers=headers)
        if tree_response.status_code != 200:
            print(f"❌ Failed to get tree: {tree_response.text}")
            return False
        
        tree_data = tree_response.json()
        tree_items = tree_data['tree']
        
        # Add our new file to the tree
        new_tree_items = []
        file_added = False
        
        for item in tree_items:
            if item['path'] == file_path:
                # Update existing file
                new_tree_items.append({
                    "path": file_path,
                    "mode": "100644",
                    "type": "blob",
                    "sha": blob_sha
                })
                file_added = True
                print(f"📝 Updating existing file: {file_path}")
            else:
                new_tree_items.append(item)
        
        if not file_added:
            # Add new file
            new_tree_items.append({
                "path": file_path,
                "mode": "100644",
                "type": "blob",
                "sha": blob_sha
            })
            print(f"Adding new file: {file_path}")
        
        # Create new tree
        new_tree_data = {
            "base_tree": current_tree_sha,
            "tree": new_tree_items
        }
        
        new_tree_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/trees"
        print("🔍 Creating new tree...")
        new_tree_response = requests.post(new_tree_url, headers=headers, json=new_tree_data)
        if new_tree_response.status_code != 201:
            print(f"❌ Failed to create tree: {new_tree_response.text}")
            return False
        
        new_tree_sha = new_tree_response.json()['sha']
        print(f"New tree SHA: {new_tree_sha}")
        
        # Create new commit
        new_commit_data = {
            "message": commit_message,
            "tree": new_tree_sha,
            "parents": [commit_data['sha']]
        }
        
        new_commit_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/commits"
        print("🔍 Creating new commit...")
        new_commit_response = requests.post(new_commit_url, headers=headers, json=new_commit_data)
        if new_commit_response.status_code != 201:
            print(f"❌ Failed to create commit: {new_commit_response.text}")
            return False
        
        new_commit_sha = new_commit_response.json()['sha']
        print(f"New commit SHA: {new_commit_sha}")
        
        # Update branch reference
        ref_url = f"{GITHUB_API_BASE}/repos/{GITHUB_OWNER}/{GITHUB_REPO}/git/refs/heads/main"
        ref_data = {
            "sha": new_commit_sha
        }
        
        print("🔍 Updating branch reference...")
        ref_response = requests.patch(ref_url, headers=headers, json=ref_data)
        if ref_response.status_code != 200:
            print(f"❌ Failed to update branch: {ref_response.text}")
            return False
        
        print(f"✅ Successfully uploaded to GitHub: {file_path}")
        return True
        
    except Exception as e:
        print(f"❌ Error uploading to GitHub: {e}")
        return False

# Test the Git API upload function
print("🔧 Testing Git API upload...")
test_content = '{"test": "git api upload", "success": true}'
test_result = upload_to_github_git_api("gpu-executor/results/test_git_api.json", test_content, "Test Git API upload")
print(f"Test upload result: {test_result}")

In [None]:
def process_kernel_file_final(kernel_file):
    """Final process function with enhanced GitHub upload using Git API"""
    try:
        print(f"\n--- Processing kernel: {kernel_file.name} ---")
        print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        
        # Read kernel content
        with open(kernel_file, 'r') as f:
            kernel_content = f.read()
        
        # Parse metadata
        metadata = parse_metadata(kernel_content)
        print(f"Metadata: {metadata}")
        
        # Clean kernel code
        clean_code = clean_kernel_code(kernel_content)
        
        # Detect code type
        code_type = detect_code_type(clean_code)
        print(f"Detected code type: {code_type}")
        
        # Write cleaned code to file for compilation
        with open(kernel_file, 'w') as f:
            f.write(clean_code)
        
        # Process based on request type
        if metadata["type"] == "compile_only":
            print("Processing compilation request...")
            if code_type == "cuda":
                result = compile_cuda_kernel(str(kernel_file), clean_code)
            else:
                result = compile_triton_kernel(clean_code)
        else:
            print("Processing execution request...")
            # For execution, we'll just do compilation for now
            if code_type == "cuda":
                result = compile_cuda_kernel(str(kernel_file), clean_code)
            else:
                result = compile_triton_kernel(clean_code)
        
        # Add corrected code to result
        if result.get("success", False):
            result["corrected_code"] = clean_code
        
        # Determine result filename based on request type
        if metadata["type"] == "compile_only":
            result_file = f"{RESULTS_DIR}/compile_{metadata['timestamp']}_result.json"
        else:
            result_file = f"{RESULTS_DIR}/kernel_{metadata['timestamp']}_result.json"
        
        # Save result locally
        with open(result_file, 'w') as f:
            json.dump(result, f, indent=2)
        
        print(f"Result saved locally to: {result_file}")
        print(f"Result: {json.dumps(result, indent=2)}")
        
        # Upload result to GitHub using Git API function
        result_path = f"gpu-executor/results/{os.path.basename(result_file)}"
        with open(result_file, 'r') as f:
            result_content = f.read()
        
        upload_success = upload_to_github_git_api_fixed(result_path, result_content, f"Result {metadata['timestamp']}")
        
        # If compilation was successful, also save the corrected kernel code to GitHub
        if result.get("success", False):
            corrected_kernel_path = f"gpu-executor/kernels/corrected_{metadata['timestamp']}.cu"
            upload_success = upload_to_github_git_api_fixed(corrected_kernel_path, clean_code, f"Corrected kernel {metadata['timestamp']}") and upload_success
        
        if upload_success:
            print(f"✅ All uploads successful")
        else:
            print(f"⚠️ Some uploads failed, but processing complete")
        
        # Wait a bit to ensure backend can fetch the result
        time.sleep(5)
        
        # Remove processed kernel file locally
        kernel_file.unlink()
        print(f"Kernel file removed locally: {kernel_file.name}")
        
        # Note: Backend handles GitHub cleanup automatically
        print(f"ℹ️ Backend will handle GitHub cleanup automatically")
        
        print(f"--- Processing complete ---\n")
        
        return True
        
    except Exception as e:
        error_msg = f"Error processing {kernel_file.name}: {str(e)}"
        print(f"❌ {error_msg}")
        
        # Save error result
        try:
            metadata = parse_metadata(kernel_content) if 'kernel_content' in locals() else {"timestamp": int(time.time())}
            error_result = {
                "success": False,
                "error": error_msg,
                "provider": "colab",
                "timestamp": time.time()
            }
            
            result_file = f"{RESULTS_DIR}/kernel_{metadata['timestamp']}_result.json"
            with open(result_file, 'w') as f:
                json.dump(error_result, f, indent=2)
            
            print(f"Error result saved to: {result_file}")
        except:
            print("Failed to save error result")
        
        return False


In [None]:
def monitor_kernels_final():
    """Final monitoring function with enhanced GitHub upload using Git API"""
    print(f"🚀 Starting KRAIT GPU Executor - Final Version with Git API")
    print(f"📁 Monitoring for both compilation and execution requests")
    print(f"⚡ Ready to process kernels...")
    print(f"Watching directory: {KERNELS_DIR}")
    
    processed_files = set()
    git_error_count = 0
    max_git_errors = 5
    
    while True:
        try:
            # Pull latest changes from GitHub
            try:
                repo.remotes.origin.pull()
                git_error_count = 0  # Reset error count on success
            except Exception as e:
                git_error_count += 1
                # Handle broken pipe and other Git errors gracefully
                if "Broken pipe" in str(e) or "Errno 32" in str(e):
                    print(f"Git connection issue ({git_error_count}/{max_git_errors}): {e}")
                    # Try to reinitialize the connection
                    try:
                        repo.remotes.origin.fetch()
                    except:
                        pass
                else:
                    print(f"Warning: Failed to pull from GitHub ({git_error_count}/{max_git_errors}): {e}")
                
                # If too many Git errors, skip this cycle
                if git_error_count >= max_git_errors:
                    print("Too many Git errors, skipping this cycle...")
                    time.sleep(60)  # Wait longer before retrying
                    git_error_count = 0
            
            # Note: Backend handles cleanup automatically
            
            # Check for new kernel files
            kernel_files = list(Path(KERNELS_DIR).glob("*.cu"))
            
            for kernel_file in kernel_files:
                if kernel_file.name not in processed_files:
                    success = process_kernel_file_final(kernel_file)
                    if success:
                        processed_files.add(kernel_file.name)
            
            if not kernel_files:
                print(f".", end="", flush=True)  # Show activity
            
            time.sleep(30)  # Check every 30 seconds to reduce Git load
            
        except KeyboardInterrupt:
            print("\nMonitoring stopped by user")
            break
        except Exception as e:
            print(f"\nError in monitoring loop: {e}")
            time.sleep(30)  # Wait longer on error


# Start final monitoring
print("\n🚀 Starting Final Monitoring with Git API Upload...")
monitor_kernels_final()
