In [1]:
#| echo: false
#| output: false
%env HF_DATASETS_CACHE=/pscratch/sd/a/azaidi/llm/cache
%env HF_HOME=/pscratch/sd/a/azaidi/llm/cache

env: HF_DATASETS_CACHE=/pscratch/sd/a/azaidi/llm/cache
env: HF_HOME=/pscratch/sd/a/azaidi/llm/cache


In [2]:
#| echo: false
#| output: false
from transformers import AutoTokenizer, AutoModelForCausalLM
import requests
import glob

In [3]:
def get_models(parent_name):
    cache_dir = '/pscratch/sd/a/azaidi/llm/cache/models--'
    models = [path.split('--')[-1] for path in \
              glob.glob(cache_dir + parent_name)]
    print(f'the {parent_name} models are: {models}')

In [4]:
get_models('mistralai*'), get_models('meta-llama*'); 

the mistralai* models are: ['Mistral-7B-v0.1', 'Mixtral-8x7B-Instruct-v0.1', 'Mistral-7B-Instruct-v0.2']
the meta-llama* models are: ['Llama-2-7b-hf', 'Llama-2-70b-hf', 'Llama-2-70b-chat-hf']


In [5]:
#| echo: false
#| output: false
def call_dce(url=None, pmid=None, context=False):
    if pmid is None: pmid = 'PMC6000758'
    if url is None: 
        if context: url = "https://metadata.namesforlife.com/context?pmcid=" + pmid
        else: url = "https://metadata.namesforlife.com/reference?pmcid=" + pmid
    return requests.get(url).text, pmid

In [6]:
#| echo: false
#| output: false
txt, pmid = call_dce(context=True)
len(txt), pmid

(37593, 'PMC6000758')

In [7]:
#| echo: false
#| output: false
def get_assets(model_name, four_bit=False, eigth_bit=False):
    cache_dir = '/pscratch/sd/a/azaidi/llm/cache'
    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                            cache_dir=cache_dir,
                                            load_in_8bit=eigth_bit,
                                            load_in_4bit=four_bit)

    model = AutoModelForCausalLM.from_pretrained(model_name,
                                            cache_dir=cache_dir,
                                            load_in_4bit=four_bit,
                                            load_in_8bit=eigth_bit
                                            );
    return tokenizer, model

In [9]:
#| echo: false
#| output: false
tokenizer, model = get_assets("mistralai/Mixtral-8x7B-Instruct-v0.1", four_bit=True)
#tokenizer, model = get_assets("mistralai/Mistral-7B-Instruct-v0.2", four_bit=False)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [23]:
#| echo: true
#| output: true
model

MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear4bit(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): Linear4bit(in_features=4096, out_features=14336, bias=False)
              (w2): Linear4bit(in_features=14336, out_features=4096, bias=False)
              (w3): Linear4bit(in_features=4096, out_

In [10]:
#model.to('cuda');

In [11]:
#| echo: false
#| output: false

def get_response(prompt, new_tokens=250, rep_penalty=1.0, length_penalty=1.0, device=None):
    if model.device.type == 'cuda': device = 'cuda'
    else: device = 'cpu'
    inputs = tokenizer(prompt,return_tensors="pt").to(device)
    output = tokenizer.batch_decode(model.generate(inputs=inputs['input_ids'],
                                                   max_new_tokens=new_tokens,
                                                   repetition_penalty=rep_penalty,
                                                   length_penalty=length_penalty,
                                                   pad_token_id=tokenizer.eos_token_id), 
                                    skip_special_tokens=True, 
                                    clean_up_tokenization_spaces=False)[0]
    print(output)

In [12]:
len(txt), tokenizer(txt).input_ids.__len__()

(37593, 12299)

In [13]:
prompt = "The ostrich shit its pants while walking across the street. But, the Tyrannosaurus rex was one of the most ferocious predators to ever walk the Earth. With a massive body, sharp teeth, and jaws so powerful they could crush a car, this famous carnivore dominated the forested river valleys in western North America during the late Cretaceous period, 68 million years ago. \n \n"

In [14]:
get_response(prompt=f'what was the previous text about? \n{prompt}')



what was the previous text about? 
The ostrich shit its pants while walking across the street. But, the Tyrannosaurus rex was one of the most ferocious predators to ever walk the Earth. With a massive body, sharp teeth, and jaws so powerful they could crush a car, this famous carnivore dominated the forested river valleys in western North America during the late Cretaceous period, 68 million years ago. 
 

The previous text was about the Tyrannosaurus rex, a large carnivorous dinosaur that lived during the late Cretaceous period. The text describes its massive body, sharp teeth, and powerful jaws, and mentions that it was one of the most ferocious predators to ever walk the Earth. It does not mention the ostrich or its pants.


In [15]:
txt, pmid = call_dce(context=True)
len(txt), pmid

(37593, 'PMC6000758')

In [29]:
#tokenizer.encode([1,])

In [31]:
tokenizer(txt,return_tensors="pt")

{'input_ids': tensor([[    1,   415, 13993,  ...,   302,  2145, 28723]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]])}

