In [2]:
import json
import pandas as pd
import numpy as np
from transformers import LlamaTokenizer, LlamaForCausalLM
import torch

from transformers.generation.utils import GreedySearchDecoderOnlyOutput
from transformers import GenerationConfig
import textwrap

DEVICE = 'cuda'

In [3]:
def generate_response(prompt: str, model) -> GreedySearchDecoderOnlyOutput:
    encoding = tokenizer(prompt, return_tensors="pt")
    input_ids = encoding["input_ids"].to(DEVICE)

    generation_config = GenerationConfig(
        temperature=0.1,
        top_p=0.75,
        repetition_penalty=1.1,
    )
    with torch.inference_mode():
        return model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=128,
        )

    
def format_response(response: GreedySearchDecoderOnlyOutput) -> str:
    decoded_output = tokenizer.decode(response.sequences[0])
    response = decoded_output.split("#### Response:")[4].strip()
    return "\n".join(textwrap.wrap(response))


def ask_alpaca(prompt: str, model=model) -> str:
    prompt = create_prompt(prompt)
    response = generate_response(prompt, model)
    return (format_response(response))


def create_prompt(instruction: str) -> str:
    return PROMPT_TEMPLATE1.replace("[INPUT]", instruction).replace("'", '')


def load_checkpoint(filepath=None):
    
    if filepath == None:
        return {}
    
    # Provide the path to your JSON file
    text_file = filepath

    # Load JSON data from the file
    with open(text_file, "r") as json_file:
        generated_formatted_venue_text = json.load(json_file)
        
    return generated_formatted_venue_text

NameError: name 'model' is not defined

## Prediction

In [None]:
# read the target api dataset
with open('../../data/category.json', 'r') as jsonfile:
    prompt_venue_text = json.load(jsonfile)

In [None]:
prompt_venue_text

In [6]:
BASE_MODEL = r"model/llama2-7B-hf"

model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)

tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### few-shot

In [None]:
PROMPT_TEMPLATE1 = """
Your task is to determine the venue category based on its given information. You MUST choose ONE category ONLY from the provided list below.
Airport, Animal Shelter, Antique Shop, Aquarium, Arcade, Art Gallery, Arts & Crafts Store, Athletic & Sport, Automotive Shop, Bagel Shop, Bakery, Bank, Bar, Beach, Beer Garden, Bike Shop, Bookstore, Bowling Alley, Brewery, Bridal Shop, Bridge, Building, Bus Station, Camera Store, Campground, Candy Store, Car Dealership, Car Wash, Casino, Cemetery, Church, Clothing Store, Coffee Shop, College & University, College Theater, Comedy Club, Concert Hall, Convenience Store, Convention Center, Cosmetics Shop, Cupcake Shop, Deli & Bodega, Department Store, Design Studio, Dessert Shop, Donut Shop, Drugstore & Pharmacy, Electronics Store, Event Space, Factory, Ferry, Flea Market, Flower Shop, Food Truck, Funeral Home, Furniture & Home Store, Garden, Gas Station & Garage, Gastropub, General Entertainment, Gift Shop, Government Building, Grocery Store, Gym & Fitness Center, Harbor & Marina, Hardware Store, Historic Site, Hobby Shop, Hotel, Housing Development, Ice Cream Shop, Jewelry Store, Laundry Service, Law School, Library, Light Rail, Mall, Medical Center, Medical School, Miscellaneous Shop, Mobile Phone Shop, Mosque, Movie Theater, Museum, Music Store, Music Venue, Nail Salon, Neighborhood, Nursery School, Office, Other Nightlife, Outdoors & Recreation, Paper & Office Supplies Store, Parking, Performing Arts Venue, Pet Store, Playground, Plaza, Pool, Pool Hall, Post Office, Professional & Other Places, Racetrack, Record Shop, Recycling Facility, Residential Building (Apartment & Condo), Rest Area, Restaurant, River, Road, Salon & Barbershop, Scenic Lookout, School, Sculpture Garden, Shop & Service, Smoke Shop, Snack Place, Spa & Massage, Spiritual Center, Sporting Goods Shop, Stadium, Storage Facility, Subway Station, Synagogue, Tanning Salon, Tattoo Parlor, Taxi, Tea Room, Theater, Thrift & Vintage Store, Toy & Game Store, Trade School, Train Station, Travel & Transport, Travel Lounge, Video Game Store, Video Store, Zoo
### Example 1:

#### Input:
Venue name: T-Mobile
Venue description: Visit T-Mobile New York cell phone stores and discover T-Mobile's best smartphones, cell phones, tablets, and internet devices. View our low cost plans with no annual service contracts.
Venue review 1: Employees here try to tricky you into more expensive stuff. I've tried to get a plan for my wife that is the same I've got three hours ago in a different store and the lady said it doesn't exist. 
Venue review 2: LA, today we're spreading Unlimited Cheer with the T-Mobile Girl! Follow us on Twitter for your chance at a Life Without Limits prize pack. http://bit.ly/UFQPCH
Venue review 3: Absolutely the worst customer service experience I have ever had. Can not believe how disorganized and rude these sales people. It's like they would rather not take your money. 

#### Response:
Mobile Phone Shop.

### Example 2:

#### Input:
Venue name: Eataly Flatiron
Venue description: Eataly is a dynamic marketplace with restaurants that was created in Torino, Italy by Oscar Farinetti in 2007. There are 26 Eataly stores in the world.
Venue review 1: Great meat section
Venue review 2: Agnolotti tartuffo and poloeto meatballs. Also de Negroni was great
Venue review 3: Tiramisu more like cream-my-pants-u. Incredibly delicious!

#### Response:
Grocery Store.

### Example 3:

#### Input:
Venue name: Pitkin Education Center
Venue description: None
Venue review 1: Everyone has that face that screams: I want to punch you.
Venue review 2: Agnolotti tartuffo and poloeto meatballs. Also de Negroni was great
Venue review 3: The mayor is Tara...the one in the counseling center, in the back room, with the funk pens. Say ello to her.

#### Response:
College & University.

### Your Task:
#### Input:
{"[INPUT]"}

#### Response:
"""

