In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import re
import transformers
import textwrap
from transformers import LlamaTokenizer, LlamaForCausalLM
import os
import sys
from typing import List

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
)

import fire
import torch
from datasets import load_dataset
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from pylab import rcParams

from transformers.generation.utils import GreedySearchDecoderOnlyOutput
from peft import PeftModel
from transformers import GenerationConfig
import json

%matplotlib inline
sns.set(rc={'figure.figsize':(10, 7)})
sns.set(rc={'figure.dpi':100})
sns.set(style='white', palette='muted', font_scale=1.2)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [3]:
def format_hours(hours_data):
    # Mapping of day numbers to day names
    day_names = {
        1: 'Monday',
        2: 'Tuesday',
        3: 'Wednesday',
        4: 'Thursday',
        5: 'Friday',
        6: 'Saturday',
        7: 'Sunday'
    }
    
    # Get the day name from the mapping
    day_name = day_names[hours_data['day']]
    
    # Format the open time
    open_time = f"{hours_data['open'][:2]}:{hours_data['open'][2:]}"
    
    # Check if the close time is on the next day
    if hours_data['close'].startswith('+'):
        # Calculate the next day
        next_day = hours_data['day'] % 7 + 1  # Ensure it wraps around the week
        next_day_name = day_names[next_day]
        close_time = f"{hours_data['close'][1:3]}:{hours_data['close'][3:]}"
        formatted_time = f"{day_name} {open_time} to {next_day_name} {close_time}"
    else:
        # Format the close time normally
        close_time = f"{hours_data['close'][:2]}:{hours_data['close'][2:]}"
        formatted_time = f"{day_name} {open_time} to {close_time}"
    
    return formatted_time

In [4]:
# read the target api dataset
with open('../../data/standard_semantic_features.json.json', 'r') as jsonfile:
    format_venue_text = json.load(jsonfile)

# here, load the prediction of open hour, bacause the popular hour should be predicted based on open hour
with open('../../result/openhour_prediction_6711.json', 'r') as jsonfile:
    predicted_value = json.load(jsonfile)

In [5]:
format_venue_text['49bbd6c0f964a520f4531fe3']

{'description': None,
 'features': {},
 'hours': None,
 'location': '308 Canal St (btwn Broadway & Mercer), New York, NY 10013',
 'name': 'Pearl Art & Craft Supply',
 'popularity': 0.6463884592658508,
 'tips': [{'created_at': '2014-01-26T23:33:23.000Z',
   'text': "A paradise for any kind of artist, or even if you're not one yet. 10% discount if you have a student or teacher ID !"},
  {'created_at': '2014-01-10T14:03:53.000Z', 'text': 'Take the stairs'},
  {'created_at': '2013-11-22T22:26:08.000Z',
   'text': 'Great place to buy all you need, but the place looks like its going to collapse'},
  {'created_at': '2013-10-18T20:21:36.000Z',
   'text': 'Muy buena selección de cosas y personal experto, un poco en caida pero aun asi de lo mejor de ny'},
  {'created_at': '2013-09-20T22:11:31.000Z',
   'text': "There's an elevator in the back!"}],
 'related_places': {},
 'tastes': None,
 'price': None,
 'menu': None,
 'rating': None,
 'category': ['Arts and Crafts Store']}

In [6]:
predicted_value['49bbd6c0f964a520f4531fe3']

{'instruction': 'Your task is to predict the open hour of the venue based on its name, category, description, average price, customer reviews and other features. ',
 'input': "Venue Name: Pearl Art & Craft Supply.\nVenue Category: Arts and Crafts Store.\nThe Customer Reviews:\n 1. A paradise for any kind of artist, or even if you're not one yet. 10% discount if you have a student or teacher ID !\n 2. Take the stairs\n 3. Great place to buy all you need, but the place looks like its going to collapse\n 4. Muy buena selección de cosas y personal experto, un poco en caida pero aun asi de lo mejor de ny\n 5. There's an elevator in the back!.",
 'output': 'Mon-Sat 9:00 AM-8:00 PM; Sun 10:00 AM-7:00 PM</s>'}

In [7]:
for venueid, venue in format_venue_text.items():
    try:
        if venue['hours']['display']:
            continue
    except:
        predicted_openhour = predicted_value[venueid]['output']
        predicted_openhour = re.sub(r'</s>', '', predicted_openhour)
        
        venue['hours'] = {}
        venue['hours']['display'] = predicted_openhour

In [8]:
# format_venue_text

In [9]:
def get_keys_with_true_values(d, prefix=''):
    keys_with_true_values = []
    
    for key, value in d.items():
        if isinstance(value, dict):
            # Recursively traverse nested dictionaries
            keys_with_true_values.extend(get_keys_with_true_values(value, prefix + key + '.'))
        elif value is True:
            # If the value is True, add the key to the list
            keys_with_true_values.append(prefix + key)
    
    return keys_with_true_values

