In [None]:
!pip install datasets evaluate transformers
!pip install accelerate

# Training

### 1. Load the Training Dataset

In [1]:
# @title 1. Load the dataset

from datasets import load_dataset
import json
dataset = load_dataset("json", data_files = "/workspace/Project/modified_v2.3.json")

Generating train split: 0 examples [00:00, ? examples/s]

### 2. Tokenize the Dataset

In [2]:
# @title 2. Tokenize the dataset
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, RobertaTokenizer, T5ForConditionalGeneration

torch.cuda.empty_cache()

# Load the tokenizer
# model_name = "Salesforce/codet5-large"        # for codet5large
# tokenizer = AutoTokenizer.from_pretrained(model_name)     # for codet5large

model_name = "Salesforce/codet5-base"         # for codet5base
tokenizer = RobertaTokenizer.from_pretrained(model_name)   # for codet5base

model = T5ForConditionalGeneration.from_pretrained(model_name)



In [3]:
# Define a tokenization function
def tokenize_function(batch):
    inputs = tokenizer(batch['latex_expression'], padding='max_length', truncation=True, max_length=256, return_tensors='pt')
    labels = tokenizer(batch['solution'], padding='max_length', truncation=True, max_length=256, return_tensors='pt')


    batch['input_ids'] = inputs['input_ids']
    batch['attention_mask'] = inputs['attention_mask']
    batch['labels'] = labels['input_ids']

    return batch

# Apply the tokenization function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Set the format for PyTorch tensors
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
print(f"Tokenized Datasets: {tokenized_datasets}")

Map:   0%|          | 0/36346 [00:00<?, ? examples/s]

Tokenized Datasets: DatasetDict({
    train: Dataset({
        features: ['task_id', 'sympy_exp', 'latex_expression', 'solution', 'simplified_solution', 'synthetic', 'domain', 'test_cases', 'complexity', 'equation_type', 'output_type', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 36346
    })
})


### 3. Setup the Training Arguments

In [5]:
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split

# # Split dataset indices
# train_indices, val_indices = train_test_split(list(range(len(tokenized_datasets['train']))), test_size=0.1, random_state=42)
# # Create subsets
# train_dataset = Subset(tokenized_datasets['train'], train_indices)
# val_dataset = Subset(tokenized_datasets['train'], val_indices)

train_dataset = tokenized_datasets['train']

# Initialize the model
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Define TrainingArguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=35,
    per_device_train_batch_size=32,
    warmup_steps=300,
    weight_decay=0.01,
    logging_dir='./logs',
    # save_strategy = "epoch",
    logging_steps=1000,
    fp16 = True,  # Enable mixed precision training
    gradient_accumulation_steps=4
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    optimizers = (torch.optim.AdamW(model.parameters(), lr=5e-5), None)
)

print(f"Ready for training!")

# # When resuming training, use the path to the latest checkpoint
# latest_checkpoint_path = './checkpoints/checkpoint-latest'  # Replace with your latest checkpoint path

# # Resume training
# trainer.train(resume_from_checkpoint=latest_checkpoint_path)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Ready for training!


### 4. Train the Model

