In [1]:
!pip -q install git+https://github.com/huggingface/transformers # need to install from github
!pip -q install datasets sentencepiece 
!pip -q install bitsandbytes==0.38.0.post2 accelerate

In [2]:
!nvidia-smi

Thu May 25 10:25:14 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A10G         On   | 00000000:00:1E.0 Off |                    0 |
|  0%   22C    P8    21W / 300W |      0MiB / 23028MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import re
import sys
import json
import torch
import logging
from random import choice
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

In [4]:
logger = logging.getLogger('api')
logger.setLevel(logging.INFO)

logHandler = logging.StreamHandler(sys.stdout)
logger.addHandler(logHandler)

In [5]:
def setup_model(model_name: str, cache_dir: str = None):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        cache_dir=cache_dir
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map='auto',
        low_cpu_mem_usage=True,
        cache_dir=cache_dir,
    )
    return tokenizer, model

In [6]:
model_name = 'TheBloke/wizardLM-7B-HF'
# model_name = 'ehartford/Wizard-Vicuna-7B-Uncensored'
cache_dir = '/home/ec2-user/SageMaker/.cache'
tokenizer, model = setup_model(model_name, cache_dir)


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so
CUDA SETUP: CUDA runtime path found: /home/ec2-user/anaconda3/envs/pytorch_p39/lib/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 117
CUDA SETUP: Loading binary /home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...


Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
!nvidia-smi

Thu May 25 10:25:23 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A10G         On   | 00000000:00:1E.0 Off |                    0 |
|  0%   25C    P0    61W / 300W |   7967MiB / 23028MiB |     19%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [17]:
TEMPLATE = '''
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
{user_input}
'''

In [140]:
def generate(
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
    prompt: str,
    top_k = 0,
    top_p = 0.95,
    max_new_tokens = 512,
    temperature = 0.1,
):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        gen_tokens = model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens, 
            num_return_sequences=1,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            do_sample=True,
            repetition_penalty=1.15,
        )
    gen_token = choice(gen_tokens)
    generation = tokenizer.decode(gen_token, skip_special_tokens=True)[len(prompt):]
    return generation

## Classification

In [162]:
prompt = '''
This is awesome! // Positive.
This is bad! // Negative.
Wow that movie was rad! // Positive.
What a horrible show! // 
'''.strip()

In [163]:
%%time

logger.debug(prompt)
generation = generate(tokenizer, model, prompt)
logger.info(f'@@gen: {generation}')

@@gen:  Negative.
CPU times: user 561 ms, sys: 6 µs, total: 561 ms
Wall time: 560 ms


## Summarization

In [164]:
prompt = '''
Antibiotics are a type of medication used to treat bacterial infections.
They work by either killing the bacteria or preventing them from reproducing, allowing the body’s immune system to fight off the infection.
Antibiotics are usually taken orally in the form of pills, capsules, or liquid solutions, or sometimes administered intravenously.
They are not effective against viral infections, and using them inappropriately can lead to antibiotic resistance.

Explain the above in one sentence:
'''.strip()

In [165]:
%%time

logger.debug(prompt)
generation = generate(tokenizer, model, prompt)
logger.info(f'@@gen: {generation}')

@@gen:  Antibiotics are medications that combat bacterial infections by killing or hindering their growth, allowing the immune system to fight off the illness.
CPU times: user 4.79 s, sys: 4.01 ms, total: 4.8 s
Wall time: 4.79 s


## Information Extraction

In [166]:
prompt = '''
Author-contribution statements and acknowledgements in research papers should state clearly and specifically whether, and to what extent, the authors used AI technologies such as ChatGPT in the preparation of their manuscript and analysis.
They should also indicate which LLMs were used.
This will alert editors and reviewers to scrutinize manuscripts more carefully for potential biases, inaccuracies and improper source crediting.
Likewise, scientific journals should be transparent about their use of LLMs, for example when selecting submitted manuscripts.

Mention the large language model based product mentioned in the paragraph above:
'''.strip()

In [167]:
%%time

logger.debug(prompt)
generation = generate(tokenizer, model, prompt)
logger.info(f'@@gen: {generation}')

@@gen:  "ChatGPT"
CPU times: user 1.03 s, sys: 0 ns, total: 1.03 s
Wall time: 1.02 s


## Zero-shot prompting

In [168]:
prompt = '''
Classify the text into neutral, negative or positive. 

Text: I think the vacation is okay.
Sentiment:
'''.strip()

