In [1]:
from datasets import load_dataset

ds = load_dataset("google/code_x_glue_ct_code_to_text", "java", cache_dir='saved_data')

In [2]:
# Import necessary libraries
import nltk
from nltk import bigrams, trigrams, ngrams
from collections import defaultdict
from transformers import AutoTokenizer

from tqdm import tqdm
import math

class NGramLM:
    def __init__(self, corpus, model='codellama/CodeLlama-7b-hf'):
        self.corpus = corpus
        self.tokenizer = AutoTokenizer.from_pretrained(model, cache_dir='./tokenizers')
        self.unigram_model, self.bigram_model, self.trigram_model, self.fourgram_model = self.build_ngrams()
        
        
    def build_ngrams(self):
        # Initialize models for n-grams
        unigram_model =  {}
        bigram_model = defaultdict(lambda: defaultdict(lambda: 0))
        trigram_model = defaultdict(lambda: defaultdict(lambda: 0))
        fourgram_model = defaultdict(lambda: defaultdict(lambda: 0))

        # Build n-gram models from the corpus
        for words in tqdm(self.corpus):
            if isinstance(words, str):
                words = [token.replace('▁', '') for token in self.tokenizer.tokenize(words)]
            
            for word in words:
                if word not in unigram_model:
                    unigram_model[word] =0
                else:
                    unigram_model[word] += 1
                
            # Build bigram model
            bi_grams = list(bigrams(words))
            for w1, w2 in bi_grams:
                bigram_model[w1][w2] += 1

            # Build trigram model
            tri_grams = list(trigrams(words))
            for w1, w2, w3 in tri_grams:
                trigram_model[(w1, w2)][w3] += 1

            # Build four-gram model
            four_grams = list(ngrams(words, 4))
            for w1, w2, w3, w4 in four_grams:
                fourgram_model[(w1, w2, w3)][w4] += 1

        # Normalize all n-gram models
        self.normalize_ngrams(bigram_model, trigram_model, fourgram_model)
    
        total_words = sum(list(unigram_model.values()))
        threshold = (max(unigram_model.values()) + min(unigram_model.values())) // 2
        unigram_model = {i:(j/total_words + 1e-10) for i,j in unigram_model.items() if j > threshold}
        return unigram_model, bigram_model, trigram_model, fourgram_model

    # Normalization function for all n-gram models
    def normalize_ngrams(self, *ngram_models):
        for model in ngram_models:
            for context in model:
                total_count = float(sum(model[context].values()))
                for word in model[context]:
                    model[context][word] /= total_count
                    model[context][word] += 1e-10

    # Function to predict the next word using five-gram, four-gram, trigram, or bigram
        # Function to predict the next word
    def predict_next_word(self, w1, w2, w3):
        next_word = self.fourgram_model[(w1, w2, w3)]
        if next_word:
            predicted_word = max(next_word, key=next_word.get)
            return predicted_word
        else:
            next_word = self.trigram_model[(w2, w3)]
            if next_word:
                predicted_word = max(next_word, key=next_word.get)
                return predicted_word
            else:
                next_word = self.bigram_model[w3]
                if next_word:
                    predicted_word = max(next_word, key=next_word.get)
                    return predicted_word
                else:
                    return max(self.unigram_model, key=self.unigram_model.get)
        return "UNK"


    def compute_perplexity(self, test_data, tokenized=True):
        if isinstance(test_data, str):
            if tokenized:
                test_data = test_data.split()
            else:
                test_data = [token.replace('▁', '') for token in self.tokenizer.tokenize(test_data)]
        test_fourgrams = list(ngrams(test_data, 4))
        N = len(test_fourgrams)
        log_prob_sum = 0

        for w1, w2, w3, w4 in test_fourgrams:
            prob = self.fourgram_model[(w1, w2, w3)].get(w4, 'N/A')
            if prob == 'N/A':                
                prob = self.trigram_model[(w2, w3)].get(w4, 'N/A')
            if prob == 'N/A':
                prob = self.bigram_model[w3].get(w4, 'N/A')
            if prob == 'N/A':
                prob = self.unigram_model.get(w4, 1e-10)
            log_prob_sum += math.log(prob, 2)

        # Calculate the perplexity
        perplexity = 2 ** (-log_prob_sum / N)
        return perplexity

    
    
    def predict_sentence(self, test_data, tokenized=True):
        if isinstance(test_data, str):
            if tokenized:
                test_data = test_data.split()
            else:
                test_data = [token.replace('▁', '') for token in self.tokenizer.tokenize(test_data)]
        test_4grams = list(ngrams(test_data, 4))
        N = len(test_4grams)
        log_prob_sum = 0
        
        correct = 0
        for w1, w2, w3, w4 in test_4grams:
            # Get the probability of the trigram (w1, w2, w3)
            next_word = self.predict_next_word(w1,w2,w3)
            if next_word == w4:
                correct += 1
        
        # Calculate the perplexity
        accuracy = correct/len(test_4grams)
        return accuracy
    
    
    def eval_on_corpus(self, test_corpus, tokenized=True):
        if isinstance(test_corpus, str):
            with open(test_corpus, 'r') as f:
                test_corpus = f.readlines()
        
        all_accs = []
        all_ppls = []
        for test_data in tqdm(test_corpus):
            all_ppls.append(self.compute_perplexity(test_data, tokenized=tokenized))
            all_accs.append(self.predict_sentence(test_data, tokenized=tokenized))
            
        avg_acc = sum(all_accs)/len(all_accs)
        avg_ppl = sum(all_ppls)/len(all_ppls)
        
        print("Average accuracy:", avg_acc)
        print("Average perplexity:", avg_ppl)
        return avg_acc,avg_ppl


ngramLM_pre = NGramLM(ds['train']['code_tokens'])
ngramLM_codellama = NGramLM(ds['train']['code'])

100%|█████████████████████████████████████████████████████████████████████████████████| 164923/164923 [00:58<00:00, 2838.97it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 164923/164923 [02:31<00:00, 1090.28it/s]


In [3]:
ngramLM_codellama.eval_on_corpus('all_codes.txt', tokenized=False)

100%|████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 302.84it/s]

Average accuracy: 0.2998862631472445
Average perplexity: 296.52376536921287





(0.2998862631472445, 296.52376536921287)

In [4]:
ngramLM_codellama.eval_on_corpus('all_codes.txt', tokenized=True)

100%|████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 191.41it/s]

Average accuracy: 0.17650285944175992
Average perplexity: 2274776.091388698





(0.17650285944175992, 2274776.091388698)

In [5]:
ngramLM_pre.eval_on_corpus('all_codes.txt', tokenized=False)

100%|█████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:14<00:00,  6.68it/s]

Average accuracy: 0.24366236780749972
Average perplexity: 661483.3497655754





(0.24366236780749972, 661483.3497655754)

In [6]:
ngramLM_pre.eval_on_corpus(ds['test'][:100]['code_tokens'])

100%|█████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:13<00:00,  7.53it/s]

Average accuracy: 0.45521732685478733
Average perplexity: 928.174491875554





(0.45521732685478733, 928.174491875554)

In [16]:
seed = ['', 'if', '{']
for i in range(1):
    out = ngramLM_pre.predict_next_word(seed[-3], seed[-2], seed[-1])
    seed += [out]
print(' '.join(seed))

 if { if ( ! ( ( ( AbstractAttribute ) attribute )


In [None]:
ngramLM.predict_next_word('','(', '!')

In [None]:
ngramLM.trigram_model[('static', 'boolean')]

In [None]:
ngramLM.fourgram_model[('',
              'boolean',
              'check')]