In [1]:
import pandas as pd
import tiktoken
from openai.embeddings_utils import get_embedding

In [2]:
embedding_model = "text-embedding-ada-002"
embedding_encoding = "cl100k_base"
max_tokens = 500

In [3]:
data_path = "data/ju.csv"
df = pd.read_csv(data_path)
df

Unnamed: 0,title,class
0,"Zhu Yanjun, Guangzhou Chengyi Technology Softw...",technology infringement
1,Li Yangyuan and Suzhou Mindray Microelectronic...,patent infringement
2,Civil Judgment of the Second Instance Civil Ju...,abuse of market power
3,The second-instance judgment of Guangxi Nanfan...,Software development contract disputes
4,Civil Judgment of Second Instance Civil Judgme...,technology infringement
5,"Sichuan Zhunda Information Technology Co., Ltd...",Software development contract disputes
6,Civil Judgment of Second Instance Civil Judgme...,Software development contract disputes
7,"Sichuan Xinneng Yufeng E-Commerce Co., Ltd., S...",Software development contract disputes
8,"Guangdong Yirun Network Technology Co., Ltd., ...",Software development contract disputes
9,"Shaanxi Navigation Technology Co., Ltd., Shaan...",Software development contract disputes


In [4]:
# count token number
encoding = tiktoken.get_encoding(embedding_encoding)
df["n_tokens"] = df.title.apply(lambda x: len(encoding.encode(x)))
df

Unnamed: 0,title,class,n_tokens
0,"Zhu Yanjun, Guangzhou Chengyi Technology Softw...",technology infringement,35
1,Li Yangyuan and Suzhou Mindray Microelectronic...,patent infringement,27
2,Civil Judgment of the Second Instance Civil Ju...,abuse of market power,37
3,The second-instance judgment of Guangxi Nanfan...,Software development contract disputes,34
4,Civil Judgment of Second Instance Civil Judgme...,technology infringement,43
5,"Sichuan Zhunda Information Technology Co., Ltd...",Software development contract disputes,31
6,Civil Judgment of Second Instance Civil Judgme...,Software development contract disputes,45
7,"Sichuan Xinneng Yufeng E-Commerce Co., Ltd., S...",Software development contract disputes,38
8,"Guangdong Yirun Network Technology Co., Ltd., ...",Software development contract disputes,33
9,"Shaanxi Navigation Technology Co., Ltd., Shaan...",Software development contract disputes,33


In [5]:
# check openai_key
import os
print(os.getenv("OPENAI_API_KEY_AA01") is None)

False


In [6]:
# check openai api accessibility
import openai
openai.organization = "org-3jO3a0KuKe3PxW3wakO9tPgK"
openai.api_key = os.getenv("OPENAI_API_KEY_AA01")
print(openai.Model.list() is None)

False


In [None]:
# proxy setup. Try this if api call returns error code: 443
os.environ["http_proxy"] = "http://localhost:7890"
os.environ["https_proxy"] = "http://localhost:7890"

In [7]:
# get embeddings
df["embedding"] = df.title.apply(lambda x: get_embedding(x, engine=embedding_model))
df

Unnamed: 0,title,class,n_tokens,embedding
0,"Zhu Yanjun, Guangzhou Chengyi Technology Softw...",technology infringement,35,"[0.016624698415398598, -0.01290762796998024, 0..."
1,Li Yangyuan and Suzhou Mindray Microelectronic...,patent infringement,27,"[-0.002745145233348012, 0.0009384964942000806,..."
2,Civil Judgment of the Second Instance Civil Ju...,abuse of market power,37,"[0.0055884248577058315, -0.0010189167223870754..."
3,The second-instance judgment of Guangxi Nanfan...,Software development contract disputes,34,"[0.022417180240154266, -0.009231014177203178, ..."
4,Civil Judgment of Second Instance Civil Judgme...,technology infringement,43,"[0.010469283908605576, 0.00784167181700468, 0...."
5,"Sichuan Zhunda Information Technology Co., Ltd...",Software development contract disputes,31,"[0.01774647645652294, -0.007223053369671106, 0..."
6,Civil Judgment of Second Instance Civil Judgme...,Software development contract disputes,45,"[0.014836732298135757, -0.005966946482658386, ..."
7,"Sichuan Xinneng Yufeng E-Commerce Co., Ltd., S...",Software development contract disputes,38,"[0.016408946365118027, -0.006478115450590849, ..."
8,"Guangdong Yirun Network Technology Co., Ltd., ...",Software development contract disputes,33,"[0.014321736991405487, -0.003787795314565301, ..."
9,"Shaanxi Navigation Technology Co., Ltd., Shaan...",Software development contract disputes,33,"[0.01996522955596447, -0.002968967193737626, 0..."


In [8]:
df.to_csv("data/ju_embedded.csv", index=None)

