In [42]:
### Imports
import pandas as pd
import numpy as np
import re2 as re

import string

from typing import List

from libsvm.svmutil import *

In [43]:
PART_2_LANGS = [
  "TGLANG_LANGUAGE_C",
  "TGLANG_LANGUAGE_CPLUSPLUS",
  "TGLANG_LANGUAGE_CSHARP",
  "TGLANG_LANGUAGE_CSS",
  "TGLANG_LANGUAGE_DART",
  "TGLANG_LANGUAGE_DOCKER",
  "TGLANG_LANGUAGE_FUNC",
  "TGLANG_LANGUAGE_GO",
  "TGLANG_LANGUAGE_HTML",
  "TGLANG_LANGUAGE_JAVA",
  "TGLANG_LANGUAGE_JAVASCRIPT",
  "TGLANG_LANGUAGE_JSON",
  "TGLANG_LANGUAGE_KOTLIN",
  "TGLANG_LANGUAGE_LUA",
  "TGLANG_LANGUAGE_NGINX",
  "TGLANG_LANGUAGE_OBJECTIVE_C",
  "TGLANG_LANGUAGE_PHP",
  "TGLANG_LANGUAGE_POWERSHELL",
  "TGLANG_LANGUAGE_PYTHON",
  "TGLANG_LANGUAGE_RUBY",
  "TGLANG_LANGUAGE_RUST",
  "TGLANG_LANGUAGE_SHELL",
  "TGLANG_LANGUAGE_SOLIDITY",
  "TGLANG_LANGUAGE_SQL",
  "TGLANG_LANGUAGE_SWIFT",
  "TGLANG_LANGUAGE_TL",
  "TGLANG_LANGUAGE_TYPESCRIPT",
  "TGLANG_LANGUAGE_XML"
]

part_2_mappings = {value: idx for idx, value in enumerate(PART_2_LANGS)}
reverse_part_2_mappings = {idx: value for idx, value in enumerate(PART_2_LANGS)}

In [44]:
### Helper functions and regexes
SPECIAL_SYMBOLS_REGEX = r"([.,;:\\\/{}\[\]\|!\"#\$%&\'\(\)\*\+\-\<\=\>\?@\^\`\)])"
SPECIAL_SYMBOLS_REGEX_2 = r"(\b\w+\b|[.,;:\\\/{}\[\]\|!\"#\$%&\'\(\)\*\+\-\<\=\>\?@\^\`\~)])"

INT_REGEX = "-?\d+"
FLOAT_REGEX = "-?\d*[.,]\d+"
HEX_REGEX = "0[xX]([0-9a-fA-F])+"
OCTAL_REGEX = "0[oO]([0-7])+"
BINARY_REGEX = "0[bB]([01])+"
EXP_REGEX = "-?\d+[eE]-?\d+"

config = [
    {'regex': BINARY_REGEX, 'change_to': '<num_binary>'},
    {'regex': OCTAL_REGEX, 'change_to': '<num_octal>'},
    {'regex': HEX_REGEX, 'change_to': '<num_hex>'},
    {'regex': EXP_REGEX, 'change_to': '<num_exp>'},
    {'regex': FLOAT_REGEX, 'change_to': '<num_float>'},
    {'regex': INT_REGEX, 'change_to': '<num_int>'},
]



def add_spaces(text: str) -> str:
    return re.sub(SPECIAL_SYMBOLS_REGEX, r' \1 ', text)


def tokenize(text: str) -> List[str]:
    results = re.findall(SPECIAL_SYMBOLS_REGEX_2, text)
    return results

leave_only_ascii = lambda text: "".join([symbol for symbol in text if symbol in string.printable]).strip()


def change_nums_to_tokens(config: List, text: str) -> str:

    for config_record in config:
        text = re.sub(config_record['regex'], config_record['change_to'], text)

    return text


def preprocess_text_to_ascii(text: str) -> List[str]:

    # assert all([column in df.columns for column in ["text"]])
    text =  leave_only_ascii(text)
    text = add_spaces(text)
    text = change_nums_to_tokens(config, text)
    
    string_arr = tokenize(text)
    
    return string_arr

from typing import Dict, List, Tuple
def generate_vector(words: List[str], tfidf_mapping: Dict[str, Tuple[float, int]]) -> List[float]:

    n_len_vector = len(tfidf_mapping)
    vector = np.zeros(n_len_vector) # = [0 for _ in range(n_len_vector)]
    for word in words:
        tuple_to_unpack = tfidf_mapping.get(word)
        if tuple_to_unpack:
            tf_idf_val, idx = tuple_to_unpack
            if vector[idx] == 0:
                vector[idx] = tf_idf_val
            else:
                vector[idx] += tf_idf_val

    return vector

def make_tfidf_mapping(tfidf_trained):
    return {
        token: (value, idx) for idx, (token, value) in enumerate(zip(tfidf_trained.get_feature_names_out(), tfidf_trained.idf_))
    }


In [45]:
preprocess_text_to_ascii("<num_int>")

['<', 'num_int', '>']

In [47]:
### load tfidf and svm
import json

TFIDF_PATH = "tfifd_mapping_ascii_p2.json"
MODEL_PATH = "model_libsvm_ascii_2.model"


# reverse_class_mapping = {value: idx for idx, value in class_mapping.items()}


with open(TFIDF_PATH) as f:
    mapping = json.load(f)


model = svm_load_model(MODEL_PATH)

In [48]:
# inference
inference_string = """
public static void main(String[] args) {
    System.out.println("Hello, World!");
}
"""

prepared_tokens = preprocess_text_to_ascii(inference_string)
vector = [generate_vector(prepared_tokens, mapping)]
label, _, _ = svm_predict(y=[1], x=vector, m=model)
label_pred = int(label[0])
label_lang = reverse_class_mapping.get(label_pred)

print(label_lang)

Accuracy = 0% (0/1) (classification)
TGLANG_LANGUAGE_ADA
