In [1]:
import json
import random
import torch
from torch import Tensor
from torch.utils.data import Dataset

from transformers import AutoTokenizer
!pip install trl==0.11.4
from trl import AutoModelForCausalLMWithValueHead


class GeneratorDataset(Dataset):
    """
    Custom Dataset class (inherits from torch.utils.data) for the generator.
    Consists of 'prompts', 'code' and 'test_list' of the mbpp dataset.
    """

    def __init__(self, data: dict, tokenizer, padding_length):
        self.tokenizer = tokenizer
        self.padding_length = padding_length
        self.prompts = torch.Tensor(
            tokenizer(data['prompts'], padding='max_length', max_length=padding_length['prompts'])['input_ids'])
        self.code = torch.Tensor(
            tokenizer(data['code'], padding='max_length', max_length=padding_length['code'])['input_ids'])
        self.test_list = torch.Tensor(
            tokenizer(data['test_list'], padding='max_length', max_length=padding_length['test_list'])['input_ids'])

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return {'prompts': self.prompts[idx],
                'code': self.code[idx],
                'test_list': self.test_list[idx]}

    def select_for_discriminator(self, indices, tokenizer, padding_length, ground_truth):
        if isinstance(self.code, torch.Tensor):
            # handling bug, when prepare_data() was called two times,
            # but each time there was a different structure of self.code
            # no idea what the fuck is happening
            self.code = list(torch.unbind(Tensor.int(self.code), dim=0))
            self.code = tokenizer.batch_decode(self.code, skip_special_tokens=True)
        elif isinstance(self.code, list):
            pass
        else:
            print("Error with the structure of the code.")
        return DiscriminatorDataset({"code": [self.code[i] for i in indices],
                                     "ground_truth": [ground_truth for _ in indices]},
                                    tokenizer=tokenizer, padding_length=padding_length)


class DiscriminatorDataset(Dataset):
    """
    Custom Dataset class (inherits from torch.utils.data) for the discriminator.
    Consists of 'code' and a 'ground_truth' (1 for real data, 0 for generated data).
    """

    def __init__(self, data: dict, tokenizer, padding_length):
        self.tokenizer = tokenizer
        self.padding_length = padding_length
        self.code = torch.Tensor(tokenizer(data['code'], padding='max_length',
                                           max_length=padding_length['code'])['input_ids'])
        self.ground_truth = torch.Tensor(tokenizer(data['ground_truth'], padding='max_length',
                                                   max_length=padding_length['code'])['input_ids'])

    def __len__(self):
        return len(self.code)

    def __getitem__(self, idx):
        return {'code': self.code[idx],
                'ground_truth': self.ground_truth[idx]}

    def __add__(self, other):
        self.code = list(torch.unbind(Tensor.int(self.code), dim=0))
        self.code = self.tokenizer.batch_decode(self.code)

        other.code = list(torch.unbind(Tensor.int(other.code), dim=0))
        other.code = other.tokenizer.batch_decode(other.code)

        self.ground_truth = list(torch.unbind(Tensor.int(self.ground_truth), dim=0))
        self.ground_truth = self.tokenizer.batch_decode(self.ground_truth)

        other.ground_truth = list(torch.unbind(Tensor.int(other.ground_truth), dim=0))
        other.ground_truth = other.tokenizer.batch_decode(other.ground_truth)

        return DiscriminatorDataset(
            {"code": self.code + other.code, "ground_truth": self.ground_truth + other.ground_truth},
            self.tokenizer, self.padding_length)

    @staticmethod
    def interleave(dataset_generated, dataset_groundtruth, num_samples, tokenizer, padding_length):
        if len(dataset_generated) == len(dataset_groundtruth):
            sample_ids = random.sample(range(len(dataset_groundtruth) + len(dataset_generated)), num_samples)
            interleave_dataset = dataset_generated + dataset_groundtruth

            interleave_dataset.code = list(torch.unbind(Tensor.int(interleave_dataset.code), dim=0))
            interleave_dataset.code = tokenizer.batch_decode(interleave_dataset.code)

            interleave_dataset.ground_truth = list(torch.unbind(Tensor.int(interleave_dataset.ground_truth), dim=0))
            interleave_dataset.ground_truth = tokenizer.batch_decode(interleave_dataset.ground_truth)

            return DiscriminatorDataset({"code": [interleave_dataset.code[i] for i in sample_ids],
                                         "ground_truth": [interleave_dataset.ground_truth[i] for i in sample_ids]},
                                        tokenizer=tokenizer, padding_length=padding_length)
        else:
            return Exception