In [9]:
# human generate some tags, calculate the embeddings of these tags
target_classes = ["technology infringement", "patent infringement", "abuse of market power", "Software development contract disputes", "pig"]
target_embeddings = [get_embedding(label, engine=embedding_model) for label in target_classes]

In [11]:
# check the cosine similarities between these tags
from openai.embeddings_utils import cosine_similarity
for i in range(len(target_embeddings)):
    similarities = [cosine_similarity(target_embeddings[i], target_embedding) for target_embedding in target_embeddings]
    print(similarities)

[1.0000000000000002, 0.925647054759034, 0.8157555596300413, 0.8129563699138294, 0.7628307118616373]
[0.925647054759034, 1.0000000000000002, 0.8198857640829815, 0.8201911197535084, 0.7673704351730666]
[0.8157555596300413, 0.8198857640829815, 1.0000000000000002, 0.783226435790348, 0.7489179160732189]
[0.8129563699138294, 0.8201911197535084, 0.783226435790348, 1.0000000000000002, 0.7410881313891087]
[0.7628307118616373, 0.7673704351730666, 0.7489179160732189, 0.7410881313891087, 1.0000000000000002]


In [12]:
# calculate the cosine similarities between titles and tags
from sklearn.metrics import PrecisionRecallDisplay

def classify(review_embedding):
    similarities = [cosine_similarity(review_embedding, target_embedding) for target_embedding in target_embeddings]
    print(similarities)
    p = similarities.index(max(similarities))
    return target_classes[p]


df["zero-shot-classification"] = df.embedding.apply(classify)
df

[0.8286762021856187, 0.8401095200372273, 0.7593370004595354, 0.8195197717758946, 0.7398602899512295]
[0.7960117927736009, 0.8309362597692548, 0.7510981958436052, 0.7916498966206983, 0.736089563184387]
[0.7648413965627104, 0.8043717541149226, 0.8312954669544891, 0.7887688180461454, 0.7430132521652483]
[0.8066082818476685, 0.8154351849772411, 0.7668447606185017, 0.8645124892039635, 0.7500207929965194]
[0.7994759753203882, 0.8243809563494928, 0.7424984158223059, 0.7851505524294107, 0.7303098955415537]
[0.7908814007156175, 0.789900231027692, 0.7469979865287012, 0.8497875378480931, 0.7326679685106556]
[0.790437979547356, 0.8097571712833654, 0.7575435554173968, 0.8513298431243538, 0.7486480360593541]
[0.7878007725093376, 0.7904966869704027, 0.7453725350084663, 0.8284685973468541, 0.732272613248034]
[0.7963056742101442, 0.7994569562673329, 0.7515204707811226, 0.8428107028640334, 0.7363382474345929]
[0.7863566082113133, 0.7862016646290221, 0.7409261689470825, 0.8352032168633531, 0.723321201989

Unnamed: 0,title,class,n_tokens,embedding,zero-shot-classification
0,"Zhu Yanjun, Guangzhou Chengyi Technology Softw...",technology infringement,35,"[0.016624698415398598, -0.01290762796998024, 0...",patent infringement
1,Li Yangyuan and Suzhou Mindray Microelectronic...,patent infringement,27,"[-0.002745145233348012, 0.0009384964942000806,...",patent infringement
2,Civil Judgment of the Second Instance Civil Ju...,abuse of market power,37,"[0.0055884248577058315, -0.0010189167223870754...",abuse of market power
3,The second-instance judgment of Guangxi Nanfan...,Software development contract disputes,34,"[0.022417180240154266, -0.009231014177203178, ...",Software development contract disputes
4,Civil Judgment of Second Instance Civil Judgme...,technology infringement,43,"[0.010469283908605576, 0.00784167181700468, 0....",patent infringement
5,"Sichuan Zhunda Information Technology Co., Ltd...",Software development contract disputes,31,"[0.01774647645652294, -0.007223053369671106, 0...",Software development contract disputes
6,Civil Judgment of Second Instance Civil Judgme...,Software development contract disputes,45,"[0.014836732298135757, -0.005966946482658386, ...",Software development contract disputes
7,"Sichuan Xinneng Yufeng E-Commerce Co., Ltd., S...",Software development contract disputes,38,"[0.016408946365118027, -0.006478115450590849, ...",Software development contract disputes
8,"Guangdong Yirun Network Technology Co., Ltd., ...",Software development contract disputes,33,"[0.014321736991405487, -0.003787795314565301, ...",Software development contract disputes
9,"Shaanxi Navigation Technology Co., Ltd., Shaan...",Software development contract disputes,33,"[0.01996522955596447, -0.002968967193737626, 0...",Software development contract disputes


In [13]:
df.to_csv("data/ju_zero_shot_classification.csv", index=None)