In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import transformers
import textwrap
from transformers import LlamaTokenizer, LlamaForCausalLM
import os
import sys
from typing import List

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
)

import fire
import torch
from datasets import load_dataset
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from pylab import rcParams

from transformers.generation.utils import GreedySearchDecoderOnlyOutput
from peft import PeftModel
from transformers import GenerationConfig
import json

%matplotlib inline
sns.set(rc={'figure.figsize':(10, 7)})
sns.set(rc={'figure.dpi':100})
sns.set(style='white', palette='muted', font_scale=1.2)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [4]:
# read the target api dataset
with open('../../data/standard_semantic_features.json', 'r') as jsonfile:
    format_venue_text = json.load(jsonfile)

In [7]:
def get_keys_with_true_values(d, prefix=''):
    keys_with_true_values = []
    
    for key, value in d.items():
        if isinstance(value, dict):
            # Recursively traverse nested dictionaries
            keys_with_true_values.extend(get_keys_with_true_values(value, prefix + key + '.'))
        elif value is True:
            # If the value is True, add the key to the list
            keys_with_true_values.append(prefix + key)
    
    return keys_with_true_values

In [8]:
prompt_venue_text = {}
for venueid, info in format_venue_text.items():
    prompt_str = ""
    prompt_str += f"Venue Name: {info['name']}.\n"
    
    if info['category']:
        categories = ', '.join(info['category'])
        prompt_str += f"Venue Category: {categories}.\n"
        
    if info['description']:
        prompt_str += f" Venue Short Description: {info['description']}.\n"
        
    if info['price']:
        if info['price'] == 1:
            price = 'cheap'
        elif info['price'] == 2:
            price = 'moderate'
        elif info['price'] == 3:
            price = 'expensive'
        else:
            price = 'very expensive'
            
        prompt_str += f" Venue price: {price}.\n"

    if info['features']:
        features = get_keys_with_true_values(info['features'], prefix='')
        features = [f for f in features if 'payment' not in f]
        features = [f.split('.')[-1] for f in features]
        features = ', '.join(features)
        
        if len(features):
            prompt_str += f"The Features: {features}\n"
            
    if info['tips']:
        tips = [f" {i + 1}. {venue['text']}" for i, venue in enumerate(info['tips'])]
        tips_str = '\n'.join(tips)
        prompt_str += f" The reviews of customers are:\n{tips_str}."
        
    if info['rating']:
        new_info = {}
        new_info['instruction'] = """Your task is to predict the rating of the venue based on its name, category, description, average price and customer reviews. The rating should be on a scale of 0.0 to 10.0, with precision limited to one decimal place."""
        new_info['input'] = prompt_str
        
        new_info['output'] = info['rating']
        
        prompt_venue_text[venueid] = new_info
    else:
        new_info = {}
        new_info['instruction'] = """Your task is to predict the rating of the venue based on its name, category, description, average price and customer reviews. The rating should be on a scale of 0.0 to 10.0, with precision limited to one decimal place."""
        new_info['input'] = prompt_str
        new_info['output'] = None
        prompt_venue_text[venueid] = new_info

In [24]:
# prompt_venue_text

In [11]:
BASE_MODEL = "model/llama2-7B-hf"
LORA_WEIGHTS = "model/rating_lr1e-4_v2"

model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map={'': 0},
    local_files_only=True,
)

tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16, device_map={'': 0})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [27]:
def generate_prompt(data_point):
    return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
"""

def generate_response(prompt: str, model: PeftModel) -> GreedySearchDecoderOnlyOutput:
    encoding = tokenizer(prompt, return_tensors="pt")
    input_ids = encoding["input_ids"].to(DEVICE)

    generation_config = GenerationConfig(
        temperature=0.1,
        top_p=0.75,
        repetition_penalty=1.1,
    )
    with torch.inference_mode():
        return model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=128,
        )
    
def format_response(response: GreedySearchDecoderOnlyOutput) -> str:
    decoded_output = tokenizer.decode(response.sequences[0])
    response = decoded_output.split("### Response:")[1].strip()
    return "\n".join(textwrap.wrap(response))

def ask_alpaca(prompt: str, model: PeftModel = model) -> str:
    prompt = generate_prompt(prompt)
    # print(prompt)
    response = generate_response(prompt, model)
    return format_response(response)

In [31]:
count = 0
predictions = {}
for k, v in prompt_venue_text.items():
    if v['output']:
        predictions[k] = v
    else:
        v['output'] = ask_alpaca(v)
        predictions[k] = v

    if count % 100 == 0:
        print(str(count/len(prompt_venue_text) * 100) + ' %')
        filename = f"../../result/rating_prediction_6711.json"
        with open(filename, "w") as json_file:
            json.dump(predictions, json_file)
    count += 1

filename = f"../../result/rating_prediction_6711.json"
with open(filename, "w") as json_file:
    json.dump(predictions, json_file)

0.0 %
1.4900908955446281 %
2.9801817910892563 %
4.470272686633884 %
5.960363582178513 %
7.450454477723141 %
8.940545373267769 %
10.430636268812398 %
11.920727164357025 %
13.410818059901656 %
14.900908955446281 %
16.390999850990912 %
17.881090746535538 %
19.371181642080167 %
20.861272537624796 %
22.351363433169425 %
23.84145432871405 %
25.33154522425868 %
26.821636119803312 %
28.311727015347937 %
29.801817910892563 %
31.291908806437192 %
32.781999701981825 %
34.27209059752645 %
35.762181493071076 %
37.2522723886157 %
38.742363284160334 %
40.232454179704966 %
41.72254507524959 %
43.21263597079422 %
44.70272686633885 %
46.192817761883475 %
47.6829086574281 %
49.172999552972726 %
50.66309044851736 %
52.153181344061984 %
53.643272239606624 %
55.13336313515125 %
56.623454030695875 %
58.1135449262405 %
59.603635821785126 %
61.09372671732976 %
62.583817612874384 %
64.07390850841901 %
65.56399940396365 %
67.05409029950827 %
68.5441811950529 %
70.03427209059753 %
71.52436298614215 %
73.014453881