In [1]:
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git

In [2]:
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer, pipeline

2024-07-15 03:26:30.535572: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-15 03:26:30.535698: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-15 03:26:30.825447: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
class km_checker:
    
    def __init__(self):
        
        self.double_quant_nf4_config = BitsAndBytesConfig(
            load_in_4bit = True,
            bnb_4bit_quant_type = "nf4",
            bnb_4bit_use_double_quant=True,
        )
        
        self.model = AutoModelForCausalLM.from_pretrained(
            "mlao01/km-checker",
            quantization_config = self.double_quant_nf4_config,
            device_map = 'auto',
            trust_remote_code = True
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            "mlao01/km-checker"
        )
        
        # Initialize the Hugging Face pipeline for text generation
        self.generator = pipeline(
            'text-generation',
            model = self.model,
            tokenizer = self.tokenizer,
            temperature=0.001,
        )
        
        # model tuned w this system message so do not change!
        self.__system_message__ = "Correct the spelling errors in the given Khmer text"
        
    def __generate_prompt_input__(self, misspelled):
        
        prompt_string = f"### system\n{self.__system_message__}\n### user\n{misspelled}"
        prompt_string += f"\n### assistant\n"
        
        return prompt_string
    
    def check(self, misspelled):
        
        OUTPUT_LENGTH = 1.05
        
        response = self.generator(
            self.__generate_prompt_input__(misspelled),
            max_new_tokens = int(OUTPUT_LENGTH*len(self.tokenizer(misspelled)['input_ids'])),
            return_full_text = False,
        )
        
        return response[0]['generated_text']

In [4]:
checker = km_checker()

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/27.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/7 [00:00<?, ?it/s]

model-00001-of-00007.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00007.safetensors:   0%|          | 0.00/4.78G [00:00<?, ?B/s]

model-00003-of-00007.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00004-of-00007.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00005-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00006-of-00007.safetensors:   0%|          | 0.00/3.66G [00:00<?, ?B/s]

model-00007-of-00007.safetensors:   0%|          | 0.00/2.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/248 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.30k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/80.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/367 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [18]:
misspelleds = [
    "បន្ទាប់មកបន្ថែមវីតាមីនគ្រាប់ចយូល",
    "កូន៉េខាងជើងប្រកាសឈប់ពិាក្សាជាមួយាមេរិក",
    "លោកសរីថ្លែងមាលោកបានឲ្ដឹថាពួកគេចឹទបាឃ់ត្រូវត្រឡប់ទៅប្រទេសចិនវិញ។"
]

In [19]:
%%time

for misspelled in misspelleds:
    
    corrected = checker.check(misspelled)
    
    print(f'Misspelled \t : {misspelled}')
    print(f'Corrected  \t : {corrected}')
    print('\n\n')

Misspelled 	 : បន្ទាប់មកបន្ថែមវីតាមីនគ្រាប់ចយូល
Corrected  	 : បន្ទាប់មកបន្ថែមវីតាមីនគ្រាប់ចូល៕



Misspelled 	 : កូន៉េខាងជើងប្រកាសឈប់ពិាក្សាជាមួយាមេរិក
Corrected  	 : កូរ៉េខាងជើងប្រកាសឈប់ពិភាក្សាជាមួយអាមេរិក



Misspelled 	 : លោកសរីថ្លែងមាលោកបានឲ្ដឹថាពួកគេចឹទបាឃ់ត្រូវត្រឡប់ទៅប្រទេសចិនវិញ។
Corrected  	 : លោកស្រីថ្លែងថាលោកបានឲ្យដឹងថាពួកគេចឹងតែបាត់ត្រូវត្រឡប់ទៅប្រទេសចិនវ�



CPU times: user 17.6 s, sys: 319 ms, total: 17.9 s
Wall time: 17.9 s
