In [41]:
!pip install Levenshtein




In [42]:
!pip install rouge




In [43]:
import os
import json
import math
import glob
import argparse
from pathlib import Path
from typing import List, Dict, Any
import pandas as pd
import numpy as np
import tree_sitter_java
from tree_sitter import Language, Parser, Query
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList, StoppingCriteria
import Levenshtein
from tqdm import tqdm
import nltk
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge import Rouge
from tqdm.auto import tqdm
tqdm.pandas()
import re
import random
from typing import List, Tuple
random.seed(42)
import gc
from collections import Counter

nltk.download('wordnet')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [45]:
from google.colab import drive


In [46]:
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [47]:
dataset_path = "/content/drive/MyDrive/Personal_RP/heap_Java_sampled_non_english_FIM"


In [48]:
from datasets import load_from_disk, Dataset


In [49]:
ds = load_from_disk(dataset_path)


In [50]:
JAVA_LANGUAGE = Language(tree_sitter_java.language())
parser = Parser(JAVA_LANGUAGE)
COMMENT_JAVA_QUERY = JAVA_LANGUAGE.query(
    """
    (block_comment)       @comment.block
    """
)

In [51]:
import random

In [52]:
def parse_code(input_code):
    encoded_code = bytes(input_code, 'utf8')
    tree = parser.parse(encoded_code)
    root_node = tree.root_node


    captures: dict[str, list] = COMMENT_JAVA_QUERY.captures(root_node)
    # ← this is a bytes object
    results = []

    for cap_name, nodes in captures.items():
        for node in nodes:
            start = node.start_byte
            end = node.end_byte

            c_start = len(encoded_code[:start].decode('utf-8', errors='replace'))
            c_end = len(encoded_code[:end].decode('utf-8', errors='replace'))
            snippet = input_code[c_start:c_end]

            results.append(
                (cap_name, c_start, c_end, snippet)
            )

    return results

In [53]:
!pip install -qU "transformers>=4.41.0" accelerate bitsandbytes sentencepiece safetensors

In [54]:
class Config:
    MODELS = {
        'mellum_base_4b': {
            'path': 'JetBrains/Mellum-4b-base',
            'context_length': 2048,
            'fim_prefix': '<fim_prefix>',
            'fim_suffix': '<fim_suffix>',
            'fim_middle': '<fim_middle>',
        },
        'smol_lm_135m': {
            'path': 'HuggingFaceTB/SmolLM-135M',
            'context_length': 2048,
        },
        'starcoder2_3b': {
            'path': 'bigcode/starcoder2-3b',
            'context_length': 2048,
            'fim_prefix': '<fim_prefix>',
            'fim_suffix': '<fim_suffix>',
            'fim_middle': '<fim_middle>',
        },
    }

In [55]:
def find_comment_body(
    c_start: int,
    c_end:   int,
    comment: str
) -> Tuple[int,int]:
    """
    Given comment = code[c_start:c_end], verify it has an opening
    '/*' or '/**' and a closing '*/'.  If so, return the absolute
    (body_start, body_end) positions in the code string *excluding*
    those markers; otherwise return c_end, c_end.
    """
    # quick strip check
    if not (comment.startswith("/*") and comment.endswith("*/")):
        return c_end, c_end

    # determine how many chars to strip off the front
    if comment.startswith("/**"):
        lead = 3
    else:  # must be "/*"
        lead = 2

    trail = 0

    # compute absolute positions in the full code
    body_start = c_start + lead
    body_end   = c_end   - trail

    # sanity check
    if body_start >= body_end:
        return c_end, c_end

    return body_start, body_end

In [56]:
from typing import Dict, Tuple

def estimate_target_tokens_comment_body(
    ds: Dataset,
    tokenizer_name: str,
    parse_fn,
) -> int:
    """
    Scan the dataset once to estimate mean+2*std of token lengths
    for all parsed snippets (e.g. comments).
    Returns the ceiling of mean + 2*std.
    """
    tok = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
    lengths = []

    for sample in ds:
        code = sample["content"]
        annotations = sample.get("language_detected", [])
        for _, c_start, c_end, _ in parse_fn(code):
            # slice out the per-char language tags for this comment
            comment_ann = annotations[c_start:c_end]
            # drop any “-1” (undetected) labels
            langs = [l for l in comment_ann if l != "-1"]
            if not langs:
                continue

            # find the single most‐common language in the span
            top_lang, _ = Counter(langs).most_common(1)[0]
            if top_lang == "en":
                continue  # skip English comments

            # grab the comment
            comment = code[c_start:c_end]
            comment_body_start, comment_body_end = find_comment_body(c_start, c_end, comment)

            if comment_body_start == comment_body_end:
              # malformed comment like "/*oops" or missing closing "*/"
              continue

            snippet = code[comment_body_start:comment_body_end]
            ids = tok(snippet, return_attention_mask=False)["input_ids"]
            lengths.append(len(ids))

    mean, std = np.mean(lengths), np.std(lengths)
    return int(np.ceil(mean + 2 * std))


