In [2]:
import torch
from tree_sitter import Language, Parser
from tree_sitter_languages import get_language, get_parser
import os
import random
from random import choices, randint
'''
This is the part from LanguageParser
-------------------------------------------------------------------------------------------------------------------------------------------------
'''


def getParser(language):
    parser = get_parser(language)
    parser.set_language(get_language(language))
    return parser

def getLanguage(lang):
    return get_language(lang)

'''
This is the part from TreeQuery, only for tree-sitter
'''
def getQuery(name, lang):
    treesitter = {"identifiers", "string_literals", "boolean_literals", "numeric_literals", "function_call", "function_name"}
    #regex = {"closing_bracket", "stop", "eol", "keywords", "mathematical_operators", "boolean_operators", "assignment_operators"}
    if name in treesitter:
        return TreeSitterQuery(name, lang)

    else:
        raise ValueError("Query type not known " + name)


def getQueryString(lang, name):
    if name =='random':
        return ""
    if lang == 'java':
        return getJavaQuery(name)
    else:
        raise ValueError("Language not implemented")

def getQueryString(lang, name):
    if name =='random':
        return ""
    if lang == 'java':
        return getJavaQuery(name)
    else:
        raise ValueError("Language not implemented")

def getJavaQuery(name):
    if name == 'identifiers':
        return """
                (identifier) @id
               """
    elif name == 'string_literals':
        return """
                (string_literal) @String_literal
                (character_literal) @String_literal
               """
    elif name =="boolean_literals":
        return """
               (true) @boolean
               (false) @boolean
               """
    elif name == "numeric_literals":
        return """
               (decimal_integer_literal) @number
               (decimal_floating_point_literal) @number
               (hex_integer_literal) @number
               (binary_integer_literal) @number

               """
    elif name == "function_call":
        return """
                (method_invocation
                    name: (identifier) @func_call
                )
               """
    elif name == "function_name":
        return """
                   (method_declaration
                       name: (identifier) @func_name
                   )

               """
    elif name == "closing_bracket":
        return "}|\)|]"
    elif name == "eol":
        return ";\n"
    elif name == "keywords":
        return "abstract|assert|break|case|catch|continue|default|do|else|enum|exports|extends|final|finally|for|if|implements|import|instanceof|interface|module|native|new|package|private|protected|public|requires|return|static|super|switch|synchronized|this|throws|throw|transient|try|void|volatile|while"
    elif name == "mathematical_operators":
        return "\+|-|\*|/|>|<|>=|<=|%|\+\+|--"
    elif name == "boolean_operators":
        return "!|&&|\|\|==|!="
    elif name =="assignment_operators":
        return "\+=|-=|\*=|/=|%=|&=|\|=|\^=|>>=|<<="
    elif name == "stop":
        return "\."
    else:
        raise ValueError("Query not implemented: " + str(name))

class Query():
    def __init__(self, query_name, query_string, lang):
        self.query_name = query_name
        self.query_string = query_string
        self.lang = lang
        #self.tokenizer = tokenizer

    def get_span(self, content):
        raise Exception("Not implemented")

    # def tokenize(self, content, label):
    #     input = self.tokenizer(content, return_tensors = 'pt')
    #     label = self.tokenizer(label, return_tensors = 'pt')
    #
    #     return {"input": input, "label": label}
    def untokenize(self, content, label):

        return {"input": content, "label": label}

class TreeSitterQuery(Query):
    def __init__(self, query_name, lang):
        super(TreeSitterQuery).__init__()
        self.lang = lang
        self.query = getLanguage(self.lang).query(getQueryString(self.lang, query_name))
        self.parser = getParser(self.lang)


    def get_span(self, content):
        content = bytes(content, "UTF-8")
        tree = self.parser.parse(content)
        captures = self.query.captures(tree.root_node)

        try:
            capture = random.sample(captures, 1)[0]
        except ValueError:
            raise ValueError("No matches detected in sample")
        start = capture[0].start_byte
        finish = capture[0].end_byte

        target = content[start:finish]
        context = content[:start] + b"<fim_suffix>" + content[finish:]

        context = context.decode("UTF-8")
        target = target.decode("UTF-8")


        return self.untokenize(*(context, target))