def get_index_length_datasets(train: bool, code_parrot: bool, mbpp: bool) -> dict:
    """
    Needed for max_lenth parameter for tokenization; specific to mbpp dataset.
    """
    if mbpp:
        if code_parrot:
            if train:
                return {'prompts': 49, 'code': 252, 'test_list': 302}  # for codeparrot/codeparrot-small tokenizer
            else:
                return {'prompts': 51, 'code': 0, 'test_list': 2248}
        else:
            if train:
                return {'prompts': 50, 'code': 289, 'test_list': 350}  # for Qwen/Qwen2.5-0.5B tokenizer
            else:
                return {'prompts': 47, 'code': 0, 'test_list': 3670}
    else:
        if code_parrot:
            if not train:
                return {'prompts': 391, 'code': 0, 'test_list': 235}  # for codeparrot/codeparrot-small tokenizer +
                # humaneval dataset for evaluation


def prepare_data(generator: AutoModelForCausalLMWithValueHead,
                 ground_dataset: GeneratorDataset, num_samples: int,
                 gen_args: dict, tokenizer: AutoTokenizer, train: bool,
                 code_parrot: bool) -> DiscriminatorDataset:
    """
    Called each epoch during training of the discriminator; builds new DiscriminatorDataset,
    which consists of samples from the dataset and generated samples with the current generator.
    """
    gen_args['num_return_sequences'] = num_samples

    padding_length = get_index_length_datasets(train, code_parrot, mbpp=True)

    generated_samples = sample_from_generation(generator, gen_args, tokenizer, padding_length)
    ground_truth_samples = sample_from_dataset(ground_dataset, num_samples, tokenizer, padding_length)
    dataset = DiscriminatorDataset.interleave(generated_samples, ground_truth_samples, num_samples * 2, tokenizer,
                                              padding_length)
    return dataset


def sample_from_generation(generator: AutoModelForCausalLMWithValueHead, gen_args: dict, tokenizer: AutoTokenizer,
                           padding_length: dict) -> DiscriminatorDataset:
    """
    Generates num_samples * samples (ground_truth = 0) with the current generator, which has been saved in the last epoch.
    """
    with torch.no_grad():
        samples = generator.generate(**gen_args)

    samples = tokenizer.batch_decode(samples, skip_special_tokens=True)
    sample_dicts = {"code": samples, "ground_truth": ["0" for _ in samples]}

    return DiscriminatorDataset(sample_dicts, tokenizer=tokenizer, padding_length=padding_length)


def sample_from_dataset(dataset: GeneratorDataset, num_samples: int, tokenizer: AutoTokenizer,
                        padding_length: dict) -> DiscriminatorDataset:
    """
    Returns num_samples * samples (ground_truth = 1) from the dataset.
    """
    sample_ids = random.sample(range(len(dataset)), num_samples)
    samples = dataset.select_for_discriminator(sample_ids, tokenizer, padding_length, ground_truth="1")

    return samples


def load_gen_data(file_path: str, tokenizer: AutoTokenizer, train: bool, code_parrot: bool,
                  mbpp: bool) -> GeneratorDataset:
    """
    Process dataset of the following structure to GeneratorDataset.
    """
    with open(file_path, "r", encoding="utf-8") as file:
        data_dicts = json.load(file)
    data_dicts['code'] = ["" for _ in range(0, len(data_dicts['test_list']))]
    data_dicts['test_list'] = [str(item) for item in data_dicts['test_list']]
    padding_length = get_index_length_datasets(train, code_parrot=code_parrot, mbpp=mbpp)

    dataset = GeneratorDataset(data_dicts, tokenizer=tokenizer, padding_length=padding_length)
    return dataset



In [2]:
import ast
import importlib.util
import re
from io import StringIO
import sys
from typing import List
import torch
from torch import Tensor
from torch.nn.functional import softmax
from transformers import AutoTokenizer
!pip install pycodestyle
from pycodestyle import Checker


def get_rewards(generation_tensor: List[torch.Tensor],
                generation_text: List[str],
                device,
                discriminator: torch.nn.Module,
                avg_rewards_during_batches: list,
                tokenizer: AutoTokenizer,
                padding_length: dict,
                disc_weight: int,
                temp_file_path: List[str],
                test_list: List[str]) -> tuple[torch.Tensor, List[float]]:
    """
    Collect general reward for the generator (combination of discriminator reward and objective reward).
    """
    samples = Tensor.int(torch.Tensor(
        tokenizer(generation_text, padding='max_length', max_length=padding_length['code'])['input_ids']).to(device))
    discriminator.to(device)
    pred = discriminator.forward(samples)
    # get predictions for positive class,
    # the more the discriminator is certain, that the sample is real (1), the higher the reward for the generator
    disc_reward = softmax(pred, dim=-1)[:, 1]

    # mapping from [0, 1] to [-1, 1]
    disc_reward_transformation = disc_reward * 2 - 1

    if disc_weight == 1:
        rewards = disc_reward_transformation
        avg_rewards = disc_reward_transformation.tolist()
    elif disc_weight == 0:
        obj_rewards, avg_rewards = collect_rewards(generation_tensor,
                                                   generation_text,
                                                   discount=1,
                                                   avg_rewards_during_batches=avg_rewards_during_batches,
                                                   temp_file_path=temp_file_path,
                                                   test_list=test_list)

        rewards = obj_rewards.to(disc_reward_transformation.device)
    else:
        obj_rewards, avg_rewards = collect_rewards(generation_tensor,
                                                   generation_text,
                                                   discount=1,
                                                   avg_rewards_during_batches=avg_rewards_during_batches,
                                                   temp_file_path=temp_file_path,
                                                   test_list=test_list)

        obj_rewards = obj_rewards.to(disc_reward_transformation.device)

        rewards = disc_weight * disc_reward_transformation + (1 - disc_weight) * obj_rewards

    return rewards, avg_rewards