def get_context_size(
    *,
    context_length: int,
    target_tokens: int,
    supports_fim: bool
) -> Tuple[int, int]:
    """
    Given a model's context window and a fixed target size,x
    return (prefix_tokens, suffix_tokens).
    If supports_fim=False, suffix_tokens will be 0.
    """
    if supports_fim:
        rem = context_length - target_tokens
        pre = rem // 2
        suf = rem - pre
    else:
        pre, suf = context_length - target_tokens, 0

    return pre, suf


In [57]:
target_size = estimate_target_tokens_comment_body(ds,Config.MODELS['mellum_base_4b']['path'],parse_fn=parse_code)

In [58]:
prefix, suffix = get_context_size(
    context_length = Config.MODELS['mellum_base_4b']['context_length'],
    target_tokens  = target_size,
    supports_fim   = True,
)

In [59]:
class FIMInput:
    def __init__(self, FIM_PREFIX = '<fim_prefix>', FIM_SUFFIX = '<fim_suffix>', FIM_MIDDLE = '<fim_middle>'):
        self.FIM_PREFIX = FIM_PREFIX
        self.FIM_SUFFIX = FIM_SUFFIX
        self.FIM_MIDDLE = FIM_MIDDLE

    def generate(self, query_tuple: Tuple[str, str, str]):
        prefix = query_tuple[0]
        suffix = query_tuple[1]
        middle = query_tuple[2]

        text = self.FIM_PREFIX + prefix + self.FIM_SUFFIX + suffix + self.FIM_MIDDLE
        return text, middle

In [60]:
fim_input = FIMInput()

In [61]:
def make_target_context(ds: Dataset, prefix_size: int, suffix_size: int, parse_fn):

    context_target_pair = []
    for sample in ds:
        code = sample["content"]
        annotations = sample.get("language_detected", [])

        for _, c_start, c_end, _ in parse_fn(code):
            # slice out the per-char language tags for this comment
            comment_ann = annotations[c_start:c_end]

            # drop any “-1” (undetected) labels
            langs = [l for l in comment_ann if l != "-1"]
            if not langs:
                continue

            # find the single most‐common language in the span
            top_lang, _ = Counter(langs).most_common(1)[0]

            if top_lang == "en":
                continue  # skip English comments


            # grab the comment
            comment = code[c_start:c_end]

            comment_body_start, comment_body_end = find_comment_body(c_start, c_end, comment)

            if comment_body_start == comment_body_end:
              # malformed comment like "/*oops" or missing closing "*/"
              continue



            prefix  = code[max(0, comment_body_start - prefix_size) : comment_body_start]
            suffix = code[comment_body_end : min(len(code), comment_body_end + suffix_size)]
            target = code[comment_body_start : comment_body_end]


            if (prefix and suffix) and target:
                context_target_pair.append(fim_input.generate((prefix, suffix, target)))

    return context_target_pair


In [62]:
context_target_pair = make_target_context(ds, prefix, suffix, parse_fn=parse_code)

In [63]:
context_target_pair[6]

('<fim_prefix>package net.cocotea.elysiananime.common.enums;\n\nimport lombok.AllArgsConstructor;\nimport lombok.Getter;\n\n/**<fim_suffix>\n@Getter\n@AllArgsConstructor\npublic enum SexEnum {\n    /**\n     * 未知\n     */\n    UNKNOWN(0, "未知"),\n    /**\n     * 男\n     */\n    MAN(1, "男"),\n    /**\n     * 女\n     */\n    WOMAN(2, "女");\n\n    final Integer code;\n    final String desc;\n}\n<fim_middle>',
 '\n * 性别枚举值\n *\n * @author CoCoTea\n * @version 2.0.0\n */')

