In [21]:
import torch
import re
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteria
import numpy as np
import random
from torch.utils.data import DataLoader
from tokenizer import *
from generation_processing import *

class MySantaCoder(nn.Module):
    def __init__(self, list_of_bad_words = ['#'], max_tokens = 128, num_sol = 1):
        super(MySantaCoder, self).__init__()
        self.checkpoint = "bigcode/santacoder"
        self.model = AutoModelForCausalLM.from_pretrained(self.checkpoint, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
        self.max_new_tokens = max_tokens
        # define the list of bad word that the model should not generate
        self.bad_words = self.get_input_ids_as_list(list_of_bad_words)
        # define the list of stop words
        # self.stop_criteria= self.get_input_ids_as_list(stop_words)
        # self.stop_criteria.append(self.tokenizer.eos_token_id)
        # self.stopping_criteria = KeyWordsStoppingCriteria(self.stop_criteria)

        self.generation_config = GenerationConfig(
                bad_words_ids = self.bad_words,
                num_beams = num_sol,
                num_return_sequences = num_sol,
                max_new_tokens = self.max_new_tokens,
                # StoppingCriteria = self.stopping_criteria,
                eos_token_id=self.model.generation_config.eos_token_id,
                bos_token_id=self.model.generation_config.bos_token_id
                )

    def get_input_ids_as_list(self, list_of_bad_words):
        token_list = []
        for element in list_of_bad_words:
            token_list.append(self.tokenizer.encode(element))
        return token_list
    
    def forward(self, input_ids):
        # input_ids = input_ids.unsqueeze(0)
        outputs = self.model.generate(input_ids, self.generation_config)
        return outputs

    def decode_output(self, encoded_output):
        output = self.tokenizer.decode(encoded_output)
        return output

    def post_generation_processing(self,code):
        # split it into list of blocks
        list_blocks = re.split('def |class |assert |print ', code)
        if 'init' in list_blocks[1]:
            fill_word = '\nclass '
        else:
            fill_word = '\ndef '
        # keep only the first block
        result = list_blocks[0] + fill_word + list_blocks[1]
        return result

class KeyWordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids):
        super(KeyWordsStoppingCriteria, self).__init__()
        self.keywords_ids = keywords_ids
    def __call__(self, input_ids):
        if input_ids in self.keywords_ids:
            return True
        return False

In [2]:
from get_data import * 

converted_mtbp = read_json_line_format('data/MTBP/converted_mtpb.jsonl')
mtbp = read_json_line_format('data/MTBP/mtpb.jsonl')

In [3]:
import re
import random
import string

def generate_random_name(signature):
    # Find the function name using regular expression
    match = re.match(r'def ([\w\-_%.]+)\((.*)\):', signature)
    if match:
        original_name, parameters = match.groups()
        
        # Generate a random name with a reasonable length (e.g., length of original name)
        random_name = ''.join(random.choice(string.ascii_lowercase) for _ in range(len(original_name)))

        # Replace the original name with the random name
        new_signature = f'def {random_name}({parameters}):'
        return new_signature
    else:
        print(signature)
        # raise ValueError("Invalid function signature")
        return signature

def custom_dataset_context_investigation(mtbp_converted, mtbp):

    # select only features that are interesting
    features_name_converted = ['text', 'signature','test_list']
    mtbp_converted = mtbp_converted[features_name_converted]
    features_name = ['prompts']
    mtbp = mtbp[features_name]
    
    data = pd.concat([mtbp, mtbp_converted], axis=1)

    random_names = []

    for i in range(len(data)):
        signature = data.iloc[i]['signature']
        random_name = generate_random_name(signature)
        random_names.append(random_name)

    data['random_signatures'] = random_names

    return data

In [4]:
STOP_WORDS = ['def', 'if', 'for', 'while']

def context_and_contexless_generation(data, model, early_stopping = None):
    """ Generate two types of problems:
            1. generate with the appropriate function signature and the context (keep the structural generation cut off)
            2. generate with a random function name and without context (keep the structural generation cut off)
            3. Keep both generation at each step with a very large cut off function (when finding a new 'def') 
    """
    codes_with_context = []
    codes_without_context = []

    raw_generations_context = []
    raw_generations_no_context = []

    for j in range(len(data)):
        if early_stopping is not None and j > early_stopping:
            break

        code_with_context = []
        code_without_context = []

        no_cut_off_no_context = []
        no_cut_off_context = []

        # start with the signature for the incoming problem
        code = data.iloc[j]['signature']
        # start with a random name for the incoming problem
        code_random = data.iloc[j]['random_signatures']
        # initiate the list of prompt to generate
        prompts = data.iloc[j]['prompts']
        # Iterate over each prompt
        for i, prompt in enumerate(prompts):
            
            # Add the prompt to the previously generated code
            input_text_context = code + '\n\t' + '#' + prompt
            input_text_no_context = code_random +'\n\t' + '#' + prompt

            # Encode the input text
            input_ids_context = model.tokenizer.encode(input_text_context, return_tensors='pt')
            input_ids_no_context = model.tokenizer.encode(input_text_no_context, return_tensors='pt')

            # Generate the output
            output_ids_context = model.forward(input_ids_context)
            output_ids_no_context = model.forward(input_ids_no_context)

            # Decode the output
            output_text_context = model.decode_output(output_ids_context[0])
            output_text_no_context = model.decode_output(output_ids_no_context[0])



            # Cut off the generated code
            code = generation_cut_off(gen_code = output_text_context, stop_words=STOP_WORDS, index_prompt=i)
            code_random = generation_cut_off(gen_code = output_text_no_context, stop_words=STOP_WORDS, index_prompt=0)
            code_random = remove_context(code_random)

            # Keep the generation with a large cut off (new def found)
            output_text_context = model.post_generation_processing(output_text_context)
            output_text_no_context = model.post_generation_processing(output_text_no_context)

            code_with_context.append(code)
            code_without_context.append(code_random)

            no_cut_off_no_context.append(output_text_no_context)
            no_cut_off_context.append(output_text_context)

        codes_with_context.append(code_with_context)
        codes_without_context.append(code_without_context)

        raw_generations_context.append(no_cut_off_context)
        raw_generations_no_context.append(no_cut_off_no_context)

    return codes_with_context, raw_generations_context, codes_without_context, raw_generations_no_context