def collect_rewards(generation_tensor: List[torch.Tensor],
                    generation_text: List[str],
                    avg_rewards_during_batches: List[float],
                    temp_file_path: List[str],
                    test_list: List[str],
                    discount: float = 1.0) -> tuple[torch.Tensor, List[float]]:
    """
    Collects rewards of the objective function (pep8 / try_test_list).
    """
    collected = torch.zeros(len(generation_text))
    reward_list_during_epoch = []
    for k, sample in enumerate(temp_file_path):

        try:
            pep08_reward, output = pep08(sample)  # pep8 reward and output

            with open(sample, "a") as temp_file:
                # code = temp_file.read()
                temp_file.write("\n\n# Errorcodes: " + str(output) + "\n# pep08_reward: " + str(pep08_reward))
            # test_list_reward = try_test_list(sample, test_list[k])

            #length_penalty = (generation_tensor[k].shape[0] / 100) - 1
            # looks up token length of generated code and maps it to [-1, 1] (max_new_token_length = 200)

            reward_list_during_epoch.append(pep08_reward)
            rewards = torch.tensor(pep08_reward)
            collected[k] = torch.sum(torch.tensor(rewards)) * torch.pow(torch.tensor(discount), torch.tensor(k))
        except UnicodeDecodeError:
            collected[k] = torch.tensor(-1)

    if not reward_list_during_epoch:
        print("No rewards collected")
    else:
        avg_rewards_during_batches.append(sum(reward_list_during_epoch) / len(reward_list_during_epoch))
    return collected, avg_rewards_during_batches


def compilable(temp_file_path: str) -> int:
    """
    Can the code be compiled by the standard python compiler.
    """
    try:
        compile(temp_file_path, "<string>", "exec")
        reward = 100
    except SyntaxError:
        reward = -5
    except ValueError:
        reward = -5

    return reward


class Capturing(list):
    """
    Context Manager to capture pycodestyle's print statements.
    """

    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        return self

    def __exit__(self, *args):
        self.pep08_output = self._stringio.getvalue()
        del self._stringio
        sys.stdout = self._stdout
        # Filter specific error codes and store them in the list
        error_code_pattern = r'\b[EWF]\d{3}\b'  # Regex for error codes like E123, W456, F789
        self.extend(re.findall(error_code_pattern, self.pep08_output))


def pep08(temp_file_path: str) -> tuple[int, List[str]]:
    """
    How much does the code adhere to pep08 standards.
    """
    checker = Checker(temp_file_path, show_source=False, show_pep8=True)
    num_errors = 15
    with Capturing() as output:
        try:
            num_errors = checker.check_all()
        except Exception as e:
            print(f"An error occurred while checking the code: {e}")

    if num_errors == 0:
        reward = 1
    elif num_errors <= 3 & num_errors > 0:
        reward = 0.5
    elif num_errors <= 5 & num_errors > 3:
        reward = 0.25
    elif num_errors <= 10 & num_errors > 5:
        reward = -0.5
    elif num_errors <= 15 & num_errors > 10:
        reward = -0.75
    else:
        reward = -1

    return reward, output


def try_test_list(temp_file_path: str, test_list: str) -> float:
    """
    Function tries to find a function in the generated file and uses 'test_list' from mbpp dataset to evaluate the code.
    """
    spec = importlib.util.spec_from_file_location(temp_file_path, "./temp_files/" + temp_file_path)
    module = importlib.util.module_from_spec(spec)
    try:
        spec.loader.exec_module(module)
    except Exception as e:
        return -1.0  # Error loading the file

    functions = [func for func in dir(module) if callable(getattr(module, func))]
    if not functions:
        return -0.5  # No functions found in the file.

    for func_name in functions:
        func = getattr(module, func_name)
        if callable(func):
            target_function = func
            break
        else:
            pass

        test_list = ast.literal_eval(test_list)
        try:
            for test in test_list:
                exec(test)
            return 1  # All tests passed!
        except AssertionError:
            return 0.2  # A test failed!
    else:
        return -0.25  # No matching function found.

# ==================== Code Metrics =====================
# (from https://radon.readthedocs.io/en/latest/intro.html)




In [3]:
import ast
import os
import re
import torch
from pycodestyle import Checker
from transformers import AutoTokenizer
from trl import PPOTrainer
from collections import Counter