In [64]:
class ModelWrapper:
    def __init__(self, config, device, local=False):
        self.device = device
        repo = config['path_local'] if local else config['path']
        self.tokenizer = AutoTokenizer.from_pretrained(str(repo), local_files_only=local)
        self.model = AutoModelForCausalLM.from_pretrained(str(repo), local_files_only=local).to(device)
        self.context_length = config['context_length']

        # initialize FIM tool if tokens provided
        if all(k in config for k in ('fim_prefix', 'fim_suffix', 'fim_middle')):
            self.fim_tool = FIMInput(
                FIM_PREFIX = config['fim_prefix'], FIM_SUFFIX = config['fim_suffix'], FIM_MIDDLE =config['fim_middle']
            )
        else:
            self.fim_tool = None


    def generate(self, input_ids, max_new_tokens: int, attention_mask=None, stopping_criteria=None, pad_token_id=None, eos_token_id=None) -> torch.LongTensor:
        return self.model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            attention_mask=attention_mask,
            # temperature=0.0,
            do_sample = False,
            stopping_criteria=stopping_criteria,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
        )


In [65]:
class StopOnLineGenerated(StoppingCriteria):
    def __init__(self, tokenizer, offset):
        super().__init__()
        self.tokenizer = tokenizer
        self.offset = offset

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        ids = input_ids[0][self.offset:]
        pred_text = self.tokenizer.decode(ids, skip_special_tokens = True)
        if first_line(pred_text):
            return True
        return False

class StopOnCommentGenerated(StoppingCriteria):
    def __init__(self, tokenizer, offset):
        super().__init__()
        self.tokenizer = tokenizer
        self.offset = offset

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        ids = input_ids[0][self.offset:]
        pred_text = self.tokenizer.decode(ids, skip_special_tokens = True)
        return '*/' in pred_text


In [66]:
class PipelineTool:
    def __init__(self, model, fin_mask_token: str = None):
        self.model = model

    def  task_line_completion(self, context: str, line: str, max_tokens: int = 160) -> Tuple[List[int], List[str]]:


        context_tokens = self.model.tokenizer(context, return_tensors='pt').to(self.model.device)
        context_ids = context_tokens['input_ids'][0][-self.model.context_length:].unsqueeze(0)
        context_attention_mask = context_tokens['attention_mask'][0][-self.model.context_length:].unsqueeze(0)


        line_tokenized = self.model.tokenizer(line, return_tensors='pt').to(self.model.device)
        line_ids = line_tokenized['input_ids'][0].unsqueeze(0)
        line_mask = line_tokenized['attention_mask']

        stopping_criterion = StopOnLineGenerated(self.model.tokenizer, context_ids.shape[-1])
        stopping_criteria = StoppingCriteriaList([stopping_criterion])

        gen = self.model.generate(
            input_ids=context_ids,
            attention_mask=context_attention_mask,
            max_new_tokens=max_tokens,
            pad_token_id=self.model.tokenizer.eos_token_id,
            eos_token_id=self.model.tokenizer.pad_token_id,
            stopping_criteria=stopping_criteria
        )

        truth_ids = line_ids[0]

        context_ids = context_ids.detach().cpu()
        context_attention_mask = context_attention_mask.detach().cpu()

        generated_line_ids = gen[0][context_ids.shape[-1]:]

        truth_toks = self.model.tokenizer.convert_ids_to_tokens(truth_ids, skip_special_tokens = True)
        pred_toks = self.model.tokenizer.convert_ids_to_tokens(generated_line_ids, skip_special_tokens = True)

        truth_text = self.model.tokenizer.decode(truth_ids, skip_special_tokens = True)
        pred_text = self.model.tokenizer.decode(generated_line_ids, skip_special_tokens = True)

        gen = gen.detach().cpu()
        line_ids = line_ids.detach().cpu()
        line_mask = line_mask.detach().cpu()

        generated_line_ids = np.array(generated_line_ids.detach().cpu().tolist())
        truth_ids = np.array(truth_ids.detach().cpu().tolist())

        del context_ids
        del context_attention_mask
        del gen
        del line_ids
        del line_mask

        return truth_ids, generated_line_ids, truth_toks, pred_toks, truth_text, pred_text

    def task_fim(self, context: str, target: str, max_tokens:int = 1500):
        context_tokenized = self.model.tokenizer(context, return_tensors = 'pt').to(self.model.device)
        context_ids = context_tokenized['input_ids']
        context_mask = context_tokenized['attention_mask']

        target_tokenized = self.model.tokenizer(target, return_tensors = 'pt').to(self.model.device)
        target_ids = target_tokenized['input_ids']
        target_mask = target_tokenized['attention_mask']

        stopping_criterion = StopOnCommentGenerated(self.model.tokenizer, context_ids.shape[-1])
        stopping_criteria = StoppingCriteriaList([stopping_criterion])

        gen = self.model.generate(
            input_ids=context_ids,
            attention_mask=context_mask,
            max_new_tokens=max_tokens,
            pad_token_id=self.model.tokenizer.eos_token_id,
            eos_token_id=self.model.tokenizer.pad_token_id,
            stopping_criteria=stopping_criteria
        )

        pred_ids = gen[0][context_ids.shape[-1]:]
        pred_toks = self.model.tokenizer.convert_ids_to_tokens(pred_ids, skip_special_tokens = True)
        pred_text = self.model.tokenizer.decode(pred_ids, skip_special_tokens = True)

        truth_ids = target_ids[0]
        truth_toks = self.model.tokenizer.convert_ids_to_tokens(truth_ids, skip_special_tokens = True)
        truth_text = self.model.tokenizer.decode(truth_ids, skip_special_tokens = True)

        pred_ids = np.array(pred_ids.detach().cpu().tolist())
        truth_ids = np.array(truth_ids.detach().cpu().tolist())

        context_ids = context_ids.detach().cpu()
        context_mask = context_mask.detach().cpu()
        target_ids = target_ids.detach().cpu()
        target_mask = target_mask.detach().cpu()
        gen = gen.detach().cpu()

        del context_ids
        del context_mask
        del target_ids
        del target_mask
        del gen

        gc.collect()


        return truth_ids, pred_ids, truth_toks, pred_toks, truth_text, pred_text