In [169]:
%%time

logger.debug(prompt)
generation = generate(tokenizer, model, prompt)
logger.info(f'@@gen: {generation}')

@@gen:  Neutral
CPU times: user 560 ms, sys: 33 µs, total: 560 ms
Wall time: 559 ms


## Few-shot learning
- 3개 이상 shot 을 추가해주면 효과적

In [170]:
instruction='''
The odd numbers in this group add up to an even number: 4, 8, 9, 15, 12, 2, 1.
A: The answer is False.

The odd numbers in this group add up to an even number: 17, 10, 19, 4, 8, 12, 24.
A: The answer is True.

The odd numbers in this group add up to an even number: 16, 11, 14, 4, 8, 13, 24.
A: The answer is True.

The odd numbers in this group add up to an even number: 17, 9, 10, 12, 13, 4, 2.
A: The answer is False.
'''.strip()

user_input = '''
The odd numbers in this group add up to an even number: 15, 32, 5, 13, 82, 7, 1.
A: 
'''.strip()

In [171]:
%%time

prompt = TEMPLATE.format(instruction=instruction, user_input=user_input)
logger.debug(prompt)
generation = generate(tokenizer, model, prompt)
logger.info(f'@@gen: {generation}')

@@gen: The answer is True.
CPU times: user 851 ms, sys: 3.94 ms, total: 855 ms
Wall time: 853 ms


## One-shot CoT (Chain ot thought)

In [189]:
instruction = '''
The odd numbers in this group add up to an even number: 4, 8, 9, 15, 12, 2, 1.
A: Adding all the odd numbers (9, 15, 1) gives 25. The answer is False.
'''.strip()

user_input = '''
The odd numbers in this group add up to an even number: 15, 32, 5, 7, 1.
A: 
'''.strip()

In [190]:
%%time

prompt = TEMPLATE.format(instruction=instruction, user_input=user_input)
logger.debug(prompt)
generation = generate(tokenizer, model, prompt, temperature=0.2)
logger.info(f'@@gen: {generation}')

@@gen: Adding all the odd numbers (15, 32, 5, 7, 1) gives 60. However, since there are no additional details provided, I cannot determine if the statement is true or false.
CPU times: user 6.75 s, sys: 0 ns, total: 6.75 s
Wall time: 6.75 s


### With `Let's think step by step.`

In [191]:
instruction = '''
The odd numbers in this group add up to an even number: 4, 8, 9, 15, 12, 2, 1.
A: Adding all the odd numbers (9, 15, 1) gives 25. The answer is False.
'''.strip()

user_input = '''
The odd numbers in this group add up to an even number: 15, 32, 5, 7, 1.
A: Let's think step by step.
'''.strip()

In [197]:
%%time

prompt = TEMPLATE.format(instruction=instruction, user_input=user_input)
logger.debug(prompt)
generation = generate(tokenizer, model, prompt, temperature=0.4)
logger.info(f'@@gen: {generation}')

@@gen: First, we need to find the sum of all the odd numbers in the list. That would be 15 + 32 + 5 + 7 = 60.
Then, we need to check if 60 is an even number. It is! Therefore, the answer is True.
CPU times: user 8.6 s, sys: 0 ns, total: 8.6 s
Wall time: 8.6 s


# PAL (program aided language)