def log_error(message, verbose: bool) -> None:
    """
    Log an error message to the console if verbose is True.
    """
    if verbose:
        print(message)


def loading_bar(current, total, width=50) -> None:
    """
    Custom progress bar for generating test files.
    """
    progress = (current / total)
    progress_bar = "[" + "=" * int(progress * width) + " " * (width - int(progress * width)) + "]"
    percent = progress * 100
    print(f"\r{progress_bar} {percent:.1f}%", end="", flush=True)


def mbpp_sample_generieren(idx: int, device: str, dataset_test: GeneratorDataset,
                           generator: PPOTrainer, tokenizer: AutoTokenizer,
                           generation_kwargs: dict, destination_dir_path: str) -> None:
    """
    Generate code files from mbpp dataset.
    """
    prompts_tensor = dataset_test.__getitem__(idx)['prompts'].to(torch.int64).to(device)
    generation_tensor = generator.generate(query_tensor=prompts_tensor, return_prompt=False, **generation_kwargs)
    generated_txts = tokenizer.batch_decode(generation_tensor, skip_special_tokens=True)
    if not os.path.exists(destination_dir_path):
        os.makedirs(destination_dir_path)
    with open(f"{destination_dir_path}/sample{idx}.py", "w", encoding='utf-8') as file:
        file.write(generated_txts[0])


def eval_pep8(directory: str) -> tuple[float, str, float]:
    """
    Check all files in directory for pep8 errors.
    """
    if not os.path.isdir(directory):
        print(f"Der Pfad {directory} ist kein gültiger Ordner.")
        return -1.0, "", 0.0

    files = os.listdir(directory)

    count = 0
    zero_errors = 0
    num_error_list = []
    error_type_list = []
    for file_name in files:
        file_path = os.path.join(directory, file_name)
        checker = Checker(file_path, show_source=False, show_pep8=True)
        with Capturing() as output:
            try:
                num_errors = checker.check_all()
                num_error_list.append(num_errors)
            except Exception as e:
                print(f"An error occurred while checking the code: {e}")
        if num_errors == 0:
            zero_errors += 1
        error_type_list.append(output)
        count += 1
    all_errors = [error for file_errors in error_type_list for error in file_errors]
    error_counts = Counter(all_errors)
    most_common_error, frequency = error_counts.most_common(1)[0]

    print(f"The most common error is '{most_common_error}' with {frequency} occurrences.")
    print(f"The compliance rate (0 errors) is {round((zero_errors / count) * 100, 2)}%.")
    print("Average pep8 errors found: ", round(sum(num_error_list) / count, 2))
    return round(sum(num_error_list) / count, 2), most_common_error, round((zero_errors / count) * 100, 2)


def extract_functions_from_code(code, verbose=False):
    """
    Extract all functions defined in the code using regex.
    """
    functions = []
    try:
        # Use regex to find all function definitions
        matches = re.finditer(r"def\s+(\w+)\s*\([^)]*\)\s*:", code)
        for match in matches:
            functions.append(match.group(1))  # Extract the function name
    except Exception as e:
        log_error(f"Error extracting functions from code: {e}", verbose)
    return functions


def extract_function_code(code, function_name, verbose=False):
    """
    Extract the code for a specific function from the code file.
    """
    try:
        # Use regex to find the function definition and its body
        pattern = rf"def\s+{function_name}\s*\([^)]*\)\s*:\s*((?:\n\s+.*)+)"
        match = re.search(pattern, code, re.DOTALL)
        if match:
            # Return the entire matched function code
            return match.group(0)
        return None
    except Exception as e:
        log_error(f"Error extracting function code for {function_name}: {e}", verbose)
        return None


def can_parse_ast(code):
    """
    Check if the code can be parsed into an AST.
    """
    try:
        ast.parse(code)
        return True
    except SyntaxError:
        return False
    except Exception as e:
        print(f"Error parsing code into AST: {e}")
        return False


def evaluate_test_case(test_case, global_namespace, verbose=False):
    """
    Evaluate a test case by extracting the boolean expression and evaluating it.
    """
    try:
        # Remove the "assert " prefix to get the boolean expression
        expression = test_case.replace("assert ", "")
        # Evaluate the expression using eval and the provided global namespace
        return eval(expression, global_namespace)
    except Exception as e:
        log_error(f"Error evaluating test case {test_case}: {e}", verbose)
        return False


