<a href="https://colab.research.google.com/github/githubpradeep/notebooks/blob/main/wordllama.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wordllama

Collecting wordllama
  Downloading wordllama-0.2.6.post2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading wordllama-0.2.6.post2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: wordllama
Successfully installed wordllama-0.2.6.post2


In [1]:
from wordllama import WordLlama

# Load the default WordLlama model
wl = WordLlama.load()

# Calculate similarity between two sentences
similarity_score = wl.similarity("i went to the car", "i went to the pawn shop")
print(similarity_score)  # Output: 0.06641249096796882

# Rank documents based on their similarity to a query
query = "i went to the car"
candidates = ["i went to the park", "i went to the shop", "i went to the truck", "i went to the vehicle"]
ranked_docs = wl.rank(query, candidates)
print(ranked_docs)


0.06641249358654022
[('i went to the vehicle', 0.744164764881134), ('i went to the truck', 0.28326916694641113), ('i went to the shop', 0.1973281353712082), ('i went to the park', 0.1510140299797058)]


In [2]:
from wordllama import WordLlama

# Load pre-trained embeddings
# truncate dimension to 64
wl = WordLlama.load(trunc_dim=64)

# Embed text
embeddings = wl.embed(["the quick brown fox jumps over the lazy dog", "and all that jazz"])
print(embeddings.shape)  # (2, 64)

(2, 64)


In [3]:
embeddings

