### Inference of LaMoE language model.

> This notebook walks through the inference using the trained LaMoE model. It uses top-p sampling and temperature for selecting new token.

**Working:** First, it asks for the user prompt and model generates text and prints sequentially. 

**Note:** To exit the inference, type exit.

In [1]:
from typing import Optional
import os, sys
from pathlib import Path
from tqdm import tqdm
import time

import torch

In [2]:
cd ..

d:\Envs\Projects\Transformer_Decoder\MOE


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
from lamoe.transformer import Transformer
from lamoe.config import ModelArgs
from lamoe.tokenizer import BPE
from lamoe.utils import preprocess_Eng, get_vocab, get_model_info

_CudaDeviceProperties(name='NVIDIA GeForce RTX 3060 Laptop GPU', major=8, minor=6, total_memory=6143MB, multi_processor_count=30, uuid=eafe3e1a-82bb-fec8-2948-094f6277e5f8, L2_cache_size=3MB)


In [4]:
def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:

        probs_sort, probs_idx = torch.sort(probs, dim = -1, descending = True)
        probs_sum = torch.cumsum(probs_sort, dim = -1)
        mask = probs_sum - probs_sort > top_p
        probs_sort[mask] = 0.0
        probs_sort.div_(probs_sort.sum(dim = -1, keepdim = True))
        next_token = torch.multinomial(probs_sort, num_samples = 1)
        next_token = torch.gather(probs_idx, -1, next_token)

        return next_token

In [5]:
tokenizer = BPE()
tokenizer_dict = get_vocab(os.path.join('Saved', 'Tokenizer.json'))

'Saved\Tokenizer.json' exists. Loading dictionary values from 'Saved\Tokenizer.json'.
Size of Vocabulary:  29627


In [6]:
def text_completion(model: Transformer, 
                    args: ModelArgs, 
                    prompts: list[str], 
                    temperature: float = 0.6, 
                    top_p: float = 0.9, 
                    max_gen_len: Optional[int] = None) -> tuple[list, list]:

        if max_gen_len is None:
            max_gen_len = args.max_seq_length - 1

        prompts = [preprocess_Eng(prompt) for prompt in prompts]
        prompt_tokens = [tokenizer.encode(prompt, tokenizer_dict, None) for prompt in prompts]
        batch_size = len(prompt_tokens)
        assert batch_size <= args.max_batch_size, f"Batch size must be less than or equal to {args.max_batch_size}"
        max_prompt_len = max(len(prompt) for prompt in prompt_tokens)
        assert max_prompt_len <= args.max_seq_length, f"Sequence length must be less than or equal to {args.max_seq_length}"
        total_len = min(args.max_seq_length, 64 + max_prompt_len)

        pad_id = tokenizer_dict['vocab_dict'].get(args.pad)
        eos_id = tokenizer_dict['vocab_dict'].get(args.eos)
        tokens = torch.full((batch_size, total_len), pad_id, dtype = torch.long, device = args.device)
        for k, t in enumerate(prompt_tokens):
            tokens[k, :len(t)] = torch.tensor(t, dtype = torch.long, device = args.device)

        eos_reached = torch.tensor([False] * batch_size, device = args.device)
        prompt_tokens_mask = tokens != pad_id

        dot_cycle = ['.', '..', '...', '....'] * total_len

        for i, token in enumerate(tokens):
            with tqdm(total = total_len, desc = "Generating Text", bar_format="{desc} {postfix}") as pbar:
                for pos in range(1, total_len):
                    with torch.no_grad():
                            logits, _ = model.forward(token[pos - 1:pos].unsqueeze(0), start_pos = pos)

                    if temperature > 0:
                        probs = torch.softmax(logits[:, -1] / temperature, dim = -1)
                        next_token = sample_top_p(probs, top_p)
                    else:
                        next_token = torch.argmax(logits[:, -1], dim = -1)

                    next_token = next_token.reshape(-1)

                    next_token = torch.where(prompt_tokens_mask[i, pos], tokens[i, pos], next_token)
                    token[pos] = next_token

                    eos_reached |= (~prompt_tokens_mask[:, pos]) & (next_token == eos_id)
                    if all(eos_reached):
                        break
                    
                    dots = dot_cycle[pos-1 % len(dot_cycle)]
                    pbar.set_postfix_str(dots)
                    time.sleep(0.01)
                    pbar.update(1)

        out_tokens = []
        out_text = []
        for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()):
            if eos_id in current_prompt_tokens:
                eos_idx = current_prompt_tokens.index(eos_id)
                current_prompt_tokens = current_prompt_tokens[:eos_idx]
            out_tokens.append(current_prompt_tokens)
            out_text.append(tokenizer.decode(current_prompt_tokens, tokenizer_dict))
        return (out_tokens, out_text)

