In [88]:
import uuid
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from arango import ArangoClient
from sklearn.preprocessing import LabelEncoder
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.base import BaseEstimator, TransformerMixin
from transformers import pipeline, AutoTokenizer, AutoModel

In [2]:
raw_df = pd.read_csv('./data/raw.csv')

In [3]:
raw_df.head()

Unnamed: 0,country_of_destination,country_of_origin,hs6_code,receiver_address,receiver_id,receiver_name,sender_address,sender_id,sender_name,tier,transaction_date
0,CA,US,321410,"95 SUNRAY STREET, WHITBY, ON, L1N9C9, CANADA T...",4bd999c0cb4346ce047f99a692ec3191,Stonhard Ltd,"60 MARKET SAUQRE, BELIZE CITY, BELIZE",52c8642c34649eafe2d044eee3d884e1,Global Fire Protection,t1->t0,11/30/18
1,,,320890,China,8b1f537a8eeca839cff8a99e08b35690,CARBOLINE DALIAN PAINT COMPANY LTD,"Plot O 356 &357, Sidco Indl Estate Ambattur Ch...",831a31a1466f0ace3eb20b52d4575f92,Carboline (India) Private Limited,t1->t0,5/2/18
2,US,IN,390940,United States,1670c05b29762f5d4ab8980d267bd482,Rust-Oleum Llc,"204, Monarch Chambers, Marol Maroshi Road, Mum...",f580bc1756d06768c94634b0332e1871,Paladin Paints & Chemicals Private Limited,t1->t0,5/2/17
3,US,CN,731100,OF RPM WOOD FINISHES GROUP 3190 HICKORY BLVD H...,210f645e15d75833b25090f2ae509eda,Mohawk Finishing Products Inc,"YUNG-SHI BLVD, SHI-WAN TOWN, PO-LO COUNTY HUI-...",149081be00548b006dc38a88264eae32,Alpha Pacific Group Pte. Ltd.,t1->t0,5/8/16
4,CH,IN,250590,Switzerland,7150428f8927c86ac11c64c1bb9b0285,Vandex International AG,FLOWCRETE INDIA PRIVATE LIMITEDNEW NO.36 OLD N...,f4864ac3d5d716cc586d60afe8a403ef,TREMCO CPG INDIA PRIVATE LIMITED,t1->t0,5/26/17


In [4]:
sender_df = raw_df[['sender_id', 'sender_address', 'country_of_origin']].copy()
sender_df.rename(columns={'sender_id':'organization_id', 'sender_address':'address', 'country_of_origin':'country'}, inplace=True)

In [5]:
receiver_df = raw_df[['receiver_id', 'receiver_address', 'country_of_destination']].copy()
receiver_df.rename(columns={'receiver_id':'organization_id', 'receiver_address':'address', 'country_of_destination':'country'}, inplace=True)

In [6]:
address_df = pd.concat([sender_df, receiver_df], ignore_index=True).drop_duplicates(ignore_index=True)

In [7]:
address_df.head()

Unnamed: 0,organization_id,address,country
0,52c8642c34649eafe2d044eee3d884e1,"60 MARKET SAUQRE, BELIZE CITY, BELIZE",US
1,831a31a1466f0ace3eb20b52d4575f92,"Plot O 356 &357, Sidco Indl Estate Ambattur Ch...",
2,f580bc1756d06768c94634b0332e1871,"204, Monarch Chambers, Marol Maroshi Road, Mum...",IN
3,149081be00548b006dc38a88264eae32,"YUNG-SHI BLVD, SHI-WAN TOWN, PO-LO COUNTY HUI-...",CN
4,f4864ac3d5d716cc586d60afe8a403ef,FLOWCRETE INDIA PRIVATE LIMITEDNEW NO.36 OLD N...,IN


In [8]:
address_df[address_df.country.isna()].address.iloc[8]

'Russian Federation'

In [9]:
address_df.shape

(2546, 3)

In [10]:
# ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")

In [11]:
# # def extract_country(address):
# #     # Apply NER to the address
# #     ner_results = ner_pipeline(address)
    
# #     # Clean up the results to avoid extra spaces
# #     locations = [entity['word'].strip() for entity in ner_results if 'LOC' in entity['entity']]
    
# #     # Join the locations detected
# #     return ' '.join(locations) if locations else "Unknown"
# def extract_country_code(address):
#     # Run NER on the input address
#     entities = ner_pipeline(address)
    
#     # Extract unique country identifiers (assuming they have been recognized)
#     country_codes = []
#     for entity in entities:
#         if entity['entity'] == 'B-LOC' or entity['entity'] == 'I-LOC':
#             country_codes.append(entity['word'])
#     print(entities)
    
