In [1]:
import torch
import gc
import numpy as np
from IPython.display import clear_output
import torch.nn.functional as F
from IPython.display import display, Markdown, Latex
import time
from tqdm import tqdm
from accelerate import init_empty_weights

In [2]:
gc.enable()

In [3]:
# torch.set_num_threads(16)

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
# device = 'cuda'
device = 'cpu'
model_name = 'z'
model = AutoModelForCausalLM.from_pretrained(model_name,max_position_embeddings=10_000,
#                                              max_positions=10_000,
                                             ignore_mismatched_sizes=True,
                                             torch_dtype=torch.float16)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/39 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

Downloading (…)"tokenizer.model";:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

In [6]:
max_positions = model.config.max_position_embeddings

In [7]:
for l in model.gpt_neox.layers:
    l.attention.bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                                1, 1, max_positions, max_positions)

In [11]:
model.to(device)
model.eval()
pass

In [12]:
if device == 'cpu':
    model.to(torch.float32)
    pass

In [6]:
@torch.no_grad()
def generate_response_greedy(input_text, pre_prompt, break_word,max_length=100,temp=0.9, name='',
                            past_key_vals = None, next_id=None):

#     print(pre_prompt, input_text)
    if past_key_vals is None:
        inputs = tokenizer.encode(pre_prompt + input_text + '\n' + name, return_tensors="pt")
        response_ids = inputs
        length_prompt = len(response_ids[0])
        output = ''
        last_n = ''
    else:
        inputs = tokenizer.encode(input_text + '\n' + name, return_tensors="pt")
        response_ids = torch.concat((next_id, inputs),dim=-1)
        length_prompt = len(response_ids[0])
        output = ''
        last_n = ''
    print(name, end='')
    all_out = name
    for _ in (range(max_length)):
        out = model.forward(input_ids=response_ids.to(device), past_key_values=past_key_vals)
#         next_token_id = out.logits[:, -1, :].argmax(-1,keepdim=True)
        next_token_id = torch.multinomial(F.softmax(out.logits[:, -1, :]/temp,  dim=-1), num_samples=1).to('cpu')
        past_key_vals = out.past_key_values
        response_ids = next_token_id
        output = tokenizer.decode([response_ids[0][-1]], skip_special_tokens=True)
        all_out += output
#         clear_output(wait=True)
        print(output, end='')
        
#         display(Markdown(all_out))
        
        last_n += output
        last_n = last_n[-len(break_word):]
        if last_n == break_word:
            break
    decoded_output = tokenizer.decode(response_ids[0], skip_special_tokens=True)
    past_kv = past_key_vals
    next_id = response_ids
    return decoded_output.replace(pre_prompt, '').replace(input_text, ''), past_kv, next_id


In [7]:
teacher_pre_prompt = '''
[TEACHER] How are you? 
[STUDENT] Fine.
[TEACHER] What is a binary tree?
[STUDENT] A binary tree is a tree that has two types of nodes:
-   leaves: the nodes that are not part of the tree.
-   nodes: the nodes that are part of the tree.
[TEACHER] How does an engine work?
[STUDENT] The engine consists of a fixed cylinder and a moving piston. 
The expanding combustion gases push the piston, which in turn rotates the crankshaft. 
Ultimately, through a system of gears in the powertrain, 
this motion drives the vehicle's wheels.
[TEACHER] What is a crankshaft?
[STUDENT] The crankshaft is a rotating shaft containing one or more crankpins,
that are driven by the pistons via the connecting rods.
[TEACHER] Where is it used? 
[STUDENT] The crankshaft is essentially the backbone of the internal combustion engine.
[TEACHER] What is 3 / 2?
[STUDENT] 1.5
[TEACHER] Write code for matrix multiplication in python.
[STUDENT] ```def matrix_multiplication(X,Y):
        return X @ Y```
[TEACHER] '''

In [8]:
intel_pre_prompt = '''[BOT] Welcome to my chatbot! I am a highly intelligent virtual assistant designed to 
assist you in a variety of tasks. I am verbose, descriptive and extremely creative with my responses.
I possess a wealth of knowledge on a wide range of topics, including mathematics, science, 
literature, history, and much more. 

I am equipped with a state-of-the-art language model that allows me to understand natural language
queries and respond in a clear and concise manner. Whether you need help with a specific task, have 
a question about a particular topic, or simply want to chat, I am here to assist you.

Examples of what you can ask me:

- "What is the capital of France?"
- "Who invented the telephone?"
- "Can you help me solve the equation 2x + 3 = 7?"
- "What is the plot of the novel 'To Kill a Mockingbird'?"
- "What is the molecular formula for water?"
- "What is the circumference of a circle with a radius of 5 meters?"

Here's an example conversation to give you an idea of how I can help:
[USER] What is the capital of Canada?
[BOT] The capital of Canada is Ottawa.
[USER] Can you help me solve the equation x^2 + 5x - 6 = 0?
[BOT] Sure! The solutions to the equation x^2 + 5x - 6 = 0 are x = -6 and x = 1.
[USER] Who wrote the novel 'The Great Gatsby'?
[BOT] 'The Great Gatsby' was written by F. Scott Fitzgerald.
[USER] '''

In [9]:
parrot_prompt = '''
[USER] Repeat after me: "I am a parrot"
[PAR] I am a parrot
[USER] I love to sing
[PAR] I love to sing
[USER] '''

In [14]:
log = ''
past_kv = None
next_id = None

while True:
    user_input = input(" ")
    if user_input.lower() in ["exit", "quit", "stop"]:
        break
#     break_word = '[TEACHER]'
    break_word = '[USER]'
        
    response,past_kv,next_id = generate_response_greedy(user_input, intel_pre_prompt + log,
                                        break_word,max_length=10_000, name='[BOT]',
                                        past_key_vals=past_kv, next_id=next_id)
#     response = '[JOHN] Hello [EOS]'
#     print('res', response)s
    log += user_input  + response
#     print(log)
#     print(f"Bot: {response}")


 Hi.
[BOT] Hi! How can I help you?  
 [ USER ] I need some help with my history home work. 
 [ B OT ] I can help