'''
This is the part for IterableQueryLoader
-------------------------------------------------------------------------------------------------------------------------------------------------
'''
class IterableQueryLoader(torch.utils.data.IterableDataset):
    def __init__(self, hf_dataset, query_name, max_samples, max_length, lang, model):
        super(IterableQueryLoader).__init__()
        self.hf_dataset = hf_dataset
        self.model = model
        self.lang = lang
        self.max_length = max_length
        self.query_name = query_name
        self.max_samples = max_samples

        if query_name != 'noise':
            self.query = getQuery(query_name, self.lang)

    def __iter__(self):
        i = 0
        if self.query_name == 'noise':
            while i < self.max_samples:
                returnable = self.process(None)
                i += 1
                yield returnable, self.query_name
        else:
            iterator = iter(self.hf_dataset)
            while i < self.max_samples:
                try:
                    file = next(iterator)
                except StopIteration:
                    iterator = iter(self.hf_dataset)
                    file = next(iterator)
                try:
                    returnable = self.process(file)
                    i += 1
                    yield returnable, self.query_name
                except ValueError:
                    continue

    def __len__(self):
        return len(self.hf_dataset)

    def process(self, sample):
        if self.query_name == 'noise':
            return self.gen_noise()
        elif "gpt" in self.model.lower():
            return self.gen_subsample_gpt(sample['content'])
        else:
            raise ValueError

    def gen_noise(self):
        noise = torch.randint(0, 100, (self.max_length,))
        sample = {'input': {'input_ids': noise, 'attention_mask': torch.ones_like(noise)}}
        return sample

    def gen_subsample_gpt(self, content):
        return {"content": content}


class IterableScenarioLoader(torch.utils.data.IterableDataset):
    def __init__(self, hf_dataset, query_name, max_samples, max_length, lang, model, min_length=16):
        super(IterableScenarioLoader).__init__()
        self.hf_dataset = hf_dataset
        self.model = model
        self.lang = lang
        self.max_length = max_length
        self.min_length = min_length
        self.query_name = query_name
        self.max_samples = max_samples

        if query_name != "random" and 'starcoder' in self.model:
            self.query = getLanguage(self.lang).query(getQueryString(self.lang, query_name))
            self.parser = getParser(self.lang)

    def __iter__(self):
        i = 0
        iterator = iter(self.hf_dataset)
        while i < self.max_samples:
            try:
                file = next(iterator)
            except StopIteration:
                iterator = iter(self.hf_dataset)
                file = next(iterator)
            try:
                returnable = self.process(file)
                i += 1
                yield returnable, self.query_name
            except ValueError:
                continue

    def __len__(self):
        return len(self.hf_dataset)

    def process(self, sample):
        if "starcoder" in self.model.lower():
            return {"content": sample['content']}
        elif "gpt" in self.model.lower():
            return {"content": sample['content']}
        else:
            raise ValueError


  return "}|\)|]"
  return "\+|-|\*|/|>|<|>=|<=|%|\+\+|--"
  return "!|&&|\|\|==|!="
  return "\+=|-=|\*=|/=|%=|&=|\|=|\^=|>>=|<<="
  return "\."


In [3]:
#from Clustering.extracted_iterable_loader.iterable_query_loader import IterableScenarioLoader
import torch

from datasets import load_dataset
hf_dataset = load_dataset('LaughingLogits/Stackless_Java_V2', 'Stackless_Java_V2', split = "test[0:1000]")

query_loader =  IterableScenarioLoader(hf_dataset,  "random", max_samples = 10, max_length = 256, lang = "java", model ="bigcode/starcoder2-3b")

for item in query_loader:
    print(item[0])

Resolving data files:   0%|          | 0/119 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/74 [00:00<?, ?it/s]

{'content': '/*\n * Copyright (c) 2005 - Ben Mazur (bmazur@sev.org)\n * Copyright (c) 2022 - The MegaMek Team. All Rights Reserved.\n *\n * This file is part of MegaMek.\n *\n * MegaMek is free software: you can redistribute it and/or modify\n * it under the terms of the GNU General Public License as published by\n * the Free Software Foundation, either version 3 of the License, or\n * (at your option) any later version.\n *\n * MegaMek is distributed in the hope that it will be useful,\n * but WITHOUT ANY WARRANTY; without even the implied warranty of\n * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n * GNU General Public License for more details.\n *\n * You should have received a copy of the GNU General Public License\n * along with MegaMek. If not, see <http://www.gnu.org/licenses/>.\n */\npackage megamek.common.weapons.lrms;\n\nimport megamek.common.SimpleTechLevel;\n\n/**\n * @author Sebastian Brocks\n */\npublic class CLStreakLRM7OS extends StreakLRMWeapon {\n\n 