#     # Convert extracted country names to codes
#     unique_country_codes = set(country_codes)
#     return list(unique_country_codes)

In [12]:
# site_df['extracted_country'] = site_df['address'].apply(extract_country)

In [13]:
# extract_country_code("60 MARKET SAUQRE, BELIZE CITY, BELIZE")

In [14]:
address_df.country.value_counts()

country
US    545
GB    139
CN    127
IT    123
DE    104
     ... 
DJ      1
IR      1
SZ      1
ET      1
GT      1
Name: count, Length: 85, dtype: int64

In [15]:
address_df.country.isna().mean()

np.float64(0.26983503534956793)

In [16]:
address_df.country.str.len().unique()

array([ 2., nan])

In [17]:
address_df[address_df.country.isna()]

Unnamed: 0,organization_id,address,country
1,831a31a1466f0ace3eb20b52d4575f92,"Plot O 356 &357, Sidco Indl Estate Ambattur Ch...",
8,c61f1f95b68a9c7caf0e642be8e36e07,CR 73 48 46 BRR NORMANDIA II SECTOR,
13,140c63ccedc8eac19c62398e6c1af6d0,AVENIDA LAS MISIONES 21 PARQUE INDUSTRIAL BERN...,
21,80e44a0be590c7c9ae03e5d83ad0a6e5,S. SUKHATME ) 1105/3 MANAS NARGIS DUTTA RD. M...,
28,233cbea86cfae088bcd6680aadd4fcc4,ADD: NO.3 OF GONGYE 3RD ROAD. BIJIA INDUSTRIAL...,
...,...,...,...
2500,1d73f9dd5be62952c256d617d199aa24,PP 8105 95TH STREET PLEASANT PRAIRI PRAIRIE WI...,
2507,59b9fcc328c9c3edd4c7d29c21f1cf28,"8905 WOLLARD BLVD RICHMOND, MO 640 85 TAX ID ...",
2508,183a723c366275c52a311db74d68402b,Malaysia,
2509,1d73f9dd5be62952c256d617d199aa24,8105 95TH STREET PLEASANT PRAIRIE WI 53158 UN...,


In [18]:
address_df[address_df.country.isna()].sample(1)

Unnamed: 0,organization_id,address,country
2229,1184d782d482ec68beddd3469ccad59e,China,


In [19]:
model_df = address_df[~address_df.country.isna()].reset_index(drop=True)

In [20]:
model_df

Unnamed: 0,organization_id,address,country
0,52c8642c34649eafe2d044eee3d884e1,"60 MARKET SAUQRE, BELIZE CITY, BELIZE",US
1,f580bc1756d06768c94634b0332e1871,"204, Monarch Chambers, Marol Maroshi Road, Mum...",IN
2,149081be00548b006dc38a88264eae32,"YUNG-SHI BLVD, SHI-WAN TOWN, PO-LO COUNTY HUI-...",CN
3,f4864ac3d5d716cc586d60afe8a403ef,FLOWCRETE INDIA PRIVATE LIMITEDNEW NO.36 OLD N...,IN
4,a26b29b47ac4c97f4285096343d85b79,"RUA DO RIBEIRINHO, 202 APARTADO 13 S.PAIO OL...",PT
...,...,...,...
1854,00750bb980757fd8535d73b90d6608c8,4518 HAMILTON AVENUE CLEVELAND OHIO 44114 USA ...,US
1855,2c398f54dacfc67b9fbe414565285dba,19320 REDWOOD RD. 44110-2799 CLEVELAND CLEVELAND,US
1856,2c398f54dacfc67b9fbe414565285dba,"19218 REDWOOD ROAD CLEVELAND, OH CLEVELANDOH ...",US
1857,1b4334d45be8824a1f6b071a9a2f75a8,United States,US


In [21]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', clean_up_tokenization_spaces=True)

In [22]:
model = AutoModel.from_pretrained('bert-base-uncased')

In [23]:
def get_embedding(text):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=64)
    with torch.no_grad():
        outputs = model(**inputs)
    # Use the CLS token's hidden state as the sentence embedding
    return outputs.last_hidden_state[:, 0, :].squeeze().numpy()

In [24]:
embeddings = np.array([get_embedding(address) for address in tqdm(model_df['address'], desc="Extracting embeddings")])

Extracting embeddings: 100%|██████████| 1859/1859 [02:25<00:00, 12.80it/s]


In [25]:
embeddings.shape

(1859, 768)