In [7]:
model_args = ModelArgs(inference = True)
model_args.vocab_size = len(tokenizer_dict['vocab_dict'])
model_args.aux_loss = False
print("Model Args: ", model_args)

Model Args:  ModelArgs(dim=512, ffn_hidden_dim=2048, n_layers=4, n_heads=8, n_kv_heads=4, vocab_size=29627, norm_eps=1e-05, num_experts=8, k=2, eos='<eos>', pad='<pad>', unk='<unk>', aux_loss=False, aux_loss_coeff=0.01, inference=True, cache=True, max_batch_size=32, max_seq_length=300, device=device(type='cuda'))


In [8]:
model = Transformer(model_args) 

model_path = os.path.join('Saved', "model", "MoE-LM.pth")

try:
        model.load_state_dict(torch.load(model_path, weights_only = True, map_location = 'cpu'), strict = True)
        model.to(model_args.device)
        print("Model is loaded with trained weights successfully.\n")
except Exception as e:
        print(e, f"\n{model_path} is not present. Train the model to get final {model_path}")
        exit(0)

Model is loaded with trained weights successfully.



In [9]:
get_model_info(model, model_args)

-------------------------------------------------------- Model Summary --------------------------------------------------------

Model size: 512.232 MB
Number of Experts: 8 (Loaded) vs 2 (Inference)
Total params: 134.188 M (Loaded)                             || Total params: 58.692 M (Inference)
----------------------------------------------------------------------------------------------------------------------------------
Name          Parameters(M)                                  || Name          Parameters(M)
Embeddings           15.169                                  || Embeddings           15.169
Layers - 4          103.85                                   || Layers - 4           28.354
LLM-Head             15.169                                  || LLM-Head             15.169

Layer: 25.961 M                                              || Layer: 7.087 M
---------------------------------------------------------------------------------------------------------------------------

In [10]:
while True:
    text = input("User: ")
    if text == "exit":
            break
    print("User: ", text)
    prompts = [text]
    out_tokens, out_texts = (text_completion(model, model_args, prompts, temperature = 0.6, top_p = 0.9, max_gen_len = 150))
    assert len(out_texts) == len(prompts)       
    print("Model: ")
    for i in range(len(out_texts)):
        for word in f'{out_texts[i]}':
            print(word, end = '', flush = True)
            time.sleep(0.05)
    print()

User:  String theory states that 


Generating Text , ....

Model: 
S




tring theory states that the forces are considered to be a possible explanation for the weak interaction Conservation of nature. The weak interaction Conservation laws are associated with Maxwells field equations and the weak interaction Conservation laws and the weak interaction Conservation laws under the Lagrangian laws of nature. The twobody problem is that the exact solutions. From the Lagrangian of the Standard Model the Lagrangian and the
User:  Artificial Intelligence is the 


Generating Text , ....


Model: 
Artificial Intelligence is the ultimate goal of understanding the evidence of the cosmological principlet. Theoretical cosmologists led astronomers to the discovery of a void. Using the Copernican Revolution the Copernican Revolution and the Copernican Revolution and the Copernican Revolution and the Copernican Revolution as the Copernican principle is the basis for the Copernican principle for the Copernican principle and the Copernican principle of nature of