array([[ 0.12470592, -0.07770053,  0.02884188,  0.21934648,  0.18028675,
        -0.12956376, -0.04828436,  0.41509733,  0.07201316,  0.13917126,
         0.0855602 ,  0.12068315,  0.12699752,  0.11927379,  0.00974343,
        -0.564209  , -0.1578487 ,  0.03860474, -0.07334492, -0.26644066,
        -0.5641868 , -0.32762977, -0.38049594,  0.21294056,  0.08007257,
         0.16110507, -0.00196422,  0.3599271 ,  0.0900435 ,  0.17322887,
         0.29006264, -0.29262888,  0.41510564,  0.15712148, -0.02885853,
        -0.11202171,  0.06890869, -0.03536294, -0.06598455,  0.24811624,
        -0.20313574, -0.30398142,  0.07795299, -0.42450505,  0.10664229,
         0.7530911 ,  0.02276056,  0.12312213, -0.3221824 ,  0.15558554,
         0.29816782, -0.02808033, -0.05978116,  0.20450106, -0.4262626 ,
        -0.14082475,  0.22328836, -0.3291959 , -0.18767756, -0.29354027,
        -0.5189202 , -0.4972867 , -0.12053888,  0.11252386],
       [-0.06514788,  0.45723152,  0.11624908, -0.07243347, -0.

In [4]:
# Binary embeddings are packed into uint64
# 64-dims => array of 1x uint64
wl = WordLlama.load(trunc_dim=64, binary=True)  # this will download the binary model from huggingface
wl.embed("I went to the car") # Output: array([[3029168427562626]], dtype=uint64)

# load binary trained model trained with straight through estimator
wl = WordLlama.load(dim=1024, binary=True)

# Uses the hamming similarity to binarize
similarity_score = wl.similarity("i went to the car", "i went to the pawn shop")
print(similarity_score)  # Output: 0.57421875

ranked_docs = wl.rank("i went to the car", ["van", "car"])
print(ranked_docs)

0.06640625
[('car', 0.677734375), ('van', 0.203125)]


In [5]:
import requests
from io import StringIO
import pandas as pd
sentences=[]
urls = [
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.train.tsv',
]
# each of these dataset have the same structure, so we loop through each creating our sentences data
for url in urls:
    res = requests.get(url)
    # extract to dataframe
    data = pd.read_csv(StringIO(res.text), sep='\t', header=None, on_bad_lines='warn')
    # add to columns 1 and 2 to sentences list
    sentences.extend(data[1].tolist())
    sentences.extend(data[2].tolist())

Skipping line 191: expected 3 fields, saw 4
Skipping line 206: expected 3 fields, saw 4
Skipping line 295: expected 3 fields, saw 4
Skipping line 695: expected 3 fields, saw 4
Skipping line 699: expected 3 fields, saw 4



In [6]:
sentences = [word for word in list(set(sentences)) if type(word) is str]
print(len(set(sentences)))

1425


In [7]:
sentences[:10]

['Prairie dogs sold as exotic pets are believed to have been infected in an Illinois pet shop by a Gambian giant rat imported from Africa.',
 'Dennehy, who transferred to Baylor last year after getting kicked off the University of New Mexico Lobos for temper tantrums, had begun to read the Bible daily.',
 'A soldier was killed Monday and another wounded when their convoy was ambushed in northern Iraq.',
 'Redman has allowed two earned runs or less in six of his nine starts.',
 'The government recently shelved peace talks with the MILF, being brokered by Malaysia, after a string of attacks, including three bombings, on Mindanao.',
 'The recent turnaround in the stock market and an easing in unemployment claims should keep consumer expectations at current levels and may signal more favorable economic times ahead.',
 "The 27-year-old rapper's attorney in the civil matter, Mark Gann, did not return calls for comment.",
 'Volume came to 439.66 million shares, below 450.39 million at the sam

In [8]:
import numpy as np

In [9]:
class VectorStore():

    def __init__(self, dim = 1024, binary=False):
        self.embeds = {}
        self.embedding_model =  WordLlama.load(dim=dim, binary=binary)
    def get(self, text):
        return self.embeds[text]

    def add(self, docs):
        embeds = self.embedding_model.embed(docs)
        self.docs = docs
        self.vectors = embeds

    def query(self, text, num_results=10):
        query_embed = self.embedding_model.embed(text)

        results = []

        scores = self.embedding_model.vector_similarity(query_embed[0], self.vectors)
        scores = scores.squeeze()
        similarities = list(zip(self.docs, scores.tolist()))
        similarities.sort(key=lambda x: x[1], reverse=True)
        results = similarities[:num_results]
        return results


In [10]:
vector_store = VectorStore()

In [11]:
vector_store.add(sentences)

In [12]:
vector_store.query('football')

[('The pressure may well rise on Thursday, with national coverage of the final round planned by ESPN, the cable sports network.',
  0.2862741947174072),
 ('Stanford (46-15) plays South Carolina (44-20) today in a first-round game at Rosenblatt Stadium in Omaha, Neb.',
  0.23061472177505493),
 ('That is up from $1.14 billion during the same quarter last year.',
  0.1976434588432312),
 ("Moore of Alabama says he will appeal his case to the nation's highest court.",
  0.18732286989688873),
 ('The pressure will intensify today, with national coverage of the final round planned by ESPN and words that are even more difficult.',
  0.1754290759563446),
 ('Stanford (51-17) and Rice (57-12) will play for the national championship tonight.',
  0.17530187964439392),
 ('The charges of espionage and aiding the enemy can carry the death penalty.',
  0.16965657472610474),
 ('The Dodgers won their sixth consecutive game their longest win streak since 2001 as they edged Colorado, 3-2, Wednesday in front

In [13]:
vector_store.query('what are the violations', num_results=5)

[('The federal courts have ruled that the monument violates the constitutional ban against state-established religion.',
  0.39459460973739624),
 ('In that case, the court held that the city of Cincinnati had violated the First Amendment in banning, in the interest of aesthetics, only the advertising pamphlets.',
  0.3732220530509949),
 ('In that case, the court held that Cincinnati had violated the First Amendment in banning only the advertising pamphlets in the interest of aesthetics.',
  0.3661629557609558),
 ('Lay had argued that handing over the documents would be a violation of his Fifth Amendment rights against self-incrimination.',
  0.3621053993701935),
 ('Under the law, telemarketers who call numbers on the list can be fined up to $11,000 for each violation.',
  0.35536259412765503)]

In [14]:
vector_store.query('tropical storm', num_results=5)

[('A tropical storm rapidly developed in the Gulf of Mexico on Sunday and could have hurricane-force winds when it hits land somewhere along the Louisiana coast Monday night.',
  0.6678519248962402),
 ('A tropical storm rapidly developed in the Gulf of Mexico Sunday and was expected to hit somewhere along the Texas or Louisiana coasts by Monday night.',
  0.5823907852172852),
 ('Crews worked to install a new culvert and prepare the highway so motorists could use the eastbound lanes for travel as storm clouds threatened to dump more rain.',
  0.3848031461238861),
 ("The heatwave was due to a mass of hot, dry air from the southeast, said Mario Almeida of Portugal's weather service.",
  0.31797584891319275),
 ('The weather service reported maximum sustained winds of nearly 105 miles an hour with stronger gusts.',
  0.2692662179470062)]