In [26]:
pca = PCA(n_components=50)  # You can adjust n_components based on variance explained
embeddings_reduced = pca.fit_transform(embeddings)

In [27]:
embeddings_reduced.shape

(1859, 50)

In [28]:
le=LabelEncoder()

In [29]:
model_df['country_encoded'] = le.fit_transform(model_df['country'])

In [30]:
X_train, X_test, y_train, y_test = train_test_split(embeddings_reduced, model_df['country_encoded'], test_size=0.2, random_state=42)

In [95]:
clf = RandomForestClassifier(n_estimators=100, max_depth=5, min_samples_split=5, random_state=42)

In [96]:
clf.fit(X_train, y_train)

In [97]:
train_pred = clf.predict(X_train)
train_pred_labels = le.inverse_transform(train_pred)
train_labels = le.inverse_transform(y_train)
print(classification_report(train_labels, train_pred_labels))

              precision    recall  f1-score   support

          AE       0.00      0.00      0.00        17
          AR       0.00      0.00      0.00         6
          AT       0.00      0.00      0.00        11
          AU       0.00      0.00      0.00        15
          AZ       0.00      0.00      0.00         1
          BE       1.00      0.42      0.60        33
          BG       0.00      0.00      0.00         1
          BR       0.00      0.00      0.00         6
          BY       0.00      0.00      0.00         2
          CA       0.00      0.00      0.00        19
          CH       0.00      0.00      0.00        13
          CL       0.00      0.00      0.00         1
          CN       0.86      0.74      0.79       104
          CO       0.00      0.00      0.00         9
          CR       0.00      0.00      0.00         4
          CU       0.00      0.00      0.00         2
          CY       0.00      0.00      0.00         1
          CZ       0.00    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [98]:
y_pred = clf.predict(X_test)
y_pred_labels = le.inverse_transform(y_pred)
y_test_labels = le.inverse_transform(y_test)
print(classification_report(y_test_labels, y_pred_labels))

              precision    recall  f1-score   support

          AE       0.00      0.00      0.00         5
          AR       0.00      0.00      0.00         2
          AT       0.00      0.00      0.00         2
          AU       0.00      0.00      0.00         7
          BE       1.00      0.50      0.67         4
          BR       0.00      0.00      0.00         1
          CA       0.00      0.00      0.00         4
          CG       0.00      0.00      0.00         1
          CH       0.00      0.00      0.00        10
          CL       0.00      0.00      0.00         3
          CN       0.77      0.43      0.56        23
          CO       0.00      0.00      0.00         2
          CY       0.00      0.00      0.00         2
          CZ       0.00      0.00      0.00         3
          DE       0.79      0.52      0.63        21
          DK       0.00      0.00      0.00         3
          DO       0.00      0.00      0.00         2
          ES       0.00    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [38]:
def predict_country_code(address):
    embedding = get_embedding(address)
    embedding_reduced = pca.transform([embedding])  # Apply PCA transformation
    prediction = clf.predict(embedding_reduced)
    return le.inverse_transform(prediction)[0]

In [None]:
class EmbeddingExtractor(BaseEstimator, TransformerMixin):
    def __init__(self, model_name='bert-base-uncased'):
        self.tokenizer = AuthTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)
        self.model = AutoModel.from_pretrained(model_name)

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        embeddings = []
        for text in tqdm(X, desc="Extracting embeddings"):
            inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=64)
            with torch.no_grad():
                outputs = self.model(**outputs)
            embeddings.append(outputs.last_hidden_state[:, 0, :].squeeze().numpy())
        return np.array(embeddings)

In [None]:
class AddressCountryClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, n_components=50, n_estimators=100)
        self.n_components = n_components
        self.n_estimators = n_estimators
        self.le = LabelEncoder()
        self.embeddings_model = None
        self.tokenizer = None
        self.pca = PCA(n_components = self.n_components)
        self.clf = RandomForestClassifier(n_estimators=self.n_estimators, random_state=42)

    def fit(self, X, y):
        self.le.fit(y)

        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.model = AutoModel.from_pretrained('bert-base-uncased')

        embeddings = np.array([self.get_embedding(address) for address in tqdm(X, desc="Extracting embeddings")])

        embeddings_reduced = self.pca.fit_transform(embeddings)

        self.clf.fit(embeddings_reduced, self.le.transform(y))

        return self

    def predict(self, X):
        embeddings = np.array([self.get_embedding(address) for address in tqdm(X, desc="Extracting embeddings")])
        embeddings_reduced = self.pca.transform(embeddings)

    

In [40]:
predict_country_code("Guanzhou")

'CN'

In [41]:
predict_country_code("Guanhou")

