In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import transformers
import textwrap
from transformers import LlamaTokenizer, LlamaForCausalLM
import os
import sys
from typing import List

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
)

import fire
import torch
from datasets import load_dataset
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from pylab import rcParams

from transformers.generation.utils import GreedySearchDecoderOnlyOutput
from peft import PeftModel
from transformers import GenerationConfig
import json

%matplotlib inline
sns.set(rc={'figure.figsize':(10, 7)})
sns.set(rc={'figure.dpi':100})
sns.set(style='white', palette='muted', font_scale=1.2)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [4]:
# read the target api dataset
with open('../../data/standard_semantic_features.json', 'r') as jsonfile:
    format_venue_text = json.load(jsonfile)

In [5]:
def get_keys_with_true_values(d, prefix=''):
    keys_with_true_values = []
    
    for key, value in d.items():
        if isinstance(value, dict):
            # Recursively traverse nested dictionaries
            keys_with_true_values.extend(get_keys_with_true_values(value, prefix + key + '.'))
        elif value is True:
            # If the value is True, add the key to the list
            keys_with_true_values.append(prefix + key)
    
    return keys_with_true_values

In [6]:
prompt_venue_text = {}
for venueid, info in format_venue_text.items():
    prompt_str = ""
    prompt_str += f"Venue Name: {info['name']}.\n"
    
    if info['category']:
        categories = ', '.join(info['category'])
        prompt_str += f"Venue Category: {categories}.\n"
        
    if info['description']:
        prompt_str += f" Venue Short Description: {info['description']}.\n"
        

    if info['features']:
        features = get_keys_with_true_values(info['features'], prefix='')
        features = [f for f in features if 'payment' not in f]
        features = [f.split('.')[-1] for f in features]
        features = ', '.join(features)
        
        if len(features):
            prompt_str += f"The Features: {features}\n"
            
    if info['tips']:
        tips = [f" {i + 1}. {venue['text']}" for i, venue in enumerate(info['tips'])]
        tips_str = '\n'.join(tips)
        prompt_str += f" The reviews of customers are:\n{tips_str}."
          
            
    new_info = {}
    new_info['instruction'] = """Your task is to predict the average consuming price of a venue based on its description, which includes its name and category. The venue's price will fall into one of four categories: Cheap, Moderate, Expensive, and Very Expensive. Remember, the name and category of the venue can be significant indicators of its price. For instance, fast-food chains like 'McDonald' might typically be 'Cheap', while upscale restaurants with names suggesting fine dining might be 'Expensive' or 'Very Expensive.'"""
    new_info['input'] = prompt_str
    
    if info['price']: 
        if info['price'] == 1:
            price = 'Cheap'
        elif info['price'] == 2:
            price = 'Moderate'
        elif info['price'] == 3:
            price = 'Expensive'
        else:
            price = 'Very Expensive'
            
        new_info['output'] = price
        prompt_venue_text[venueid] = new_info
    else:
        new_info['output'] = None
        prompt_venue_text[venueid] = new_info

In [8]:
# prompt_venue_text

In [9]:
BASE_MODEL = "model/llama2-7B-hf"
LORA_WEIGHTS = "model/price_lr1e-4_v2"

model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map={'': 0},
    local_files_only=True,
)

tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16, device_map={'': 0})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
def generate_prompt(data_point):
    return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
"""

def generate_response(prompt: str, model: PeftModel) -> GreedySearchDecoderOnlyOutput:
    encoding = tokenizer(prompt, return_tensors="pt")
    input_ids = encoding["input_ids"].to(DEVICE)

    generation_config = GenerationConfig(
        temperature=0.1,
        top_p=0.75,
        repetition_penalty=1.1,
    )
    with torch.inference_mode():
        return model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=128,
        )
    
def format_response(response: GreedySearchDecoderOnlyOutput) -> str:
    decoded_output = tokenizer.decode(response.sequences[0])
    response = decoded_output.split("### Response:")[1].strip()
    return "\n".join(textwrap.wrap(response))

def ask_alpaca(prompt: str, model: PeftModel = model) -> str:
    prompt = generate_prompt(prompt)
    # print(prompt)
    response = generate_response(prompt, model)
    return format_response(response)

In [11]:
count = 0
predictions = {}
for k, v in prompt_venue_text.items():
    if v['output']:
        predictions[k] = v
    else:
        v['output'] = ask_alpaca(v)
        predictions[k] = v

    if count % 10 == 0:
        print(str(count/len(prompt_venue_text) * 100) + ' %')
        filename = f"../../result/price_prediction_6711.json"
        with open(filename, "w") as json_file:
            json.dump(predictions, json_file)
    count += 1

filename = f"../../result/price_prediction_6711.json"
with open(filename, "w") as json_file:
    json.dump(predictions, json_file)

0.0 %
0.14900908955446282 %
0.29801817910892564 %
0.44702726866338843 %
0.5960363582178513 %
0.7450454477723141 %
0.8940545373267769 %
1.0430636268812399 %
1.1920727164357026 %
1.3410818059901655 %
1.4900908955446281 %
1.639099985099091 %
1.7881090746535537 %
1.9371181642080169 %
2.0861272537624798 %
2.235136343316942 %
2.384145432871405 %
2.5331545224258676 %
2.682163611980331 %
2.831172701534794 %
2.9801817910892563 %
3.1291908806437196 %
3.278199970198182 %
3.427209059752645 %
3.5762181493071075 %
3.7252272388615704 %
3.8742363284160337 %
4.023245417970496 %
4.1722545075249595 %
4.3212635970794215 %
4.470272686633884 %
4.619281776188347 %
4.76829086574281 %
4.917299955297273 %
5.066309044851735 %
5.215318134406199 %
5.364327223960662 %
5.513336313515124 %
5.662345403069588 %
5.811354492624051 %
5.960363582178513 %
6.1093726717329755 %
6.258381761287439 %
6.407390850841901 %
6.556399940396364 %
6.705409029950828 %
6.85441811950529 %
7.003427209059752 %
7.152436298614215 %
7.301445388