In [10]:
## useful features: name, category, tips, description, price, features, hours
prompt_venue_text = {}
for venueid, info in format_venue_text.items():
    prompt_str = ""
    prompt_str += f"Venue Name: {info['name']}.\n"
    
    if info['category']:
        categories = ', '.join(info['category'])
        prompt_str += f"Venue Category: {categories}.\n"
        
    if info['description']:
        prompt_str += f"Venue Short Description: {info['description']}.\n"
        
    if info['price']:
        if info['price'] == 1:
            price = 'cheap'
        elif info['price'] == 2:
            price = 'moderate'
        elif info['price'] == 3:
            price = 'expensive'
        else:
            price = 'very expensive'
            
        prompt_str += f" Venue price: {price}.\n"
        
    if info['features']:
        features = get_keys_with_true_values(info['features'], prefix='')
        features = [f for f in features if 'payment' not in f]
        features = [f.split('.')[-1] for f in features]
        features = ', '.join(features)
        
        if len(features):
            prompt_str += f"The Features: {features}\n"
        
    if info['tips']:
        tips = [f" {i + 1}. {venue['text']}" for i, venue in enumerate(info['tips'])]
        tips_str = '\n'.join(tips)
        prompt_str += f"The Customer Reviews:\n{tips_str}."

    if info['hours']['display']:
        prompt_str += f"\nVenue Open Hour: {info['hours']['display']}.\n"
        

    new_info = {}
    new_info['instruction'] = new_info['instruction'] = """Your task is to predict the popular hour of the venue based on its name, category, description, average price, customer reviews and other features."""
    new_info['input'] = prompt_str
    
    try:
        hour_popular = info['hours_popular']
        hour_popular_str = ""
        
        for day in hour_popular:
            hour_popular_str += format_hours(day)
            hour_popular_str += ', '
        
        new_info['output'] = hour_popular_str[:-2]
        prompt_venue_text[venueid] = new_info
    except:
        new_info['output'] = None
        prompt_venue_text[venueid] = new_info

In [11]:
BASE_MODEL = "model/llama2-7B-hf"
LORA_WEIGHTS = "model/popular_hour_lr7e-5"

model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map={'': 0},
    local_files_only=True,
)

tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16, device_map={'': 0})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
def generate_prompt(data_point):
    return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:
"""

def generate_response(prompt: str, model: PeftModel) -> GreedySearchDecoderOnlyOutput:
    encoding = tokenizer(prompt, return_tensors="pt")
    input_ids = encoding["input_ids"].to(DEVICE)

    generation_config = GenerationConfig(
        temperature=0.1,
        top_p=0.75,
        repetition_penalty=1.1,
    )
    with torch.inference_mode():
        return model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=128,
        )
    
def format_response(response: GreedySearchDecoderOnlyOutput) -> str:
    decoded_output = tokenizer.decode(response.sequences[0])
    response = decoded_output.split("### Response:")[1].strip()
    return "\n".join(textwrap.wrap(response))

def ask_alpaca(prompt: str, model: PeftModel = model) -> str:
    prompt = generate_prompt(prompt)
    # print(prompt)
    response = generate_response(prompt, model)
    return format_response(response)

In [13]:
def load_checkpoint(filepath=None):
    
    if filepath == None:
        return {}
    
    # Provide the path to your JSON file
    text_file = filepath

    # Load JSON data from the file
    with open(text_file, "r") as json_file:
        generated_formatted_venue_text = json.load(json_file)
        
    return generated_formatted_venue_text

In [14]:
predictions = load_checkpoint('data/popularhour_prediction_6711_v3.json')

In [None]:
count = 0

for k, v in prompt_venue_text.items():
    if k not in predictions.keys():
        if v['output']:
            predictions[k] = v
        else:
            v['output'] = ask_alpaca(v)
            predictions[k] = v

    if count % 10 == 0:
        print(str(count/len(prompt_venue_text) * 100) + ' %')
        filename = f"../../result/popularhour_prediction_6711_v2.json"
        with open(filename, "w") as json_file:
            json.dump(predictions, json_file)
    count += 1

filename = f"../../result/popularhour_prediction_6711_v2.json"
with open(filename, "w") as json_file:
    json.dump(predictions, json_file)

0.0 %
0.14900908955446282 %
0.29801817910892564 %
0.44702726866338843 %
0.5960363582178513 %
0.7450454477723141 %
0.8940545373267769 %
1.0430636268812399 %
1.1920727164357026 %
1.3410818059901655 %
1.4900908955446281 %
1.639099985099091 %
1.7881090746535537 %
1.9371181642080169 %
2.0861272537624798 %
2.235136343316942 %
2.384145432871405 %
2.5331545224258676 %
2.682163611980331 %
2.831172701534794 %
2.9801817910892563 %
3.1291908806437196 %
3.278199970198182 %
3.427209059752645 %
3.5762181493071075 %
3.7252272388615704 %
3.8742363284160337 %
4.023245417970496 %
4.1722545075249595 %
4.3212635970794215 %
4.470272686633884 %
4.619281776188347 %
4.76829086574281 %
4.917299955297273 %
5.066309044851735 %
5.215318134406199 %
5.364327223960662 %
5.513336313515124 %
5.662345403069588 %
5.811354492624051 %
5.960363582178513 %
6.1093726717329755 %
6.258381761287439 %
6.407390850841901 %
6.556399940396364 %
6.705409029950828 %
6.85441811950529 %
7.003427209059752 %
7.152436298614215 %
7.301445388