'CN'

In [42]:
predict_country_code("New York")

'US'

In [43]:
predict_country_code("United Kingdom")

'GB'

In [44]:
def infer_missing_country_codes(df):
    missing_mask = df['country'].isnull()  # Mask for missing country codes
    addresses_to_predict = df.loc[missing_mask, 'address']
    
    # Get embeddings and reduce dimensionality for missing addresses in bulk
    embeddings = np.array([get_embedding(address) for address in tqdm(addresses_to_predict, desc="Extracting embeddings"]))
    embeddings_reduced = pca.transform(embeddings)
    
    # Predict country codes in bulk
    predictions = clf.predict(embeddings_reduced)
    predicted_country_codes = le.inverse_transform(predictions)
    
    # Fill the missing country codes
    df.loc[missing_mask, 'country'] = predicted_country_codes
    
    return df

In [45]:
address_df_filled = infer_missing_country_codes(address_df)

In [80]:
address_df_filled.to_csv('./data/address.csv', index=False)

In [46]:
address_df_filled.country.isna().sum()

np.int64(0)

In [47]:
address_df_filled

Unnamed: 0,organization_id,address,country
0,52c8642c34649eafe2d044eee3d884e1,"60 MARKET SAUQRE, BELIZE CITY, BELIZE",US
1,831a31a1466f0ace3eb20b52d4575f92,"Plot O 356 &357, Sidco Indl Estate Ambattur Ch...",US
2,f580bc1756d06768c94634b0332e1871,"204, Monarch Chambers, Marol Maroshi Road, Mum...",IN
3,149081be00548b006dc38a88264eae32,"YUNG-SHI BLVD, SHI-WAN TOWN, PO-LO COUNTY HUI-...",CN
4,f4864ac3d5d716cc586d60afe8a403ef,FLOWCRETE INDIA PRIVATE LIMITEDNEW NO.36 OLD N...,IN
...,...,...,...
2541,00750bb980757fd8535d73b90d6608c8,4518 HAMILTON AVENUE CLEVELAND OHIO 44114 USA ...,US
2542,2c398f54dacfc67b9fbe414565285dba,19320 REDWOOD RD. 44110-2799 CLEVELAND CLEVELAND,US
2543,2c398f54dacfc67b9fbe414565285dba,"19218 REDWOOD ROAD CLEVELAND, OH CLEVELANDOH ...",US
2544,1b4334d45be8824a1f6b071a9a2f75a8,United States,US


In [65]:
site_df = address_df_filled[['organization_id', 'country']].drop_duplicates()

In [66]:
site_df.insert(0, '_key', [uuid.uuid4().hex for _ in range(len(site_df))])

In [68]:
site_df['organization_id'] = 'organizations/' + site_df['organization_id']

In [69]:
site_df

Unnamed: 0,_key,organization_id,country
0,5f5d506385474ebf9cf2c97e748dd07b,organizations/52c8642c34649eafe2d044eee3d884e1,US
1,54c78f96680e4b7b912a28ba11ddc0eb,organizations/831a31a1466f0ace3eb20b52d4575f92,US
2,8d6cddc3df8040b3b15f6eb78b419edd,organizations/f580bc1756d06768c94634b0332e1871,IN
3,093b7092701f43fbaad91d6faf8a20e9,organizations/149081be00548b006dc38a88264eae32,CN
4,8853aa0467ee4d1c97bbf341c4a1927c,organizations/f4864ac3d5d716cc586d60afe8a403ef,IN
...,...,...,...
2522,2be5373e281249ae851b80111d3e7579,organizations/2fc4208e46b525d0b90e4662a7d7e2f4,US
2524,fcdc3607e8e6491c83f6a71362b215c1,organizations/d3a2485b31c1e99218252f115a29d02e,US
2525,972e3126b93a40d59f6fb42047a0baa0,organizations/ba616735135653cb6ff28ca5856e6599,US
2539,034a81c43ff0453dac52d247f3cbd226,organizations/68b95e65719c3deb6c32fda273487981,US


In [52]:
client = ArangoClient(hosts="https://2ae4f052d710.arangodb.cloud:8529")

db = client.db('machine_learning', username='lab_test', password='z-rRdN-Enf4qQwybGiVdbG')

In [74]:
# db.delete_collection('sites')

True

In [75]:
if not db.has_collection('sites'):
    sites = db.create_collection('sites')
else:
    sites = db.collection('sites')

In [76]:
sites_list = site_df.to_dict(orient='records')

In [78]:
# sites_list
sites.insert_many(sites_list, overwrite_mode="replace", silent=True)

True