In [4]:
from torch import cuda, bfloat16
import transformers

model_id = 'meta-llama/Llama-2-7b-chat-hf'

device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

# bnb_config = transformers.BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type='nf4',
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=bfloat16
# )

# Need auth token for these
hf_auth = '<Token ID>'
model_config = transformers.AutoConfig.from_pretrained(
    model_id,
    use_auth_token=hf_auth
)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    config=model_config,
    # quantization_config=bnb_config,
    device_map='auto',
    use_auth_token=hf_auth
)
model.eval()
print(f"Model loaded on {device}")

Downloading (…)lve/main/config.json: 100%|██████████| 635/635 [00:00<00:00, 37.0kB/s]
Downloading shards: 100%|██████████| 2/2 [00:00<00:00,  4.92it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:13<00:00,  6.68s/it]
Downloading (…)neration_config.json: 100%|██████████| 167/167 [00:00<00:00, 15.4kB/s]

Model loaded on cuda:0





In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_id,
    use_auth_token=hf_auth
)

Downloading (…)/main/tokenizer.json: 100%|██████████| 1.84M/1.84M [00:00<00:00, 2.10MB/s]


In [6]:
generate_text = transformers.pipeline(
    model=model, 
    tokenizer=tokenizer,
    return_full_text=True,  # langchain expects the full text
    task="text-generation",
    # we pass model parameters here too
    #stopping_criteria=stopping_criteria,  # without this model rambles during chat
    temperature=0.0,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
    max_new_tokens=1024,  # max number of tokens to generate in the output
    repetition_penalty=1.1  # without this output begins repeating
)

In [7]:
from langchain import HuggingFacePipeline
llm = HuggingFacePipeline(pipeline=generate_text)

In [8]:
from langchain import SQLDatabase, SQLDatabaseChain
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [9]:
db_chain.run("How many people are in Employee table?")



[1m> Entering new SQLDatabaseChain chain...[0m
How many people are in Employee table?
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM Employee;[0m
SQLResult: [33;1m[1;3m[(8,)][0m
Answer:[32;1m[1;3mThere are 8 employees in the Employee table.[0m
[1m> Finished chain.[0m


'There are 8 employees in the Employee table.'

Input: "How many albums by Aerosmith?

SQLQuery: SELECT COUNT(*) FROM "Album" WHERE "ArtistId" = 2; `wrong`

Correct Query: SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Aerosmith')

In [10]:
db_chain.run("How many albums by Aerosmith?")



[1m> Entering new SQLDatabaseChain chain...[0m
How many albums by Aerosmith?
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM "Album" WHERE "ArtistId" = 2;[0m
SQLResult: [33;1m[1;3m[(2,)][0m
Answer:[32;1m[1;3mThere are 3 albums by Aerosmith.[0m
[1m> Finished chain.[0m


'There are 3 albums by Aerosmith.'

Correct Query:     

SELECT Track.Name
    FROM Track
    INNER JOIN Album ON Track.AlbumId = Album.AlbumId
    INNER JOIN Artist ON Album.ArtistId = Artist.ArtistId
    WHERE Artist.Name = 'Led Zeppelin'
    LIMIT 10

In [14]:
db_chain.run("Provide 10 tracks by Led Zeppelin")



[1m> Entering new SQLDatabaseChain chain...[0m
Provide 10 tracks by Led Zeppelin
SQLQuery:



[32;1m[1;3mSELECT * FROM "Track" WHERE "MediaTypeId" = (SELECT "MediaTypeId" FROM "MediaType" WHERE "Name" = 'Led Zeppelin');[0m
SQLResult: [33;1m[1;3m[][0m
Answer:[32;1m[1;3mNone[0m
[1m> Finished chain.[0m


'None'