In [None]:
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install accelerate transformers einops datasets

In [None]:
import re
import csv
import cmath
import json
import builtins
import math
import ast
import torch
import pandas as pd
import numpy as np
import sympy as sp
from math import pi, exp
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sympy import (
    sympify, lambdify, Function, Eq, Derivative, integrate, Sum, sqrt, exp, 
    log, sin, cos, tan, cot, sec, csc, pi, symbols, diff, Rational, factorial, 
    expand, factor, trigsimp, I, parse_expr, latex)

In [None]:
# using cache dir
cache_dir='/workspace/models'
# Load the fine-tuned model
model = AutoModelForSeq2SeqLM.from_pretrained("/workspace/saad/project/try3/codet5_finetuned_latex_python_amdu2",cache_dir=cache_dir).to('cuda')
tokenizer = AutoTokenizer.from_pretrained("/workspace/saad/project/try3/codet5_finetuned_latex_python_amdu2",cache_dir=cache_dir)

In [None]:
# Load test data
with open('/workspace/SKYLABS PROJECT_dataset/private_test_new_no_sol_no_out.json') as f:
    test_data = json.load(f)

In [None]:
sp.ds = sp.dsolve
pi = 3.14159265358979
fresnels = sp.fresnels
fresnelc = sp.fresnelc

# Function to extract the function name and parameters from the generated code
# Define the alphabetic characters from a to z
a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z = sp.symbols('a:z')

# Define the additional symbols
x_symb, y_symb, z_symb, d1, d2 = sp.symbols('x_symb y_symb z_symb d1 d2')

def extract_function_name_and_params(code_str):
    match = re.search(r'def\s+(\w+)\s*\((.*?)\):', code_str)
    if match:
        function_name = match.group(1)
        params = [p.strip() for p in match.group(2).split(',') if p.strip()]
        
        sum_matches = re.findall(r'sum\(.+?\*\*(\w+)\s+for\s+\w+\s+in\s+range', code_str)
        for var in sum_matches:
            if var not in params:
                params.append(var)
        
        return function_name, params
    return None, []
    

def evaluate_sum_expressions(code_str):
    sum_pattern = r'Sum\((.*?)\*\*(\w+),\s*\(\2,\s*(\d+),\s*(\d+)\)\)'
    
    def evaluate_sum(match):
        base = match.group(1)
        var = match.group(2)
        start = int(match.group(3))
        end = int(match.group(4))
        result = sum(eval(f"{base}**{i}") for i in range(start, end + 1))
        return str(result)
    
    return re.sub(sum_pattern, evaluate_sum, code_str)

def run_python_code(code_str, test_cases):
    def balance_parentheses(code):
        open_count = code.count('(')
        close_count = code.count(')')
        return open_count, close_count

    function_name, params = extract_function_name_and_params(code_str)
    
    # Extract existing default values
    param_defaults = {}
    for param in params:
        if '=' in param:
            name, default = param.split('=')
            param_defaults[name.strip()] = default.strip()
    
    # Create new parameter string
    new_params = []
    for p in params:
        if '=' in p:
            new_params.append(p)  # Keep existing default values
        else:
            if p in test_cases[0]['input']:
                new_params.append(p)  # Keep as is if in test input
            elif p in param_defaults:
                new_params.append(f"{p}={param_defaults[p]}")  # Use existing default
            else:
                new_params.append(f"{p}=1")  # Add default of 1 if not specified
    
    new_param_str = ', '.join(new_params)
    code_str = re.sub(r'def\s+\w+\s*\(.*?\):', f'def {function_name}({new_param_str}):', code_str)
    
    # Check and fix parentheses
    open_count, close_count = balance_parentheses(code_str)
    
    if open_count > close_count:
        code_str += ')' * (open_count - close_count)
    elif close_count > open_count:
        while close_count > open_count:
            last_close_idx = code_str.rfind(')')
            if last_close_idx != -1:
                code_str = code_str[:last_close_idx] + code_str[last_close_idx + 1:]
                close_count -= 1

    results = []
    try:
        # Evaluate all Sum expressions with powers
        code_str = evaluate_sum_expressions(code_str)
        
        compiled_code = compile(code_str, '<string>', 'exec')
        exec(compiled_code, globals())
    except SyntaxError as e:
        return [f"Error: {e}"] * len(test_cases)
    except Exception as e:
        return [f"Error: {e}"] * len(test_cases)
    
    func = globals().get(function_name)
    
    if func:
        for test_case in test_cases:
            try:
                filtered_input = {k: v for k, v in test_case['input'].items() if k in params}
                result = func(**filtered_input)
                
                if isinstance(result, sp.Expr):
                    result = result.evalf()
                
                if isinstance(result, (int, float, complex, sp.Number)):
                    result = complex(result)
                    result = complex(round(result.real, 6), round(result.imag, 6))
                
                result_str = str(result).replace('I', 'j')
                
                if result_str.endswith('+ 0j') or result_str.endswith('- 0j'):
                    result_str = result_str[:-4]
                
                try:
                    results.append(float(result_str))
                except ValueError:
                    results.append(result_str)
                
            except Exception as e:
                error_message = str(e)
                if "Can't calculate derivative wrt" in error_message:
                    x_val = filtered_input.get('x', None)
                    if x_val is not None:
                        x = sp.Symbol('x')
                        symbolic_result = func(**{k: (x if k == 'x' else v) for k, v in filtered_input.items()})
                        result = symbolic_result.subs(x, x_val).evalf()
                        
                        if isinstance(result, sp.Expr):
                            result = result.evalf()
                        
                        if isinstance(result, (int, float, complex, sp.Number)):
                            result = complex(result)
                            result = complex(round(result.real, 6), round(result.imag, 6))
                        
                        result_str = str(result).replace('I', 'j')
                        #down there 
                        if isinstance(result_str, str) and any(var in result_str for var in ['_symb', 'Symbol']):
                            expr = sp.sympify(result_str)
                            symbols = expr.free_symbols
                            subs_dict = {}
                            for symbol in symbols:
                                symbol_name = str(symbol).replace('_symb', '')
                                if symbol_name in filtered_input:
                                    subs_dict[symbol] = filtered_input[symbol_name]
                            if subs_dict:
                                result = expr.subs(subs_dict).evalf()
                                result_str = str(result)
                        #up there useful or not if not then del
                        if result_str.endswith('+ 0j') or result_str.endswith('- 0j'):
                            result_str = result_str[:-4]
                        
                        try:
                            results.append(float(result_str))
                        except ValueError:
                            results.append(result_str)
                    else:
                        results.append(f"Error: {error_message}")
                else:
                    print(f"Error in test case {test_case}: {e}")
                    results.append(f"Error: {str(e)}")
    else:
        results = ["Function not found"] * len(test_cases)
    
    return results