In [202]:
prompt = '''
# Q: 2015 is coming in 36 hours. What is the date one week from today in MM/DD/YYYY?
# If 2015 is coming in 36 hours, then today is 36 hours before.
today = datetime(2015, 1, 1) - relativedelta(hours=36)
# One week from today,
one_week_from_today = today + relativedelta(weeks=1)
# The answer formatted with %m/%d/%Y is
one_week_from_today.strftime('%m/%d/%Y')

# Q: The first day of 2019 is a Tuesday, and today is the first Monday of 2019. What is the date today in MM/DD/YYYY?
# If the first day of 2019 is a Tuesday, and today is the first Monday of 2019, then today is 6 days later.
today = datetime(2019, 1, 1) + relativedelta(days=6)
# The answer formatted with %m/%d/%Y is
today.strftime('%m/%d/%Y')

# Q: The concert was scheduled to be on 06/01/1943, but was delayed by one day to today. What is the date 10 days ago in MM/DD/YYYY?
# If the concert was scheduled to be on 06/01/1943, but was delayed by one day to today, then today is one day later.
today = datetime(1943, 6, 1) + relativedelta(days=1)
# 10 days ago,
ten_days_ago = today - relativedelta(days=10)
# The answer formatted with %m/%d/%Y is
ten_days_ago.strftime('%m/%d/%Y')

# Q: It is 4/19/1969 today. What is the date 24 hours later in MM/DD/YYYY?
# It is 4/19/1969 today.
today = datetime(1969, 4, 19)
# 24 hours later,
later = today + relativedelta(hours=24)
# The answer formatted with %m/%d/%Y is
today.strftime('%m/%d/%Y')

# Q: Jane thought today is 3/11/2002, but today is in fact Mar 12, which is 1 day later. What is the date 24 hours later in MM/DD/YYYY?
# If Jane thought today is 3/11/2002, but today is in fact Mar 12, then today is 3/1/2002.
today = datetime(2002, 3, 12)
# 24 hours later,
later = today + relativedelta(hours=24)
# The answer formatted with %m/%d/%Y is
later.strftime('%m/%d/%Y')

# Q: Jane was born on the last day of Feburary in 2001. Today is her 16-year-old birthday. What is the date yesterday in MM/DD/YYYY?
# If Jane was born on the last day of Feburary in 2001 and today is her 16-year-old birthday, then today is 16 years later.
today = datetime(2001, 2, 28) + relativedelta(years=16)
# Yesterday,
yesterday = today - relativedelta(days=1)
# The answer formatted with %m/%d/%Y is
yesterday.strftime('%m/%d/%Y')

# Q: Today is 27 February 2023. I was born exactly 25 years ago. What is the date I was born in MM/DD/YYYY?
'''.lstrip()

In [203]:
%%time

logger.debug(prompt)
generation = generate(tokenizer, model, prompt, temperature=0.1)
logger.info(f'@@gen: {generation}')

@@gen: # If today is 27 February 2023 and you were born exactly 25 years ago, then you were born on 27 February 1998.
birthdate = datetime(1998, 2, 27)
# The answer formatted with %m/%d/%Y is
birthdate.strftime('%m/%d/%Y')
CPU times: user 11.8 s, sys: 0 ns, total: 11.8 s
Wall time: 11.8 s


## Generating data

In [208]:
instruction = '''
Produce 10 exemplars for sentiment analysis. Examples are categorized as either positive or negative. Produce 2 negative examples and 8 positive examples. Use this format for the examples:
Q: <sentence>
A: <sentiment>
'''.strip()

user_input = '''
Q: 
'''.strip()

In [209]:
%%time

prompt = TEMPLATE.format(instruction=instruction, user_input=user_input)
logger.debug(prompt)
generation = generate(tokenizer, model, prompt, temperature=0.8)
logger.info(f'@@gen: {generation}')

@@gen: "I am so excited to receive my new job promotion!"
A: Positive

Q: "The traffic on the way to work was terrible today."
A: Negative

Q: "My favorite restaurant closed down unexpectedly."
A: Negative

Q: "The sunset tonight was absolutely breathtaking."
A: Positive

Q: "I got a good deal on a new pair of shoes at the mall today."
A: Positive

Q: "My boss gave me a great compliment during our meeting today."
A: Positive

Q: "I had a really bad cold all week, but I'm feeling better now."
A: Negative

Q: "The weather forecast predicted rain for tomorrow, but it looks like it might clear up later in the day."
A: Neutral
CPU times: user 25.8 s, sys: 3.88 ms, total: 25.8 s
Wall time: 25.8 s


## Generate Code

In [206]:
prompt = '''
SELECT students.StudentId, students.StudentName
FROM students
INNER JOIN departments
ON students.DepartmentId = departments.DepartmentId
WHERE departments.DepartmentName = 'Computer Science';

Explain the above SQL statement.
'''.strip()

In [207]:
%%time

logger.debug(prompt)
generation = generate(tokenizer, model, prompt, temperature=0.1)
logger.info(f'@@gen: {generation}')

@@gen:  
This SQL statement selects the StudentID and StudentName from the Students table, and then joins it with the Departments table on the DepartmentID field. The WHERE clause filters the results to only include rows where the DepartmentName is "Computer Science". This query will return a list of all Computer Science students along with their names.
CPU times: user 8.93 s, sys: 3.96 ms, total: 8.93 s
Wall time: 8.93 s