In [67]:
class Evaluator:
    @staticmethod
    def sentence_bleu(pred_text: List[str], ref_text: List[str]) -> float:
        """
        Compute the BLEU score for a single sentence.
        Args:
            pred_text (List[str]): List of predicted tokens.
            ref_text (List[str]): List of reference tokens.
        Returns:
            float: BLEU score.
        """
        try:
            chencherry = SmoothingFunction()
            return sentence_bleu([ref_text], pred_text,smoothing_function = chencherry.method2)
        except Exception as e:
            print(f"Error in sentence_bleu: {e}")
            return np.nan

    @staticmethod
    def lev_total(preds: List[str], refs: List[str]) -> float:
        """
        Compute the total Levenshtein distance between two lists of strings. This is the same as edit distance.
        Args:
            preds (List[str]): List of predicted strings.
            refs (List[str]): List of reference strings.
        Returns:
            float: Total Levenshtein distance (normalized per character).
        """
        try:
            length = max(len(preds), len(refs))
            preds_p = preds + [""] * (length - len(preds))
            refs_p  = refs  + [""] * (length - len(refs))

            total_dist = sum(Levenshtein.distance(p, r) for p, r in zip(preds_p, refs_p))
            total_chars = sum(max(len(p), len(r)) for p, r in zip(preds_p, refs_p))
            return total_dist / total_chars if total_chars > 0 else 0.0

        except Exception as e:
            print(f"Error in lev_total: {e}")
            return np.nan

    @staticmethod
    def meteor_score(pred_tokens: List[str], ref_tokens: List[str]) -> float:
        try:
            return meteor_score([ref_tokens], pred_tokens)
        except Exception as e:
            print(f"Error in meteor_score: {e}")
            return np.nan

    @staticmethod
    def rouge_score(pred_text: str, ref_text: str) -> List[float]:
        try:
            rouge = Rouge()
            rouge_scores = rouge.get_scores(pred_text, ref_text)
            res = [
                rouge_scores[0]['rouge-1']['p'],
                rouge_scores[0]['rouge-1']['r'],
                rouge_scores[0]['rouge-1']['f'],
                rouge_scores[0]['rouge-2']['p'],
                rouge_scores[0]['rouge-2']['r'],
                rouge_scores[0]['rouge-2']['f'],
                rouge_scores[0]['rouge-l']['p'],
                rouge_scores[0]['rouge-l']['r'],
                rouge_scores[0]['rouge-l']['f'],
            ]
            return res
        except Exception as e:
            print(f"Error in rouge_score: {e}")
            return [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]

    def exact_match(generated_ids: np.ndarray, ground_truth_ids: np.ndarray) -> float:
        """
        Compute exact match accuracy between two 1D numpy arrays.
        Pads the shorter one with -1 so lengths match.
        Returns a float ∈ [0, 1].
        """
        try:
            if generated_ids.ndim != 1 or ground_truth_ids.ndim != 1:
                raise ValueError("Both inputs must be 1D arrays.")

            len_gt = ground_truth_ids.shape[0]
            len_gen = generated_ids.shape[0]

            if len_gt != len_gen:
                target_len = max(len_gt, len_gen)
                pad_gt = np.full((target_len - len_gt,), -1, dtype=ground_truth_ids.dtype)
                pad_gen = np.full((target_len - len_gen,), -1, dtype=generated_ids.dtype)
                ground_truth_ids = np.concatenate([ground_truth_ids, pad_gt])
                generated_ids = np.concatenate([generated_ids, pad_gen])

            matches = np.sum(ground_truth_ids == generated_ids)
            total = ground_truth_ids.shape[0]

            return matches / total if total > 0 else 0.0
        except Exception as e:
            print(f"Error in exact_match: {e}")
            return np.nan

    @staticmethod
    def get_results(pred_ids: torch.Tensor, pred_tokens: List[str], pred_text: str,
                    ref_ids: torch.Tensor, ref_tokens: List[str], ref_text: str, device: str) -> List[float]:
        """
        Compute evaluation metrics and return as a tuple of 13 floats.
        """
        exact_match = Evaluator.exact_match(pred_ids, ref_ids)
        levenshtein_distance = Evaluator.lev_total(pred_tokens, ref_tokens)
        meteor = Evaluator.meteor_score(pred_tokens, ref_tokens)
        rouge_scores = Evaluator.rouge_score(pred_text, ref_text)
        sentence_bleu = Evaluator.sentence_bleu(pred_tokens, ref_tokens)

        return np.array([
            exact_match,
            sentence_bleu,
            levenshtein_distance,
            meteor,
            *rouge_scores
        ])

    @staticmethod
    def get_default_results() -> List[float]:
        """
        Return a list of default results (all NaN).
        """
        return np.array([np.nan] * 13)