In [None]:
generated_formatted_venue_text = load_checkpoint('data/category_predict_name_des_tip_6500.json')
count = 0   

for index, (venue, info) in enumerate(prompt_venue_text.items()):
    if venue not in generated_formatted_venue_text.keys():
        
        new_info = {}
        new_info['text'] = info['prompt']
        new_info['generated'] = ask_alpaca(info['prompt'])
        new_info['truth'] = info['label']
        
        generated_formatted_venue_text[venue] = new_info
    
    if (index+1)%100 == 0:
        print(f"{(index+1)*100/len(prompt_venue_text):.2f} %")
        
    if (index+1)%100 == 0:
        # print(f"{(index+1)*100/len(venue_text):.2f} %")
        
        filename = f"data/category_predict_name_des_tip_{index+1}.json"
        with open(filename, "w") as json_file:
            json.dump(generated_formatted_venue_text, json_file)
            
    count = index

filename = f"data/category_predict_name_des_tip_6711.json"
with open(filename, "w") as json_file:
    json.dump(generated_formatted_venue_text, json_file)

### Second Round few-shot to Correct the Labels that don't Exist

In [13]:
with open('data/category_predict_name_des_tip_6711.json', 'r') as jsonfile:
    generated_formatted_venue_text = json.load(jsonfile)

In [14]:
category_list = []

for index, (venue, info) in enumerate(generated_formatted_venue_text.items()):
    category_list.append(info['truth'])

categories = set(sorted(category_list))

In [16]:
PROMPT_TEMPLATE2 = """
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
You are provided a venue category not present in the given standard list. Match it to the closest category from the standard list. Choose ONLY ONE category as your response.
Standard Venue Categories: Electronics Store, Light Rail, Music Store, Animal Shelter, Theater, Smoke Shop, Historic Site, Airport, Snack Place, Salon & Barbershop, Antique Shop, Jewelry Store, Bookstore, Playground, Toy & Game Store, Pool Hall, Furniture & Home Store, Mobile Phone Shop, Thrift & Vintage Store, College Theater, Bagel Shop, Synagogue, Housing Development, Dessert Shop, Museum, College & University, Tea Room, Tanning Salon, Food Truck, Restaurant, Library, Storage Facility, Plaza, Shop & Service, Miscellaneous Shop, Hobby Shop, Racetrack, Train Station, Pet Store, Laundry Service, Medical School, Scenic Lookout, Zoo, Sporting Goods Shop, Bakery, Deli & Bodega, Road, Harbor & Marina, Office, Bowling Alley, Gym & Fitness Center, Aquarium, Athletic & Sport, Candy Store, Sculpture Garden, Campground, Casino, Convenience Store, Church, Grocery Store, Government Building, Post Office, Gastropub, Bank, Paper & Office Supplies Store, Automotive Shop, Trade School, Other Nightlife, Video Game Store, Department Store, Travel Lounge, Donut Shop, General Entertainment, Rest Area, Stadium, Law School, Video Store, Convention Center, Arts & Crafts Store, Ice Cream Shop, Nursery School, Car Wash, Neighborhood, Brewery, Beach, Pool, Mosque, Ferry, Concert Hall, Coffee Shop, School, Residential Building (Apartment & Condo), Factory, Parking, Art Gallery, Record Shop, Professional & Other Places, Design Studio, Hotel, Arcade, Bike Shop, Spa & Massage, Cupcake Shop, Tattoo Parlor, Comedy Club, Nail Salon, Movie Theater, River, Camera Store, Taxi, Subway Station, Gift Shop, Flower Shop, Event Space, Bar, Building, Beer Garden, Bus Station, Gas Station & Garage, Spiritual Center, Music Venue, Clothing Store, Cosmetics Shop, Medical Center, Performing Arts Venue, Outdoors & Recreation, Funeral Home, Garden, Travel & Transport, Car Dealership, Bridge, Cemetery, Hardware Store, Mall, Bridal Shop, Drugstore & Pharmacy, Recycling Facility, Flea Market.

#### Example 1:
### Input: 
Venue Category: Park

### Response:
Outdoors & Recreation.

#### Example 2:
### Input: 
Venue Category: Computer Store

### Response:
Electronics Store.

#### Your task:
### Input:
Venue Category: {"[INPUT]"}

### Response:
"""