def execute_code_and_check_tests(code_file, test_cases, verbose=False):
    """
    Execute the code in the file and check if all test cases pass.
    If the code file is invalid or raises an error, return False.
    """
    try:
        with open(code_file, 'r') as file:
            code = file.read()

        # Skip if the file is empty
        if not code.strip():
            log_error(f"Error: {code_file} is empty.", verbose)
            return False

        # Determine the expected function name from the first test case
        expected_function_name = test_cases[0].split("(")[0].replace("assert ", "")

        # Check if the entire file can be parsed into an AST
        if can_parse_ast(code):
            # Use the entire code for testing
            global_namespace = {}
            try:
                exec(code, global_namespace)
            except SyntaxError as e:
                log_error(f"Skipping file due to SyntaxError: {e}", verbose)
                return False
            except Exception as e:
                log_error(f"Error executing code: {e}", verbose)
                return False
        else:
            # Extract all function names from the code
            function_names = extract_functions_from_code(code, verbose)
            if not function_names:
                log_error(f"Error: No functions found in {code_file}.", verbose)
                return False

            # Check if the expected function exists in the code
            if expected_function_name not in function_names:
                log_error(f"Error: Function {expected_function_name} not found in {code_file}.", verbose)
                return False

            # Extract the code for the expected function
            function_code = extract_function_code(code, expected_function_name, verbose)
            if not function_code:
                log_error(f"Error: Could not extract code for function {expected_function_name}.", verbose)
                return False

            # Execute only the extracted function code
            global_namespace = {}
            try:
                exec(function_code, global_namespace)
            except SyntaxError as e:
                log_error(f"Skipping function due to SyntaxError: {e}", verbose)
                return False
            except Exception as e:
                log_error(f"Error executing function {expected_function_name}: {e}", verbose)
                return False

        # Check if all test cases pass with this function
        for test_case in test_cases:
            if not evaluate_test_case(test_case, global_namespace, verbose):
                return False

        return True
    except Exception as e:
        log_error(f"Error executing {code_file}: {e}", verbose)
        return False


def calculate_pass_at_1_rate(test_list, code_directory, verbose=False):
    """
    Calculate the pass@1 rate by checking if the code files pass their respective test cases.
    Skip invalid or problematic code files.
    """
    passed = 0
    total = 0

    for idx, test_cases in enumerate(test_list):
        code_file = os.path.join(code_directory, f"sample{idx}.py")
        if os.path.exists(code_file):
            total += 1
            if execute_code_and_check_tests(code_file, test_cases, verbose):
                passed += 1

    if total == 0:
        return 0.0

    print("pass@1 rate: " + str(round((passed / total) * 100, 2)) + "%")
    return (passed / total) * 100


In [4]:


import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Iterable


class CNNDiscriminator(nn.Module):
    def __init__(self, embed_dim: int, vocab_size: int,
                 filter_sizes: Iterable[int], num_filters: Iterable[int],
                 padding_idx: int, gpu: bool = False,
                 dropout=0.2, dis_init="uniform") -> None:
        super(CNNDiscriminator, self).__init__()
        self.embedding_dim = embed_dim
        self.vocab_size = vocab_size
        self.padding_idx = padding_idx
        self.feature_dim = sum(num_filters)
        self.gpu = gpu

        self.embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.convs = nn.ModuleList([
            nn.Conv2d(1, n, (f, embed_dim)) for (n, f) in zip(num_filters, filter_sizes)
        ])
        self.highway = nn.Linear(self.feature_dim, self.feature_dim)
        self.feature2out = nn.Linear(self.feature_dim, 2)
        self.dropout = nn.Dropout(dropout)

        self.init_params(dis_init)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        """
        Get final predictions of discriminator
        :param inp: batch_size * seq_len
        :return: pred: batch_size * 2
        """
        feature = self.get_feature(inp)
        pred = self.feature2out(self.dropout(feature))

        return pred

    def get_feature(self, inp: torch.Tensor) -> torch.Tensor:
        """
        Get feature vector of given sentences
        :param inp: batch_size * max_seq_len
        :return: batch_size * feature_dim
        """
        emb = self.embeddings(inp).unsqueeze(1)  # batch_size * 1 * max_seq_len * embed_dim
        convs = [F.relu(conv(emb)).squeeze(3) for conv in self.convs]  # [batch_size * num_filter * length]
        pools = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs]  # [batch_size * num_filter]
        pred = torch.cat(pools, 1)  # tensor: batch_size * feature_dim
        highway = self.highway(pred)
        pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred  # highway

        return pred

    def init_params(self, dis_init: str):
        for param in self.parameters():
            if param.requires_grad and len(param.shape) > 0:
                stddev = 1 / math.sqrt(param.shape[0])
                if dis_init == 'uniform':
                    torch.nn.init.uniform_(param, a=-0.05, b=0.05)
                elif dis_init == 'normal':
                    torch.nn.init.normal_(param, std=stddev)
                elif dis_init == 'truncated_normal':
                    truncated_normal_(param, std=stddev)


def truncated_normal_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> torch.Tensor:
    """
    Implemented by @ruotianluo
    See https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
    """
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)
    return tensor


In [5]:
import ast
import codecs
import json
import os.path
import shutil
import gc
import tempfile
from statistics import mean

import pandas as pd
from torch import Tensor
from argparse import Namespace
import torch
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

import warnings