In [None]:
model = ModelWrapper(Config.MODELS["mellum_base_4b"], "cuda")
tool = PipelineTool(model)
evaluator = Evaluator()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
results = []
evaluation_results = []
ctx_len = model.context_length
tok     = model.tokenizer

for i, (context, target) in enumerate(context_target_pair):
    tok_out = tok(context, return_tensors="pt", add_special_tokens=False)
    ids     = tok_out["input_ids"][0]
    used    = ids.size(0)

    target_ids = tok(target, return_attention_mask=False)["input_ids"]
    target_len   = len(target_ids)        # this is the scalar you need


    if used >= ctx_len:
        ids    = ids[-(ctx_len - 1):]
        context = tok.decode(ids, skip_special_tokens=True)
        used   = ids.size(0)

    max_new = min((ctx_len - used), target_len)

    if max_new <= 0:
        # nothing to generate
        continue
    try:
      truth_ids,  pred_ids, truth_tokens, pred_tokens, truth_text, pred_text = \
          tool.task_fim(
              context   = context,
              target      = target,
              max_tokens= max_new,
          )

      eval_res = evaluator.get_results(
                  pred_ids, pred_tokens, pred_text,
                  truth_ids, truth_tokens, truth_text,
                  tool.model.device
              )

      evaluation_results.append(eval_res)
      results.append({
          "context":     context,
          "truth": truth_text,
          "prediction": pred_text,
      })

    except Exception as e:
        print(f"[{i}] generation/eval failed:", str(e))
        evaluation_results.append(Evaluator.get_default_results())
        results.append({
            "prefix":      None,
            "truth_line":  None,
            "prediction":  None,
        })

    if i % 200 == 0:
        torch.cuda.empty_cache()


In [None]:
for r in results[:40]:
    print("─── context ───")
    print(r["context"][-200:])       # show the tail of your context
    print("─── truth ───")
    print(r["truth"].rstrip())
    print("─ prediction ─")
    print(r["prediction"].rstrip())
    print()

In [None]:
import pandas as pd


In [None]:
metric_names = [
    "exact_match", "bleu", "levenshtein", "meteor",
    "rouge1_p","rouge1_r","rouge1_f",
    "rouge2_p","rouge2_r","rouge2_f",
    "rougeL_p","rougeL_r","rougeL_f",
]
metrics_rows = [row.tolist() for row in evaluation_results]

df_main    = pd.DataFrame(results)
df_metrics = pd.DataFrame(metrics_rows, columns=metric_names)

text_path   = "/content/drive/MyDrive/mellum_non_en_fim_with_results.csv"
metric_path = "/content/drive/MyDrive/mellum_non_en_fim_with_metrics.csv"

df_main.to_csv(text_path,   index=False, encoding="utf8")
df_metrics.to_csv(metric_path, index=False, encoding="utf8")

print(f"Saved {len(df_main)} text results")
print(f"Saved {len(df_metrics)} metrics results")