In [5]:
from get_data import *
mtbp_converted = read_json_line_format('data/MTBP/converted_mtpb.jsonl')
mtbp = read_json_line_format('data/MTBP/mtpb.jsonl')

dataset = custom_dataset_context_investigation(mtbp_converted, mtbp)

In [34]:
index = 2
prompt = 1
instruction = dataset.iloc[index]['random_signatures'] + '\n\t"""' + dataset.iloc[index]['prompts'][prompt] + '"""'
instruction_comments = '#' + dataset.iloc[index]['prompts'][prompt]

In [35]:
print(instruction_comments)

#Write a function that takes an integer hours and converts it to seconds.


In [23]:
model = MySantaCoder()

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


In [36]:
input_ids = model.tokenizer.encode(instruction_comments, return_tensors='pt')
output_ids = model.forward(input_ids)
output_text = model.decode_output(output_ids[0])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:49152 for open-end generation.


In [39]:
text = output_text + '\n\ndef function(test):\n\treturn bite'

In [43]:
print(text)

#Write a function that takes an integer hours and converts it to seconds.

def convert_hours_to_seconds(hours):
    return hours * 3600

print(convert_hours_to_seconds(1))
print(convert_hours_to_seconds(2))
print(convert_hours_to_seconds(3))
print(convert_hours_to_seconds(4))
print(convert_hours_to_seconds(5))
print(convert_hours_to_seconds(6))
print(convert_hours_to_seconds(7))
print(convert_hours_to_seconds(8))


def function(test):
	return bite


In [52]:
def find_all_indices(text, substring):
    indices = []
    index = text.find(substring)
    while index != -1:
        indices.append(index)
        index = text.find(substring, index + 1)
    return indices

def starts_with_def(text):
    """ Check if the generated text start with a def or not."""
    lines = text.split('\n')
    found_comment = False

    for line in lines:
        stripped_line = line.strip()

        if not found_comment and stripped_line.startswith('#'):
            found_comment = True
            continue

        if found_comment and stripped_line:
            return stripped_line.startswith('def')

    return False

In [67]:
import numpy as np

def cut_off_generated_text(text):
    # Finding the index of the pattern '\n\nprint('
    index_print = text.find('\n\nprint(')
    # Finding the index of the pattern '\n\ndef'
    index_def = text.find('\n\ndef')
    startwith_def = starts_with_def(text)
    indexes_def = find_all_indices(text, '\n\ndef')
    # Finding the minimum index among the two patterns (if found)
    index = -1
    if index_print != -1 and index_def != -1:
        # there are both 'def's and 'print()'s within the text
        if startwith_def:
            if len(indexes_def) > 1:
                index_def = indexes_def[1]
            else:
                index_def = np.inf
        index = min(index_print, index_def)
    elif index_print != -1:
        index = index_print
    elif index_def != -1:
        index = index_def

    # If any pattern is found, slicing the text accordingly
    if index != -1:
        return text[:index]
    return text

In [57]:
output_text[73:]

'\n\ndef convert_hours_to_seconds(hours):\n    return hours * 3600\n\nprint(convert_hours_to_seconds(1))\nprint(convert_hours_to_seconds(2))\nprint(convert_hours_to_seconds(3))\nprint(convert_hours_to_seconds(4))\nprint(convert_hours_to_seconds(5))\nprint(convert_hours_to_seconds(6))\nprint(convert_hours_to_seconds(7))\nprint(convert_hours_to_seconds(8))\n'

In [69]:
print(cut_off_generated_text(output_text))

#Write a function that takes an integer hours and converts it to seconds.

def convert_hours_to_seconds(hours):
    return hours * 3600


In [None]:

model = MySantaCoder()
codes_with_context, codes_without_context, codes_without_context, raw_generations_no_context = context_and_contexless_generation(data=dataset, model=model, early_stopping=2)

In [58]:
codes_with_context[1]

['def normalize_integer_list(numbers):\n\t#Define a list of integers named "numbers" with the values {numbers}.\n\tnumbers = [int(x) for x in numbers]',
 'def normalize_integer_list(numbers):\n\t#Calculate the sum of the elements in variable "numbers" and store the result to variable "total".\n\ttotal = sum(numbers)',
 'def normalize_integer_list(numbers):\n\t#Divide each element of the list by the total and multiply by 100, store the result to variable "normalized".\n\tnormalized = []',
 'def normalize_integer_list(numbers):\n\t#Convert each element in variable "normalized" into a formatted string with single decimal point and store the result into "formatted".\n\tformatted = []',
 'def normalize_integer_list(numbers):\n\t#Print the variable "formatted".\n\tformatted = []']

In [59]:
codes_without_context[1]

['def wicupvsftvndphmdqgwdyw(numbers):\n\tnumbers = [int(x) for x in numbers]',
 'def wicupvsftvndphmdqgwdyw(numbers):\n\ttotal = 0',
 'def wicupvsftvndphmdqgwdyw(numbers):\n\tnormalized = []',
 'def wicupvsftvndphmdqgwdyw(numbers):\n\tformatted = ""',
 'def wicupvsftvndphmdqgwdyw(numbers):\n\tformatted = ""']