In [19]:
def build_prompt(questions, input_text):
    return f"[INST] You are a microbiology journal editor. Please answer as concisely as possible: {questions} \n here is the text in question: {input_text} \n [/INST]"

In [20]:
get_response(prompt=build_prompt(questions="can you provide the exact quote surrounding a mention of GCA_00092485.1?",
             input_text=call_dce(context=True)[0][:3700],),
             new_tokens=1000)

[INST] You are a microbiology journal editor. Please answer as concisely as possible: can you provide the exact quote surrounding a mention of GCA_00092485.1? 
 here is the text in question: The ocean covers 71% of the earth’s surface and is regarded as the largest habitat for life on the planet Earth.
Marine microorganisms are known to play an essential role in energy conservation and biogeochemical cycling in the oceans.
Heterotrophic prokaryotes are considered key players in the decomposition of the dissolved organic matter (DOM) and particulate organic matter (POM) present therein (DeLong and Karl, 2005; Azam and Malfatti, 2007).
Oceanospirillales is an order of proteobacteria with seven families of heterotrophic marine bacteria that are usually associated with oil spills and are known to be involved in xylan and hydrocarbon utilization (Choi et al., 2012; Cao et al., 2014).
Recently, a new family named Kangiellaceae has been proposed within the order Oceanospirillales based on phy

In [83]:
#get_response(prompt=f'can you provide the exact quote surrounding a mention of GCA_000024085.1? Please provide the sentence before and after. \n{txt} -------\n\n')

In [None]:
#mistral prompt format
# <s>[INST] Instruction [/INST] Model answer</s>[INST] Follow-up instruction [/INST]

In [None]:
#prompt = f'<s>[INST] Can you summarize this text? {sample_text} [/INST]'
#prompt = f'<s>[INST] {text} [/INST]'

In [20]:
prompt = "can you tell me about yourself?"

In [21]:
model_inputs = tokenizer([prompt], return_tensors="pt").to('cuda');

In [22]:
len(model_inputs.input_ids[0])

8

In [24]:
generated_ids = model.generate(**model_inputs, max_new_tokens=250,)#do_sample=True)
tokenizer.batch_decode(generated_ids)[0]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'<s> can you tell me about yourself?\n\nI’m a 20-year-old college student from the United States. I’m currently studying at the University of California, Berkeley. I’m a student of the arts, and I’m currently studying at the University of California, Berkeley. I’m a student of the arts, and I’m currently studying at the University of California, Berkeley. I’m a student of the arts, and I’m currently studying at the University of California, Berkeley. I’m a student of the arts, and I’m currently studying at the University of California, Berkeley. I’m a student of the arts, and I’m currently studying at the University of California, Berkeley. I’m a student of the arts, and I’m currently studying at the University of California, Berkeley. I’m a student of the arts, and I’m currently studying at the University of California, Berkeley. I’m a student of the arts, and I’m currently studying at the University of California, Berkeley. I’m a student of the arts, and I’m currently studying at the

In [14]:
#model_inputs.input_ids
token_dict = dict(zip(tokenizer.vocab.values(), tokenizer.vocab.keys()))

In [15]:
import pandas as pd

In [16]:
model_inputs.input_ids.to('cpu').shape

torch.Size([1, 20533])

In [17]:
sample_text_tokens = pd.Series(model_inputs.input_ids.to('cpu').squeeze()).map(token_dict)
sample_text_tokens.head(5)

0     <s>
1     <s>
2      ▁[
3    INST
4       ]
dtype: object

In [20]:
sample_text_tokens[:50]

0         <s>
1         <s>
2          ▁[
3        INST
4           ]
5        ▁Can
6        ▁you
7        ▁sum
8           m
9         ara
10        ize
11      ▁this
12      ▁text
13          ?
14          ▁
15     <0x0A>
16         ▁▁
17     <0x0A>
18         ▁▁
19     <0x0A>
20       ▁▁▁▁
21     <0x0A>
22      ▁▁▁▁▁
23     ▁Front
24     ▁Micro
25         bi
26         ol
27     <0x0A>
28      ▁▁▁▁▁
29     ▁Front
30     ▁Micro
31         bi
32         ol
33     <0x0A>
34      ▁▁▁▁▁
35     ▁Front
36          .
37     ▁Micro
38         bi
39         ol
40          .
41     <0x0A>
42     ▁▁▁▁▁▁
43     <0x0A>
44    ▁▁▁▁▁▁▁
45     ▁Front
46       iers
47        ▁in
48     ▁Micro
49         bi
dtype: object