def create_prompt(instruction: str) -> str:
    return PROMPT_TEMPLATE2.replace("[INPUT]", instruction).replace("'", '')

def generate_response(prompt: str, model) -> GreedySearchDecoderOnlyOutput:
    encoding = tokenizer(prompt, return_tensors="pt")
    input_ids = encoding["input_ids"].to(DEVICE)

    generation_config = GenerationConfig(
        temperature=0.1,
        top_p=0.75,
        repetition_penalty=1.1,
    )
    with torch.inference_mode():
        return model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=128,
        )

def format_response(response: GreedySearchDecoderOnlyOutput) -> str:
    decoded_output = tokenizer.decode(response.sequences[0])
    response = decoded_output.split("### Response:")[3].strip()
    return "\n".join(textwrap.wrap(response))

def ask_alpaca(prompt: str, model=model) -> str:
    prompt = create_prompt(prompt)
    response = generate_response(prompt, model)
    return (format_response(response))

In [22]:
#corrected_prediction = load_checkpoint('data/category_predict_name_des_tip_corrected_5500.json')
corrected_prediction = load_checkpoint()

for index, (venue, info) in enumerate(generated_formatted_venue_text.items()):
    
    if venue not in corrected_prediction.keys():
        if info['generated'] == '' or info['generated'][:4] == 'None':
            corrected_prediction[venue] = info
        else:
            generated = info['generated'].split('<')[0].strip().strip('.')

            if generated not in categories:
                new_info = {}
                new_info['text'] = info['text']
                new_info['generated'] = ask_alpaca(generated)
                new_info['truth'] = info['truth']
                corrected_prediction[venue] = new_info

            else:
                corrected_prediction[venue] = info
        
    if (index+1)%100 == 0:
        print(f"{(index+1)*100/len(generated_formatted_venue_text):.2f} %")
        
    if (index+1)%1000 == 0:
        # print(f"{(index+1)*100/len(venue_text):.2f} %")
        
        filename = f"../../result/category_predict_name_des_tip_corrected_{index+1}.json"
        with open(filename, "w") as json_file:
            json.dump(corrected_prediction, json_file)
            
    count = index

filename = f"../../result/category_predict_name_des_tip_corrected_{count+1}.json"
with open(filename, "w") as json_file:
    json.dump(corrected_prediction, json_file)

1.49 %
2.98 %
4.47 %
5.96 %
7.45 %
8.94 %
10.43 %
11.92 %
13.41 %
14.90 %
16.39 %
17.88 %
19.37 %
20.86 %
22.35 %
23.84 %
25.33 %
26.82 %
28.31 %
29.80 %
31.29 %
32.78 %
34.27 %
35.76 %
37.25 %
38.74 %
40.23 %
41.72 %
43.21 %
44.70 %
46.19 %
47.68 %
49.17 %
50.66 %
52.15 %
53.64 %
55.13 %
56.62 %
58.11 %
59.60 %
61.09 %
62.58 %
64.07 %
65.56 %
67.05 %
68.54 %
70.03 %
71.52 %
73.01 %
74.50 %
75.99 %
77.48 %
78.97 %
80.46 %
81.95 %
83.45 %
84.94 %
86.43 %
87.92 %
89.41 %
90.90 %
92.39 %
93.88 %
95.37 %
96.86 %
98.35 %
99.84 %