warnings.filterwarnings("ignore")


class GANTrainer:

    def __init__(self, cfg: Namespace) -> None:
        self.cfg = cfg
        self.device = "cuda" if self.cfg.gpu else "cpu"
        print("device: ", self.device)
        print(" >> init tokenizer")
        self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.base_model, padding_side="left")
        self.tokenizer.pad_token = self.tokenizer.eos_token
        print(" >> init generator")
        if self.cfg.load_generator:
            self.generator = AutoModelForCausalLMWithValueHead.from_pretrained(self.cfg.load_generator_path)
        else:
            self.generator = AutoModelForCausalLMWithValueHead.from_pretrained(self.cfg.base_model)
            self.generator.save_pretrained(self.cfg.gen_dir)
        if self.cfg.base_model == "codeparrot/codeparrot-small":
            self.code_parrot = True
        else:
            self.code_parrot = False

        print(" >> init discriminator")
        self.discriminator = CNNDiscriminator(embed_dim=self.cfg.embed_dim, vocab_size=len(self.tokenizer),
                                              filter_sizes=cfg.filter_sizes, num_filters=cfg.num_filters,
                                              padding_idx=self.tokenizer.pad_token_id, gpu=self.cfg.gpu,
                                              dropout=self.cfg.dropout)
        if self.cfg.load_discriminator:
            state_dict = torch.load(self.cfg.load_discriminator_file)
            self.discriminator.load_state_dict(state_dict)
        self.discriminator = self.discriminator.to(self.device)
        self.ppo_cfg = PPOConfig(**{"batch_size": self.cfg.batch_size,
                                    "mini_batch_size": self.cfg.batch_size,
                                    "optimize_device_cache": False})
        self.ppo_trainer = PPOTrainer(config=self.ppo_cfg, model=self.generator, tokenizer=self.tokenizer)
        self.generation_kwargs = {
            "min_length": -1,
            "top_k": 0.0,
            "top_p": 1.0,
            "do_sample": True,
            "pad_token_id": self.tokenizer.eos_token_id,
            "max_new_tokens": self.cfg.max_new_tokens,
            "bos_token_id": 0,
            'num_return_sequences': 1
        }
        self.adv_loss = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr=self.cfg.disc_lr)
        print(" >> read data")
        self.dataset_train = load_gen_data(os.path.join(self.cfg.data_dir, "mbpp_train.json"), self.tokenizer,
                                           train=True, code_parrot=self.code_parrot, mbpp=True)
        self.dataset_test = load_gen_data(os.path.join(self.cfg.data_dir, "mbpp_test.json"), self.tokenizer,
                                          train=False, code_parrot=self.code_parrot, mbpp=True)

    def adversarial_train(self) -> None:
        """
        Trainings process of the generator and the discriminator.
        """
        print(" >> prepare Generator")
        self.generator = AutoModelForCausalLMWithValueHead.from_pretrained(self.cfg.gen_dir)
        self.generator = self.generator.to(self.device)
        self.tokenizer.padding_side = "left"
        print(" =========== Discriminator Pretraining ===========")
        disc_loss = []
        disc_acc = []
        disc_loss, disc_acc = self._disc_adv_train(-1, disc_loss, disc_acc)
        print(" > ", {"loss": disc_loss[0], "acc": disc_acc[0]})
        print(" =========== Start Adversarial Training ===========")
        avg_rewards = []
        for epoch in range(self.cfg.num_adv_epochs):
            print(f" --- Epoch {epoch}: Generator ---")
            avg_rewards = self._gen_adv_train(epoch, avg_rewards)
            print("Average reward: ", avg_rewards[epoch])
            print(f" --- Epoch {epoch}: Discriminator---")
            disc_loss, disc_acc = self._disc_adv_train(epoch, disc_acc, disc_loss)
            print(" > ", {"loss": disc_loss[epoch + 1], "acc": disc_acc[epoch + 1]})
            print(f" -----------------------------------")
        print("Average Rewards for the epochs: ", avg_rewards)

        if self.cfg.save_RL:
            print(" >> save RL model")
            model_save_now_path = self.cfg.gen_dir + "/" + "finished_model"
            self.generator.save_pretrained(model_save_now_path + "/generator")
            torch.save(self.discriminator.state_dict(), os.path.join(model_save_now_path, "discriminator.pt"))
            data = {"disc_loss": disc_loss, "disc_acc": disc_acc, "gen_rewards": avg_rewards}
            #data = {"gen_rewards": avg_rewards}
            with open(model_save_now_path + "output.json", "w") as file:
                json.dump(data, file, indent=4)

        if self.cfg.delete_temp_files:
            print(" >> delete temporary files")
            if os.path.exists('./temp_files'):
                shutil.rmtree('./temp_files')

    def _gen_adv_train(self, current_epoch: int, avg_rewards: list) -> list:
        """
        Trainings process of the generator.
        """
        data = DataLoader(dataset=self.dataset_train, batch_size=self.cfg.batch_size, shuffle=True, drop_last=True)
        padding_length = get_index_length_datasets(train=True, code_parrot=self.code_parrot, mbpp=True)
        avg_rewards_during_batches = []
        for batch_number, batch in enumerate(data):
            prompts_tensor = [prompt.to(torch.int64).to(self.device) for prompt in
                              torch.unbind(batch['prompts'], dim=0)]
            prompts_txt = self.tokenizer.batch_decode(prompts_tensor, skip_special_tokens=True)

            generation_tensor = self.ppo_trainer.generate(query_tensor=prompts_tensor, return_prompt=False,
                                                          **self.generation_kwargs)
            generated_txts = self.tokenizer.batch_decode(generation_tensor, skip_special_tokens=True)

            test_list_tensor = [prompt.to(torch.int64).to(self.device) for prompt in
                                torch.unbind(batch['test_list'], dim=0)]
            test_list_txt = self.tokenizer.batch_decode(test_list_tensor, skip_special_tokens=True)

            temp_file_path = []
            for k, sample in enumerate(generated_txts):
                try:
                    header_prompt = "\n# " + prompts_txt[k] + "\n\n"
                    code = codecs.unicode_escape_decode(sample)[0]
                    if not os.path.exists("./temp_files"):
                        os.makedirs("./temp_files")
                    with tempfile.NamedTemporaryFile(delete=False, suffix=".py", dir="./temp_files",
                                                     prefix="Epoch " + str(current_epoch) + "_") as temp_file:
                        temp_file.write(header_prompt.encode('utf-8'))
                        temp_file.write(code.encode('utf-8'))
                        temp_file_path.append(temp_file.name)
                    break
                except UnicodeError:
                    pass
            rewards, avg_rewards_during_batches = get_rewards(generation_tensor=generation_tensor,
                                                              generation_text=generated_txts,
                                                              discriminator=self.discriminator,
                                                              avg_rewards_during_batches=avg_rewards_during_batches,
                                                              tokenizer=self.tokenizer, device=self.device,
                                                              padding_length=padding_length,
                                                              disc_weight=self.cfg.disc_weight,
                                                              temp_file_path=temp_file_path,
                                                              test_list=test_list_txt)
            rewards = [reward.detach().to(self.device) for reward in rewards]
            self.ppo_trainer.step(prompts_tensor, generation_tensor, rewards)
            del prompts_tensor
            del prompts_txt
            del generation_tensor
            del generated_txts
            del rewards
            gc.collect()
        self.generator.save_pretrained(os.path.join(self.cfg.gen_dir, "generator"))
        avg_rewards_during_epoch = mean(avg_rewards_during_batches)  # mean(avg_reward per batch) = avg_reward(epoch)
        avg_rewards.append(avg_rewards_during_epoch)
        return avg_rewards

    def _disc_adv_train(self, current_epoch: int, disc_loss_list: list, disc_acc_list: list) -> tuple[list, list]:
        """
        Trainings process of the discriminator.
        """
        directory_path = "./save/huggan/gen/generator"
        if not os.path.exists(directory_path):
            os.makedirs(directory_path)
            # Copy the base-model to the generator dir for disc pretraining
            source_path_config = "./save/huggan/gen/config.json"
            source_path_generation_config = "./save/huggan/gen/generation_config.json"
            source_path_model = "./save/huggan/gen/model.safetensors"
            shutil.copy(source_path_config, directory_path)
            shutil.copy(source_path_generation_config, directory_path)
            shutil.copy(source_path_model, directory_path)
        else:
            pass

        lm_generator = AutoModelForCausalLMWithValueHead.from_pretrained(
            pretrained_model_name_or_path=self.cfg.gen_dir + "generator").to(self.device)

        dataset_disc = prepare_data(lm_generator, self.dataset_train,
                                    self.cfg.num_samples, self.generation_kwargs,
                                    self.tokenizer, train=True, code_parrot=self.code_parrot)
        data = DataLoader(dataset=dataset_disc, batch_size=self.cfg.batch_size, shuffle=True, drop_last=True)
        total_loss = 0
        total_acc = 0
        total_num = 0
        for batch in data:
            samples = Tensor.int(batch["code"]).to(self.device)
            classes = batch["ground_truth"]

            classes = list(torch.unbind(Tensor.int(classes), dim=0))
            classes_decoded_str = self.tokenizer.batch_decode(classes, skip_special_tokens=True)
            classes_decoded_int = torch.tensor([int(item) for item in classes_decoded_str], device=self.device,
                                               dtype=torch.int64)

            classification = self.discriminator.forward(samples)

            loss = self.adv_loss(classification, classes_decoded_int)
            self._optimize_discriminator(loss)

            total_loss += loss.item()
            total_acc += (classification.argmax(dim=-1) == classes_decoded_int).sum().item()
            total_num += len(classes_decoded_int)
            del samples
            del classes_decoded_int
            gc.collect()
        total_loss /= len(data)
        total_acc /= total_num
        torch.save(self.discriminator.state_dict(), os.path.join(self.cfg.disc_dir, "discriminator.pt"))
        disc_acc_list.append(total_acc)
        disc_loss_list.append(total_loss)
        return disc_loss_list, disc_acc_list

    def _optimize_discriminator(self, loss: _Loss) -> None:
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.cfg.clip_norm)
        self.optimizer.step()

    def eval(self) -> None:
        print(">> begin evaluation")
        print("> start generating mbpp samples")
        mbpp_testcases = []
        for idx in range(self.dataset_test.__len__()):
            mbpp_sample_generieren(idx=idx, device=self.device, dataset_test=self.dataset_test,
                                   generator=self.ppo_trainer, tokenizer=self.tokenizer,
                                   generation_kwargs=self.generation_kwargs, destination_dir_path="mbpp_samples")
            loading_bar(idx, self.dataset_test.__len__())
            mbpp_testcases.append(self.dataset_test.__getitem__(idx)['test_list'].to(torch.int64))
        mbpp_testcases = self.tokenizer.batch_decode(mbpp_testcases, skip_special_tokens=True)
        mbpp_testcases = [ast.literal_eval(item) for item in mbpp_testcases]
        print("\nAll mbpp test sample generated.")

        humaneval_data = load_gen_data(os.path.join(self.cfg.data_dir, "humaneval.json"), self.tokenizer, train=False,
                                       code_parrot=self.code_parrot, mbpp=False)
        print("> start generating humaneval samples")
        human_eval_testcases = []
        for idx in range(humaneval_data.__len__()):
            mbpp_sample_generieren(idx=idx, device=self.device, dataset_test=humaneval_data,
                                   generator=self.ppo_trainer, tokenizer=self.tokenizer,
                                   generation_kwargs=self.generation_kwargs, destination_dir_path="human_eval_samples")
            loading_bar(idx, humaneval_data.__len__())
            human_eval_testcases.append(humaneval_data.__getitem__(idx)['test_list'].to(torch.int64))
        human_eval_testcases = self.tokenizer.batch_decode(human_eval_testcases, skip_special_tokens=True)
        human_eval_testcases = [ast.literal_eval(item) for item in human_eval_testcases]
        print("\nAll human_eval test sample generated.")

        pep8_mbpp, most_common_error_mbpp, compliance_rate_mbpp = eval_pep8("./mbpp_samples")
        pep8_human_eval, most_common_error_humaneval, compliance_rate_humaneval = eval_pep8("./human_eval_samples")
        pass_at_1_mbpp = calculate_pass_at_1_rate(mbpp_testcases, "./mbpp_samples", False)
        pass_at_1_human_eval = calculate_pass_at_1_rate(human_eval_testcases, "./human_eval_samples", False)
        results = {
            'test_dataset': ["mbpp", "human_eval"],
            'pep8_average_error': [pep8_mbpp, pep8_human_eval],
            'most_common_error': [most_common_error_mbpp, most_common_error_humaneval],
            'compliance_rate': [compliance_rate_mbpp, compliance_rate_humaneval],
            'pass_at_1': [pass_at_1_mbpp, pass_at_1_human_eval]
        }
        df = pd.DataFrame(results)
        df.to_csv('results.csv', index=False)


