In [None]:
from google.colab import drive
root = "/content/drive/"
drive.mount(root)

In [None]:
import json, os, sys, re, json

path = os.path.join(root, "My Drive/Colab Notebooks/COSE474")
os.makedirs(path, exist_ok=True)

od_path = os.path.join(path, "Rust_Code_Generation")
os.makedirs(od_path, exist_ok=True)

%cd "{od_path}"

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

In [None]:
from huggingface_hub import login

with open("token_file.txt", "r") as file:
    token = file.read().strip()

# Log in using your token
login(token)

In [None]:
# model_name = "meta-llama/CodeLlama-13b-Instruct-hf"
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer  = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the GPU (or CPU if no GPU available)
model = model.to(device)

In [None]:
device = next(model.parameters()).device
print(f"The model is loaded on: {device}")

In [None]:
from transformers import StoppingCriteria, StoppingCriteriaList, AutoModelForCausalLM, AutoTokenizer

class StopOnRUSTEND(StoppingCriteria):
    def __init__(self, stop_sequence: str, tokenizer, max_occurrences: int = 4):
        self.stop_sequence = stop_sequence
        self.tokenizer = tokenizer
        self.max_occurrences = max_occurrences
        self.current_count = 0

    def __call__(self, input_ids, scores, **kwargs):
        # Convert the current generated tokens back to text
        text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
        # Count occurrences of 'RUST_END' in the generated text
        self.current_count = text.count(self.stop_sequence)

        # Stop if the 'RUST_END' sequence has appeared max_occurrences times
        return self.current_count >= self.max_occurrences

In [None]:
class RustCodeGenerator:
    def __init__(self, input_filename, output_filename, tokenizer, model, num_samples = 164, total_samples = 164, start_idx = 0, max_new_tokens = 300):
        self.input_filename = input_filename
        self.output_filename = output_filename
        self.tokenizer = tokenizer
        self.model = model
        self.num_samples = num_samples if num_samples <= total_samples else total_samples
        self.total_samples = total_samples
        self.start_idx = start_idx
        self.max_new_tokens = max_new_tokens

    def print_progress_bar(self, current, total, bar_length=50):
        # Calculate the percentage of completion
        percent = (current / total) * 100
        # Determine the number of "#" characters in the bar based on the percentage
        filled_length = int(bar_length * current // total)
        bar = '#' * filled_length + '-' * (bar_length - filled_length)

        # Use '\r' to overwrite the current line and display the loading bar
        sys.stdout.write(f'\rProgress: |{bar}| {percent:.2f}% ({current}/{total})')
        sys.stdout.flush()

    def process_line(self, line, row):

        one_shot = (
            f"You are an expert Rust programmer. Write a Rust function `add_two_numbers(a: i32, b: i32) -> i32` that returns the sum of two integers."
            "In your response, use RUST_BEGIN and RUST_END to delimit the rust function."
            "RUST_BEGIN"
            "fn add_two_numbers(a: i32, b: i32) -> i32 {"
            "    a + b"
            "}"
            "RUST_END"
            ""
        )

        prompt = (
            f"You are an expert Rust programmer. {row['instruction']}"
            "In your response, use RUST_BEGIN and RUST_END to delimit the rust function.\nRUST_BEGIN\n"
        )
        prompt = one_shot + prompt

        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = {key: value.to(device) for key, value in inputs.items()}

        stop_sequence = "RUST_END"
        stopping_criteria = StoppingCriteriaList([StopOnRUSTEND(stop_sequence, self.tokenizer)])

        attempt = 0
        max_attempts = 10
        rust_code = ""

        while attempt < max_attempts:
            # Configure generation parameters
            generation_kwargs = {
                "num_return_sequences": 1,
                "max_new_tokens": self.max_new_tokens,
                "stopping_criteria": stopping_criteria,
                "return_dict_in_generate": True,
                "output_scores": True,
            }

            # First attempt: Deterministic decoding
            if attempt == 0:
                generation_kwargs.update({
                    "do_sample": False,
                    "temperature": None,
                    "top_p": None,
                })
            else:  # Subsequent attempts: Sampling with temperature
                generation_kwargs.update({
                    "do_sample": True,
                    "temperature": 0.3,
                    "top_p": 0.9,
                })

            outputs = self.model.generate(**inputs, **generation_kwargs)

            # Decode and extract generated code
            decoded_texts = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
            generated_code = "".join(decoded_texts)
            matches = re.findall(r'RUST_BEGIN\s*(.*?)\s*RUST_END', generated_code, re.DOTALL)

            matches_in_prompt = re.findall(r'RUST_BEGIN\s*(.*?)\s*RUST_END', prompt, re.DOTALL)
            rust_code = matches[len(matches_in_prompt)].strip() if len(matches) > len(matches_in_prompt) else ""  # Extract the code between RUST_BEGIN and RUST_END

            # Break if valid code is generated
            if len(rust_code) > 0:
                print(f"Attempt {attempt}: Generated code length is {len(rust_code)}")
                break

            attempt += 1
            print(f"Attempt {attempt}: Generated code length is 0. Retrying...")

        return {
            'task_id': row["task_id"],
            'instruction': row["instruction"],
            'generated_code': rust_code if len(rust_code) > 0 else "No valid code generated.",
            'test': row["test"],
            'attempt': attempt
        }

    def save_output(self, generated_codes):
        with open(self.output_filename, "w") as file:
            json.dump(generated_codes, file, indent=0)

    def process(self):
        generated_codes = []

        # Load existing data if the file exists
        if os.path.exists(self.output_filename) and os.path.getsize(self.output_filename) > 0:
            with open(self.output_filename, "r") as file:
                generated_codes = json.load(file)

        with open(self.input_filename, "r") as f:
            for i in range(self.total_samples):
                if i < self.start_idx:
                    f.readline().strip()  # Skip lines until start_idx
                    continue

                line = f.readline().strip()
                self.print_progress_bar(i - self.start_idx, self.num_samples)
                row = json.loads(line)

                generated_code_data = self.process_line(line, row)
                generated_codes.append(generated_code_data)

                # Save progress to the output file after each iteration
                self.save_output(generated_codes)

                count = i - self.start_idx
                if count + 1 >= self.num_samples:
                    break
            self.print_progress_bar(self.num_samples, self.num_samples)


    def print_example_code(self, idx=0):

        if os.path.exists(self.output_filename) and os.path.getsize(self.output_filename) > 0:
            with open(self.output_filename, "r") as file:
                generated_codes = json.load(file)
                if generated_codes:
                    # Take the first entry
                    try:
                      generated_code_data = generated_codes[idx]
                    except IndexError:
                      print("Index out of range.")
                      return

                    task_id = generated_code_data.get("task_id", "No task available")
                    prompt = generated_code_data.get("instruction", "No prompt available")
                    code = generated_code_data.get("generated_code", "No code available")

                    print("Task " + task_id + ":")
                    print("Example Prompt:")
                    print(prompt)
                    print("\nGenerated Code:")
                    print(code)
                else:
                    print("No data available in the output file.")
        else:
            print("Output file does not exist or is empty.")


In [None]:
input_filename = "humanevalpack.jsonl"
output_filename = "output.json"

if os.path.exists(output_filename):
  os.remove(output_filename)

generator = RustCodeGenerator(
    input_filename,
    output_filename,
    tokenizer,
    model,
    # num_samples=5,
    # start_idx=0,
    max_new_tokens = 1024
)

generator.process()

In [None]:
for i in range(5):
  generator.print_example_code(i)
  print("======================= NEXT ===========================")

In [None]:
generator.print_example_code(89)