In [6]:
# Start Training
# torch.cuda.empty_cache()
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33msuhaibfarooqui2000[0m ([33msuhaibfarooqui2000-nust[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
1000,0.1981
2000,0.0069
3000,0.0045
4000,0.0032
5000,0.0024
6000,0.0018
7000,0.0014
8000,0.0011
9000,0.0009


TrainOutput(global_step=9940, training_loss=0.02223771441150719, metrics={'train_runtime': 15794.4967, 'train_samples_per_second': 80.541, 'train_steps_per_second': 0.629, 'total_flos': 3.873307110801408e+17, 'train_loss': 0.02223771441150719, 'epoch': 35.0})

### 5. Save the Trained Model

In [7]:
model.save_pretrained('./CodeT5B_35ep_modified_v2.3')
tokenizer.save_pretrained('./CodeT5B_35ep_modified_v2.3')
print("Model is saved")

Model is saved


In [None]:
import shutil
shutil.make_archive('CodeT5B_13ep_modified_train_SS', 'zip', './CodeT5B_13ep_modified_train_SS')

# Inference

### 1. Load the Test Dataset

In [8]:
import json
import csv

# Load the test JSON file
with open('/workspace/Project/public_test_new_no_sol_no_out.json', 'r') as f:
    test_data = json.load(f)

### 2. Load the Trained Model

In [33]:
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration, RobertaTokenizer

# Load the tokenizer and model from the trained directory
# model_name = "Salesforce/codet5-large"     # for codet5large
# tokenizer = AutoTokenizer.from_pretrained(model_name)     # for codet5large

model_name = "Salesforce/codet5-base"     # for codet5base
tokenizer = RobertaTokenizer.from_pretrained(model_name)     # for codet5base

model_path = '/workspace/Project/CodeT5B_30ep_modified_v2.1(notest)'
model = T5ForConditionalGeneration.from_pretrained(model_path)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

cuda


In [None]:
def preprocess(text):
    import re
    text = re.sub(r'return(?!\s)', r'return ', text)
    text = text.replace('cdot*', '').replace('cdot', '')
    text = text.replace('from math import sqrt', '')
    text = text.rstrip() + ")" * (text.count("(") - text.count(")"))
    balance, result = 0, []
    for char in text:
        if char == '(': balance += 1
        elif char == ')' and balance == 0: continue
        elif char == ')': balance -= 1
        result.append(char)
    text = ''.join(result) + ')' * balance
    # Ensure proper 'try-except' exception handling
    if "try:" in text and "except" not in text:
        text = text.replace("try:", "")
    return text

### 3. Define Inference Function

In [39]:
# Function to perform inference on a single latex expression
def infer_latex_expression(latex_expression):
    import re
    inputs = tokenizer(latex_expression, return_tensors="pt", padding=True, truncation=True).to(device)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # Generate outputs for the batch
    with torch.no_grad():
      outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=256)

    generated_text = [preprocess(tokenizer.decode(output, skip_special_tokens=True)) for output in outputs]
        
    return generated_text

# Sample inference
generated_code = infer_latex_expression("\\frac{6 v^{3}}{5 z + 3} + \\frac{7 w^{2}}{9 w + 5} + \\frac{5 x^{2}}{6 x + 6} + \\frac{4 y^{2}}{3 z + 8} + \\frac{8 z^{2}}{y + 7} - 2.5")
# generated_code = infer_latex_expression("86.6059059698858 x + 86.6059059698858 y")
print(f"Generated Code: {generated_code[0]}")

Generated Code: 

def rational_function(x, y, z, w, v):

    try:
        return 6*v**3/(5*z + 3) + 7*w**2/(9*w + 5) + 5*x**2/(6*x + 6) + 4*y**2/(3*z + 8) + 8*z**2/(y + 7) - 2.5
    except ZeroDivisionError:
        return float('inf')  # or another appropriate value or error handling


### 4. Define Evaluate Function

In [40]:
# @title evaluate_code function
import json
import re
import math
import csv
import torch

def evaluate_code(generated_code: str, test_cases: list):
    import re
    output_list = []
    error_list = [str(float('inf'))]*len(test_cases)
    is_solution_valid = True
    
    # Step 1: Extract and execute import statements
    # import_pattern = re.compile(r'^\s*(import\s+\w+(?:\s+as\s+\w+)?|from\s+\w+\s+import\s+\w+)', re.MULTILINE)
    import_pattern = re.compile(r'^\s*(import\s+\w+(?:\s+as\s+\w+)?|from\s+\w+\s+import\s+\*|from\s+\w+\s+import\s+\w+(?:,\s*\w+)*)', re.MULTILINE)
    imports = import_pattern.findall(generated_code)
    
    for import_statement in imports:
        try:
            exec(import_statement, globals())
        except ImportError as e:
            print(f"ImportError in executing import statement: {import_statement}. Error: {e}")
        except SyntaxError as e:
            print(f"SyntaxError in import statement: {import_statement}. Error: {e}")

    # Remove import statements from generated_code to isolate the function definition
    generated_code = import_pattern.sub('', generated_code)

    # Step 2: Extract the function definition
    function_pattern = re.compile(r'(def\s+\w+\([^)]*\):(?:\n\s+.+)+)')
    match = function_pattern.search(generated_code)
    if match:
        function_def = match.group(1)
    else:
        print(f"No valid function definition found!")
        is_solution_valid = False
        return error_list, is_solution_valid

    # Step 3: Extract the function parameters
    param_pattern = re.compile(r'def\s+(\w+)\(([^)]*)\)')
    param_match = param_pattern.search(function_def)
    if param_match:
        function_name = param_match.group(1)
        parameters = param_match.group(2).split(', ')
        # print(f"Parameters: {parameters}")
    else:
        print(f"Couldn't extract function parameters!")
        is_solution_valid = False
        return error_list, is_solution_valid

    # Step 4: Execute the function definition
    local_scope = {}
    try:
        exec(function_def, globals(), local_scope)
    except SyntaxError as e:
        print(f"SyntaxError in function definition!")

    # Ensure the function is defined in local_scope
    if function_name not in local_scope:
        print(f"Function '{function_name}' not found after exec!")
        is_solution_valid = False
        return error_list, is_solution_valid

    # Step 5: Test each case in the test_cases
    is_result_valid_list = []
    for test_case in test_cases:
        is_result_valid = True
        input_values = test_case["input"]
        try:
            function_args = {param: input_values.get(param, 0) for param in parameters}
        except Exception as e:
            is_result_valid = False
            print(f"Error finding arguments: {e}")
            result = float('inf')
            output_list.append(str(result))
            is_result_valid_list.append(is_result_valid)
            continue
        function_to_call = local_scope[function_name]

        try:
            result = function_to_call(**function_args)
        except Exception as e:
            # print(f"Function arguments: {function_args}")
            if str(e) != "math range error":
              print(f"An error occurred while calculating result: {e}")
              is_result_valid = False
            else:
              is_result_valid = True
            result = float('inf')
            output_list.append(str(result))
            is_result_valid_list.append(is_result_valid)
            continue

        # Convert result to a Python number if it's a tensor
        if isinstance(result, torch.Tensor):
            result = result.item()
        elif isinstance(result, (int, float)):
            result = float(result)
        elif hasattr(result, 'evalf'):
            result = result.evalf()

        # Format the complex number to the desired string format
        if result != float('inf'):
            result = complex(result)
            if result.imag:
                if result.imag > 0:
                    result = f'{result.real:.6f}+{result.imag:.6f}j'
                else:
                    result = f'{result.real:.6f}-{abs(result.imag):.6f}j'
            else:
                result = f'{result.real:.6f}'

        # Append the result to the output_list
        output_list.append(str(result))
        is_result_valid_list.append(is_result_valid)

    return output_list, all(is_result_valid_list)

#### 4.1 Sample Testing

In [41]:
import re
import numpy as np
generated_code = '''import numpy as np
import sympy as sp
def geometric(x, y):
    return x*y+100*np.pi
'''
test_cases = [{'input': {'x': 7.440682794714104, 'y': 5.095726493684404}},
              {'input': {'x': 4.504088231490985, 'y': 2.242145429757670}},
              {'input': {'x': 9.433652200314123, 'y': 6.074211889279868}},
              {'input': {'x': 4.672236683712683, 'y': 2.163058324414243}},
              {'input': {'x': 5.735382662132545, 'y': 2.272275556889203}}
             ]

evaluate_code(generated_code, test_cases)

(['352.074950', '324.258086', '371.461268', '324.265586', '327.191635'], True)

### 5. Perform Actual Inference to get the Output File

In [42]:
# @title Actual Inferencing
from tqdm import tqdm
from collections import defaultdict
import csv

error_list = [str(float('inf'))]*5

# Prepare to write to CSV
csv_file_path = 'CodeT5B_30ep_modified_v2.1(notest).csv'
with open(csv_file_path, 'w', newline='') as csvfile:
    csv_writer = csv.writer(csvfile)
    csv_writer.writerow(['id', 'outputs'])  # Write the header row
    
    batch_size = 64  # Adjust based on your GPU memory capacity
    numEE = 0

    # Dictionary to track evaluation errors per equation_type
    evaluation_errors_by_type = defaultdict(int)
    # trunc_test_data = test_data[:100]
    # Iterate over the test data with progress tracking
    for i in tqdm(range(0, len(test_data), batch_size), desc="Processing tasks"):
        batch_tasks = test_data[i:i+batch_size]
        latex_expressions = [task['latex_expression'] for task in batch_tasks]
        generated_codes = infer_latex_expression(latex_expressions)
        for task, generated_code in zip(batch_tasks, generated_codes):
            task_id = task['task_id']
            equation_type = task['equation_type']  # Extract the equation type
            latex_expression = task['latex_expression']
            test_cases = task['test_cases']
            # print(f"Test Cases: {test_cases}\n")
         # try: 
            # Evaluate the generated code for each test case
            task_outputs, valid = evaluate_code(generated_code, test_cases)
            if not valid:
                numEE += 1
                evaluation_errors_by_type[equation_type] += 1
                # print(f"____________ERROR LOG____________")
                # print(f"Task ID: {task_id}")
                # print(f"LaTeX Expression:\n{latex_expression}")
                # print(f"Generated Code:\n{generated_code}\n")
                
            csv_writer.writerow([task_id, task_outputs])

# Display the evaluation errors summary
print(f"\nOutput has been written to {csv_file_path}")
print(f"\nNumber of evaluation errors: {numEE}")

# Print evaluation errors for each equation_type
print("\nEvaluation Errors by Equation Type:")
for equation_type, error_count in evaluation_errors_by_type.items():
    print(f"{equation_type}: {error_count} errors")


Processing tasks:  25%|██▌       | 4/16 [00:13<00:40,  3.33s/it]

An error occurred while calculating result: module 'sympy' has no attribute 'f'
An error occurred while calculating result: module 'sympy' has no attribute 'f'
An error occurred while calculating result: module 'sympy' has no attribute 'f'
An error occurred while calculating result: module 'sympy' has no attribute 'f'
An error occurred while calculating result: module 'sympy' has no attribute 'f'
____________ERROR LOG____________
Task ID: 9dfe1cf7
LaTeX Expression:
\mathtt{\text{Integral(4*x**4 + 10*x**3 + 3*x**2 + 9*x + sqrt(a*x + b + x**2) + sin(sqrt(x)/2) + cos(a + x**2) + 5, x)}}
Generated Code:

import numpy as np
import sympy as sp

def integral_simplified(x, a, b, c):
    return -4*np.sqrt(x)*np.cos(np.sqrt(x)/2) + 4*x**5/5 + 5*x**4/2 + x**3 + 9*x**2/2 + 5*x + (a/4 + x/2)*np.sqrt(a*x + b + x**2) + (-a**2/8 + b/2)*sp.Piecewise((np.log(a + 2*x + 2*np.sqrt(a*x + b + x**2)), sp.Ne(a**2/4 - b, 0)), ((a/2 + x)*np.log(a/2 + x)/np.sqrt((a/2 + x)**2), True)) + np.sqrt(2)*np.sqrt(np.pi)*(

Processing tasks:  62%|██████▎   | 10/16 [00:34<00:22,  3.80s/it]

SyntaxError in function definition!
Function 'rational_function' not found after exec!
____________ERROR LOG____________
Task ID: 20696a62
LaTeX Expression:
\frac{8 x + 2}{9^{x} + 9 y^{2} - 9604}
Generated Code:


def rational_function(x, y):

    try:
        return 8*x + 2/(9**x + 9*y**2 - 9604
    except ZeroDivisionError:
        return float('inf')  # or another appropriate value or error handling)

SyntaxError in function definition!
Function 'augmented_function' not found after exec!
____________ERROR LOG____________
Task ID: 80b80a95
LaTeX Expression:
\frac{5^{x} + 1035.67769989538 \sqrt{2 h^{2} + r^{2} + w^{2}} - 1936}{6 \log{\left(5 x \right)} + 1}
Generated Code:
import numpy as np
from sympy import *
def augmented_function(x, r, h, w, 2):
    return (5**x + 1035.67769989538*sqrt(2*h**2 + r**2 + w**2) - 1936)/(6*log(5*x) + 1)

SyntaxError in function definition!
Function 'augmented_function' not found after exec!
____________ERROR LOG____________
Task ID: 5b219b33
LaTeX Expr

Processing tasks:  69%|██████▉   | 11/16 [00:38<00:19,  3.92s/it]

An error occurred while calculating result: name 'x' is not defined
An error occurred while calculating result: name 'x' is not defined
An error occurred while calculating result: name 'x' is not defined
An error occurred while calculating result: name 'x' is not defined
An error occurred while calculating result: name 'x' is not defined
____________ERROR LOG____________
Task ID: 2f5e18d1
LaTeX Expression:
- 1890.26007655406 h l w + 9 \cos{\left(4 x \right)} + 9
Generated Code:
from sympy import cos

def trigonometric_function(w, l, h):
    return -1890.26007655406*h*l*w + 9*cos(4*x) + 9



Processing tasks:  75%|███████▌  | 12/16 [00:43<00:16,  4.07s/it]

An error occurred while calculating result: name 'B' is not defined
An error occurred while calculating result: name 'B' is not defined
An error occurred while calculating result: name 'B' is not defined
An error occurred while calculating result: name 'B' is not defined
An error occurred while calculating result: name 'B' is not defined
____________ERROR LOG____________
Task ID: 03205d8a
LaTeX Expression:
\frac{0.0565698253396804 \left(- 7 x + 9 y - 29\right) \left(1.02442536802193 x^{2} + 0.829681920170998 y^{2} + 9.91717244288725 e^{- 0.257807598554065 z}\right)}{B h}
Generated Code:
import numpy as np
from sympy import *
def augmented_function(x, y, b, z, h):
    return 0.0565698253396804*(-7*x + 9*y - 29)*(1.02442536802193*x**2 + 0.829681920170998*y**2 + 9.91717244288725*exp(-0.257807598554065*z))/(B*h)



Processing tasks:  81%|████████▏ | 13/16 [00:48<00:13,  4.56s/it]

An error occurred while calculating result: name 'R' is not defined
An error occurred while calculating result: name 'R' is not defined
An error occurred while calculating result: name 'R' is not defined
An error occurred while calculating result: name 'R' is not defined
An error occurred while calculating result: name 'R' is not defined
____________ERROR LOG____________
Task ID: 6ee1ca18
LaTeX Expression:
5605.99552408209 \pi^{2} R r + 0.600143809449589 y^{2} + 9.4013513898678 e^{- 1.97302771956569 x}
Generated Code:
from sympy import exp

def exponential_decay_function(x, r, y):
    return 5605.99552408209*pi**2*R*r + 0.600143809449589*y**2 + 9.4013513898678*exp(-1.97302771956569*x)



Processing tasks: 100%|██████████| 16/16 [01:02<00:00,  3.90s/it]

An error occurred while calculating result: name 'x' is not defined
An error occurred while calculating result: name 'x' is not defined
An error occurred while calculating result: name 'x' is not defined
An error occurred while calculating result: name 'x' is not defined
An error occurred while calculating result: name 'x' is not defined
____________ERROR LOG____________
Task ID: 93913ab8
LaTeX Expression:
710.232009116335 b h + 10 \tan{\left(7 x \right)} + 5
Generated Code:
from sympy import tan

def trigonometric_function(b, h):
    return 710.232009116335*b*h + 10*tan(7*x) + 5


Output has been written to CodeT5B_35ep_modified_v2.3.csv

Number of evaluation errors: 8

Evaluation Errors by Equation Type:
integration: 1 errors
augmented_equation: 7 errors