In [None]:
import os.path
import time
import yaml
from argparse import Namespace


def load_config(config_path: str = "config.yaml") -> Namespace:
    """
    Load configuration from a YAML file and return it as a Namespace object.
    """
    with open(config_path, 'r') as f:
        config_dict = yaml.safe_load(f)

    config = Namespace(**config_dict)

    return config


if __name__ == "__main__":

    config = load_config()

    if not os.path.exists(config.gen_dir):
        print(" > makedirs", config.gen_dir)
        os.makedirs(config.gen_dir, exist_ok=True)
    if not os.path.exists(config.disc_dir):
        print(" > makedirs", config.disc_dir)
        os.makedirs(config.disc_dir, exist_ok=True)

    trainer = GANTrainer(config)
    start_time = time.time()

    if config.adversarial:
        trainer.adversarial_train()

    if config.eval:
        trainer.eval()

    end_time = time.time()
    elapsed_time = (end_time - start_time) / 3600
    print(f"Computing time of training: {elapsed_time:.2f} hours")

In [None]:
!zip -r ./human_eval_samples.zip ./human_eval_samples/
!zip -r ./mbpp_samples.zip ./mbpp_samples/
#!zip -r ./temp_files.zip ./temp_files/

from google.colab import files
files.download("/content/human_eval_samples.zip")
files.download("/content/mbpp_samples.zip")
#files.download("/content/temp_files.zip")
files.download("/content/results.csv")

In [None]:
!zip -r ./generator.zip /content/save/huggan/gen/finished_model
files.download("/content/generator.zip")