In [2]:
import sys, json, re
import pandas as pd
from pathlib import Path
from decouple import config
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForMaskedLM,  file_utils
cache_dir = Path(file_utils.default_cache_path)

pd.set_option('display.max_columns', 1000, 'display.width', 1000, 'display.max_rows',1000)

data_dir = Path(".").absolute().parent/"data"
ls = lambda p:print("\n".join(map(str,p.iterdir())))

ls(data_dir)
hf_model_name = "gpt2"
# hf_model_name = "mistralai/Mistral-7B-v0.1"

/home/idan/Documents/llm_workshop/data/sample_apps.parquet


In [3]:
df = pd.read_parquet(data_dir / "sample_apps.parquet").sample(9)
categories = df["category_names"].str.lower().str.split(',').explode().value_counts()
df.sample(9)

Unnamed: 0,bundle_id,title,description,store_url,category_names,ios
24041,com.grabtaxi.passenger,Grab Superapp,Grab is Southeast Asia’s leading superapp. We ...,https://play.google.com/store/apps/details?id=...,"TRAVEL_AND_LOCAL,APPLICATION",False
26290,com.hwqgrhhjfd.idlefastfood,Eatventure,Are you looking to become a restaurant million...,https://play.google.com/store/apps/details?id=...,"GAME_SIMULATION,GAME",False
49136,com.tripledot.woodoku,Woodoku - Block Puzzle Games,Woodoku: a wood block puzzle game meets a sudo...,https://play.google.com/store/apps/details?id=...,"GAME_PUZZLE,GAME",False
17760,com.dream.dale,Dreamdale - Fairy Adventure,🌳 ONCE UPON A TIME…\n\nSet off on a fairy tale...,https://play.google.com/store/apps/details?id=...,"GAME_ROLE_PLAYING,GAME",False
29752,com.king.candycrushsodasaga,Candy Crush Soda Saga,You loved playing Candy Crush Saga - Start pla...,https://play.google.com/store/apps/details?id=...,"GAME_CASUAL,GAME",False
15820,com.creditkarma.mobile,Credit Karma,• Check your free credit scores – Learn what a...,https://play.google.com/store/apps/details?id=...,"FINANCE,APPLICATION",False
1021,1105855019,Gardenscapes,Welcome to Gardenscapes—the first hit from Pla...,https://apps.apple.com/us/app/gardenscapes/id1...,"Games,Entertainment,Puzzle,Simulation",True
5315,530168168,Paramount+,Welcome to A Mountain of Entertainment. Stream...,https://apps.apple.com/us/app/paramount/id5301...,Entertainment,True
39809,com.playrix.fishdomdd.gplay,Fishdom,Never Fishdomed before? Take a deep breath and...,https://play.google.com/store/apps/details?id=...,"GAME_PUZZLE,GAME",False


# Verbalizers

## Verbalizers as masks

Most generation models we used so far are `CausalLM` trained to predict the next token.

However, we can use `MaskedLM` models (that tend to be smaller) if we are looking for a completion mid-sentence

In [4]:
def masked_lm_yes_or_no(txt, model_str):
  assert "<mask>" in txt
  tokenizer = AutoTokenizer.from_pretrained(model_str)
  r = [t for t in tokenizer.encode("yes or no") if t!=tokenizer.bos_token_id and t!=tokenizer.eos_token_id]
  yes,_,no = r
  model = AutoModelForMaskedLM.from_pretrained(model_str)
  # model = AutoModelForSeq2SeqLM.from_pretrained(model_str)
  X = tokenizer.encode(txt, return_tensors="pt")
  y = model(X)
  masked_tup = (X==tokenizer.mask_token_id).nonzero(as_tuple=True)
  mask_idx = list(masked_tup[1].numpy())[0]
  ret = torch.vstack(
  [y.logits[:,mask_idx,no].reshape(-1),
    y.logits[:,mask_idx,yes].reshape(-1)],
  ).argmax(axis=0)
  return ret

In [5]:
masked_lm_yes_or_no("Is an apple a fruit? answer: <mask>", "facebook/bart-large")

Downloading tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

tensor([0])

## Verbalizers from generation models

In [9]:
def causal_lm_yes_or_no(txt, model_str):
  tokenizer = AutoTokenizer.from_pretrained(model_str)
  r = [t for t in tokenizer.encode("yes or no") if t!=tokenizer.bos_token_id and t!=tokenizer.eos_token_id]
  yes,_,no = r
  model = AutoModelForCausalLM.from_pretrained(model_str)
  X = tokenizer.encode(txt, return_tensors="pt")
  y = model(X)
  ret = torch.vstack(
  [y.logits[:,-1,no].reshape(-1),
    y.logits[:,-1,yes].reshape(-1)],
  ).argmax(axis=0)
  return ret

In [10]:
causal_lm_yes_or_no("Is an apple a fruit? answer: <mask>", "gpt2")

Downloading tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

tensor([0])

# JSONFormer
JSONFormer constraints the decoder to only output the most-likely token that would result in a valid json according to a predefined schema.

In [15]:
from jsonformer import Jsonformer

model = AutoModelForCausalLM.from_pretrained(hf_model_name)
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)

json_schema = {
    "type": "object",
    "properties": {
        "name": {"type": "string"},
        "age": {"type": "number"},
        "is_for_kids": {"type": "boolean"},
        "categories": {
            "type": "array",
            "items": {"type": "string"}
        }
    }
}

prompt = "Please describe 'Candy crush' with the following schema"
jsonformer = Jsonformer(model, tokenizer, json_schema, prompt)
generated_data = jsonformer()

print(generated_data)

{'name': 'Candy Crush', 'age': 0.5, 'is_for_kids': False, 'categories': ['boolean', 'boolean']}


# Guidance
Guidance is a very popular library for decoder constraints, that is much more "user-friendly" than JSONFormer.

In [6]:
from guidance import models, select, gen
llm = models.Transformers(hf_model_name)

Downloading config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [8]:
prompt = "Please categorize the mobile app 'slotomania'"
llm + gen(prompt, max_tokens=10)

In [26]:
app = "Solitaire Grand Harvest"

llm + f'{app} is ' + select(list(categories.index))

# Exercise 3
Answer the questions in exercise 1 with `Mistral-7B`