def process_batch(batch):
    batch_results = []
    for item in batch:
        try:
            input_text = "Convert LaTeX to Python: " + item['latex_expression']
            inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to('cuda')
            
            with torch.no_grad():
                outputs = model.generate(**inputs, max_length=512)
            
            generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
            print("Generated code:")
            print(generated_code)
            test_cases = item['test_cases']
            
            results = run_python_code(generated_code, test_cases)
            
            formatted_results = []
            for result in results:
                if isinstance(result, float):
                    formatted_results.append(f"{result:.6f}")
                elif isinstance(result, complex):
                    real = round(result.real, 6)
                    imag = round(result.imag, 6)
                    if imag == 0:
                        formatted_results.append(f"{real:.6f}")
                    else:
                        formatted_results.append(f"{real:.6f}{'+' if imag >= 0 else '-'}{abs(imag):.6f}j")
                else:
                    formatted_results.append(str(result))
            
            batch_results.append({
                'task_id': item['task_id'],
                'outputs': formatted_results
            })
        except Exception as e:
            print(f"Error processing item {item['task_id']}: {str(e)}")
            batch_results.append({
                'task_id': item['task_id'],
                'outputs': [f"Error: {str(e)}"] * len(item['test_cases'])
            })
    
    return batch_results

# Prepare results list and set batch size
results_list = []
batch_size = 8  # Adjust based on your GPU memory
# Process each batch of items in the test data
for start_idx in range(0, len(test_data), batch_size):
    batch = test_data[start_idx:start_idx + batch_size]
    batch_results = process_batch(batch)
    results_list.extend(batch_results)
    print(f"Processed batch starting at index {start_idx}")
    
# Save results to CSV. Note: There are 3 files made, Final_Results.csv is the submission file
csv_filename = 'results1.csv'
with open(csv_filename, 'w', newline='') as csvfile:
    fieldnames = ['id', 'outputs']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    for result in results_list:
        task_id = result['task_id']
        outputs = result['outputs']
        formatted_outputs = [output if isinstance(output, str) else f"{output:.6f}" for output in outputs]
        writer.writerow({
            'id': task_id,
            'outputs': str(formatted_outputs)  # Convert list to string for CSV
        })
#print(f"Results saved to {csv_filename}")

# Load the CSV file
df = pd.read_csv('results1.csv')
# Function to clean and handle errors
def process_output(output):
    # Check if the output is a string and contains specific error messages
    if isinstance(output, str):
        if "Syntax error" in output or "Error" in output:
            # Return a string of zeros with the same number of elements as in the original list
            num_elements = len(eval(output))  # Count the number of elements in the list
            return '[' + ', '.join('0' for _ in range(num_elements)) + ']'
        else:
            # Remove parentheses if it's a valid entry
            return re.sub(r'[()]', '', output)
    return output
# Apply the function to the 'outputs' column
df['outputs'] = df['outputs'].apply(process_output)
# Save the updated DataFrame to a new CSV file
df.to_csv('results2.csv', index=False)
#print("Entries processed and saved to 'results2.csv'")

# Read the CSV file
df = pd.read_csv('results2.csv')

# Function to remove '+0j' from the outputs and replace '4*j' with '4j'
def clean_complex_output(output_list):
    # Convert string representation of list to an actual list
    output_list = ast.literal_eval(output_list)
    
    # Function to clean each item
    def clean_item(item):
        item_str = str(item)
        # Remove '+0j'
        item_str = item_str.replace('+0j', '')
        # Replace '4*j' with '4j'
        item_str = re.sub(r'(\d+)\*j', r'\1j', item_str)
        return item_str
    
    # Apply cleaning to each item in the list
    cleaned_list = [clean_item(item) for item in output_list]
    return cleaned_list

# Apply the function to the 'outputs' column
df['outputs'] = df['outputs'].apply(clean_complex_output)

# Save the cleaned DataFrame back to a CSV
df.to_csv('Final_Results.csv', index=False)
print("Done! The results have been saved to 'Final_Results.csv'.")