In [1]:
import uuid
import joblib
import pandas as pd
from arango import ArangoClient
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

In [2]:
raw_df = pd.read_csv('../data/raw.csv')

In [3]:
raw_df.head()

Unnamed: 0,country_of_destination,country_of_origin,hs6_code,receiver_address,receiver_id,receiver_name,sender_address,sender_id,sender_name,tier,transaction_date
0,CA,US,321410,"95 SUNRAY STREET, WHITBY, ON, L1N9C9, CANADA T...",4bd999c0cb4346ce047f99a692ec3191,Stonhard Ltd,"60 MARKET SAUQRE, BELIZE CITY, BELIZE",52c8642c34649eafe2d044eee3d884e1,Global Fire Protection,t1->t0,11/30/18
1,,,320890,China,8b1f537a8eeca839cff8a99e08b35690,CARBOLINE DALIAN PAINT COMPANY LTD,"Plot O 356 &357, Sidco Indl Estate Ambattur Ch...",831a31a1466f0ace3eb20b52d4575f92,Carboline (India) Private Limited,t1->t0,5/2/18
2,US,IN,390940,United States,1670c05b29762f5d4ab8980d267bd482,Rust-Oleum Llc,"204, Monarch Chambers, Marol Maroshi Road, Mum...",f580bc1756d06768c94634b0332e1871,Paladin Paints & Chemicals Private Limited,t1->t0,5/2/17
3,US,CN,731100,OF RPM WOOD FINISHES GROUP 3190 HICKORY BLVD H...,210f645e15d75833b25090f2ae509eda,Mohawk Finishing Products Inc,"YUNG-SHI BLVD, SHI-WAN TOWN, PO-LO COUNTY HUI-...",149081be00548b006dc38a88264eae32,Alpha Pacific Group Pte. Ltd.,t1->t0,5/8/16
4,CH,IN,250590,Switzerland,7150428f8927c86ac11c64c1bb9b0285,Vandex International AG,FLOWCRETE INDIA PRIVATE LIMITEDNEW NO.36 OLD N...,f4864ac3d5d716cc586d60afe8a403ef,TREMCO CPG INDIA PRIVATE LIMITED,t1->t0,5/26/17


In [4]:
sender_df = raw_df[['sender_id', 'sender_address', 'country_of_origin']].copy()
sender_df.rename(columns={'sender_id':'organization_id', 'sender_address':'address', 'country_of_origin':'country'}, inplace=True)

In [5]:
receiver_df = raw_df[['receiver_id', 'receiver_address', 'country_of_destination']].copy()
receiver_df.rename(columns={'receiver_id':'organization_id', 'receiver_address':'address', 'country_of_destination':'country'}, inplace=True)

In [6]:
address_df = pd.concat([sender_df, receiver_df], ignore_index=True).drop_duplicates(ignore_index=True)

In [7]:
address_df.describe()

Unnamed: 0,organization_id,address,country
count,2546,2546,1859
unique,2072,1758,85
top,00750bb980757fd8535d73b90d6608c8,UNKNOWN ADDRESS,US
freq,16,171,545


In [8]:
address_df.shape

(2546, 3)

In [9]:
address_df.head()

Unnamed: 0,organization_id,address,country
0,52c8642c34649eafe2d044eee3d884e1,"60 MARKET SAUQRE, BELIZE CITY, BELIZE",US
1,831a31a1466f0ace3eb20b52d4575f92,"Plot O 356 &357, Sidco Indl Estate Ambattur Ch...",
2,f580bc1756d06768c94634b0332e1871,"204, Monarch Chambers, Marol Maroshi Road, Mum...",IN
3,149081be00548b006dc38a88264eae32,"YUNG-SHI BLVD, SHI-WAN TOWN, PO-LO COUNTY HUI-...",CN
4,f4864ac3d5d716cc586d60afe8a403ef,FLOWCRETE INDIA PRIVATE LIMITEDNEW NO.36 OLD N...,IN


In [10]:
address_df.country.value_counts()

country
US    545
GB    139
CN    127
IT    123
DE    104
     ... 
DJ      1
IR      1
SZ      1
ET      1
GT      1
Name: count, Length: 85, dtype: int64

In [11]:
address_df[address_df.country.isna()].sample(1)

Unnamed: 0,organization_id,address,country
1049,897c3b85769b28051a7253dec2a18325,UNKNOWN ADDRESS,


In [12]:
model_df = address_df[~address_df.country.isna()].reset_index(drop=True)

In [15]:
%%writefile ../artifacts/country_classifier_classes_1.py

import torch
import numpy as np
from tqdm import tqdm
from collections import Counter
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from transformers import pipeline, AutoTokenizer, AutoModel

class TopNClassTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, top_n=10, other_label='Other'):
        self.top_n = top_n
        self.other_label = other_label
        self.top_classes_ = None

    def fit(self, y):
        class_counts = Counter(y)
        self.top_classes_ = [cls for cls, _ in class_counts.most_common(self.top_n)]
        return self

    def transform(self, y):
        return np.where(np.isin(y, self.top_classes_), y, 'Other') 

    def fit_transform(self, y):
        self.fit(y)
        return self.transform(y)

class AddressCountryClassifier(BaseEstimator, ClassifierMixin):
    
    def __init__(self, n_components=100, n_estimators=1000):
        self.n_components = n_components
        self.n_estimators = n_estimators
        self.tnct = TopNClassTransformer()
        self.le = LabelEncoder()
        self.embeddings_model = None
        self.tokenizer = None
        self.pca = PCA(n_components = self.n_components)
        self.rfc = RandomForestClassifier(n_estimators=self.n_estimators, max_depth=7, min_samples_split=8, min_samples_leaf=4, class_weight="balanced_subsample", max_features = "log2", random_state=42)

    def fit(self, X, y):
        y_contained = self.tnct.fit_transform(y)
        y_transformed = self.le.fit_transform(y_contained)

        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', clean_up_tokenization_spaces=True)
        self.model = AutoModel.from_pretrained('bert-base-uncased')

        embeddings = np.array([self.get_embedding(address) for address in tqdm(X, desc="Extracting embeddings")])
        embeddings_reduced = self.pca.fit_transform(embeddings)

        self.rfc.fit(embeddings_reduced, y_transformed)

        return self

    def predict(self, X):
        embeddings = np.array([self.get_embedding(address) for address in tqdm(X, desc="Extracting embeddings")])
        embeddings_reduced = self.pca.transform(embeddings)
        
        y_pred = self.rfc.predict(embeddings_reduced)
        
        return self.le.inverse_transform(y_pred).tolist()

    def get_embedding(self, text):
        inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=64)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        return outputs.last_hidden_state[:, 0, :].squeeze().numpy()

Overwriting ../artifacts/country_classifier_classes_1.py


In [16]:
%run ../artifacts/country_classifier_classes_1.py

In [15]:
clf = AddressCountryClassifier()

X_train, X_test, y_train, y_test = train_test_split(model_df.address, model_df.country, test_size=0.2, random_state=42)

clf.fit(X_train, y_train)

Extracting embeddings: 100%|██████████| 1487/1487 [05:30<00:00,  4.50it/s]


In [16]:
train_pred = clf.predict(X_train)
print(classification_report(clf.tnct.transform(y_train), train_pred))

Extracting embeddings: 100%|██████████| 1487/1487 [01:46<00:00, 14.02it/s]


              precision    recall  f1-score   support

          BE       1.00      1.00      1.00        33
          CN       0.79      0.99      0.88       104
          DE       0.82      0.99      0.90        83
          FR       0.91      1.00      0.95        31
          GB       0.79      0.98      0.87       121
          IN       0.89      1.00      0.94        62
          IT       0.96      0.99      0.97        88
          MX       0.86      0.96      0.91        53
          NL       0.95      1.00      0.98        59
       Other       0.97      0.86      0.91       409
          US       0.97      0.88      0.92       444

    accuracy                           0.92      1487
   macro avg       0.90      0.97      0.93      1487
weighted avg       0.93      0.92      0.92      1487



In [17]:
y_pred = clf.predict(X_test)
print(classification_report(clf.tnct.transform(y_test), y_pred))

Extracting embeddings: 100%|██████████| 372/372 [00:25<00:00, 14.76it/s]


              precision    recall  f1-score   support

          BE       1.00      0.50      0.67         4
          CN       0.67      0.96      0.79        23
          DE       0.64      0.76      0.70        21
          FR       1.00      0.71      0.83         7
          GB       0.64      0.89      0.74        18
          IN       0.95      0.95      0.95        19
          IT       0.95      1.00      0.97        35
          MX       0.69      0.85      0.76        13
          NL       0.73      0.92      0.81        12
       Other       0.91      0.73      0.81       119
          US       0.87      0.85      0.86       101

    accuracy                           0.83       372
   macro avg       0.82      0.83      0.81       372
weighted avg       0.85      0.83      0.83       372



In [18]:
clf.predict(["Guanzhou", "Guanzou", "NY", "Arlington, VA"])

Extracting embeddings: 100%|██████████| 4/4 [00:00<00:00, 19.61it/s]


['CN', 'CN', 'US', 'US']

In [19]:
def infer_missing_country_codes(dataframe, classifier):
    missing_mask = dataframe['country'].isnull()  # Mask for missing country codes
    
    addresses_to_predict = dataframe.loc[missing_mask, 'address']
    predictions = classifier.predict(addresses_to_predict)

    dataframe.loc[missing_mask, 'country'] = predictions
    return dataframe

In [20]:
address_df_filled = infer_missing_country_codes(address_df, clf)

Extracting embeddings: 100%|██████████| 687/687 [00:35<00:00, 19.35it/s]


In [21]:
address_df['country'].isna().sum()

np.int64(0)

In [22]:
address_df_filled.country.isna().sum()

np.int64(0)

In [23]:
address_df_filled

Unnamed: 0,organization_id,address,country
0,52c8642c34649eafe2d044eee3d884e1,"60 MARKET SAUQRE, BELIZE CITY, BELIZE",US
1,831a31a1466f0ace3eb20b52d4575f92,"Plot O 356 &357, Sidco Indl Estate Ambattur Ch...",Other
2,f580bc1756d06768c94634b0332e1871,"204, Monarch Chambers, Marol Maroshi Road, Mum...",IN
3,149081be00548b006dc38a88264eae32,"YUNG-SHI BLVD, SHI-WAN TOWN, PO-LO COUNTY HUI-...",CN
4,f4864ac3d5d716cc586d60afe8a403ef,FLOWCRETE INDIA PRIVATE LIMITEDNEW NO.36 OLD N...,IN
...,...,...,...
2541,00750bb980757fd8535d73b90d6608c8,4518 HAMILTON AVENUE CLEVELAND OHIO 44114 USA ...,US
2542,2c398f54dacfc67b9fbe414565285dba,19320 REDWOOD RD. 44110-2799 CLEVELAND CLEVELAND,US
2543,2c398f54dacfc67b9fbe414565285dba,"19218 REDWOOD ROAD CLEVELAND, OH CLEVELANDOH ...",US
2544,1b4334d45be8824a1f6b071a9a2f75a8,United States,US


In [26]:
joblib.dump(clf, '../artifacts/country_classifier_pipeline_1.joblib')

['./artifacts/country_classifier_pipeline_1.joblib']

In [18]:
# address_df_filled.to_csv('./data/address.csv', index=False)
address_df_filled = pd.read_csv('../data/address.csv')

# Country collection

In [19]:
address_df_filled.country.unique()

array(['US', 'Other', 'IN', 'CN', 'PT', 'FR', 'ZA', 'MX', 'DE', 'TW',
       'AE', 'MY', 'CA', 'KZ', 'IT', 'NP', 'KR', 'BE', 'JP', 'AU', 'PL',
       'DO', 'NL', 'TH', 'CG', 'GB', 'HK', 'TJ', 'CR', 'SG', 'NZ', 'NO',
       'SK', 'IR', 'PH', 'VN', 'CL', 'RO', 'IL', 'ES', 'DK', 'TR', 'TT',
       'LI', 'CO', 'CH', 'FI', 'HU', 'SI', 'ID', 'AR', 'SE', 'BR', 'ET',
       'HR', 'IE', 'LR', 'BY', 'GY', 'CZ', 'AT', 'SA', 'RU', 'CU', 'SZ',
       'SV', 'PA', 'DJ', 'SN', 'RS', 'AZ', 'UY', 'SD', 'LU', 'EE', 'MN',
       'SS', 'CY', 'MT', 'TM', 'SM', 'EC', 'NG', 'BG', 'TZ', 'GT'],
      dtype=object)

In [20]:
client = ArangoClient(hosts="https://2ae4f052d710.arangodb.cloud:8529")

db = client.db('machine_learning', username='lab_test', password='z-rRdN-Enf4qQwybGiVdbG')

In [21]:
if not db.has_collection('countries'):
    countries = db.create_collection('countries')
else:
    countries = db.collection('countries')

In [22]:
countries_list = [{'_key': code} for code in address_df_filled.country.unique()]

In [23]:
countries.insert_many(countries_list)

[{'_id': 'countries/US', '_key': 'US', '_rev': '_ihFCKhG---'},
 {'_id': 'countries/Other', '_key': 'Other', '_rev': '_ihFCKhG--_'},
 {'_id': 'countries/IN', '_key': 'IN', '_rev': '_ihFCKhG--A'},
 {'_id': 'countries/CN', '_key': 'CN', '_rev': '_ihFCKhG--B'},
 {'_id': 'countries/PT', '_key': 'PT', '_rev': '_ihFCKhG--C'},
 {'_id': 'countries/FR', '_key': 'FR', '_rev': '_ihFCKhG--D'},
 {'_id': 'countries/ZA', '_key': 'ZA', '_rev': '_ihFCKhG--E'},
 {'_id': 'countries/MX', '_key': 'MX', '_rev': '_ihFCKhG--F'},
 {'_id': 'countries/DE', '_key': 'DE', '_rev': '_ihFCKhG--G'},
 {'_id': 'countries/TW', '_key': 'TW', '_rev': '_ihFCKhG--H'},
 {'_id': 'countries/AE', '_key': 'AE', '_rev': '_ihFCKhG--I'},
 {'_id': 'countries/MY', '_key': 'MY', '_rev': '_ihFCKhG--J'},
 {'_id': 'countries/CA', '_key': 'CA', '_rev': '_ihFCKhG--K'},
 {'_id': 'countries/KZ', '_key': 'KZ', '_rev': '_ihFCKhG--L'},
 {'_id': 'countries/IT', '_key': 'IT', '_rev': '_ihFCKhG--M'},
 {'_id': 'countries/NP', '_key': 'NP', '_rev': '_

# Sites collection

In [21]:
address_df_filled['country_id'] = 'countries/' + address_df_filled['country']

In [22]:
site_df = address_df_filled[['organization_id', 'country_id']].drop_duplicates()

In [23]:
site_df.insert(0, '_key', [uuid.uuid4().hex for _ in range(len(site_df))])

In [24]:
site_df['organization_id'] = 'organizations/' + site_df['organization_id']

In [25]:
site_df

Unnamed: 0,_key,organization_id,country_id
0,18857ddacc1f473bb65d0a42030601a6,organizations/52c8642c34649eafe2d044eee3d884e1,countries/US
1,8659f74cc67d4c53b80693eaae86c6f1,organizations/831a31a1466f0ace3eb20b52d4575f92,countries/Other
2,14c27d3af19a40259c5fbea6c774fd00,organizations/f580bc1756d06768c94634b0332e1871,countries/IN
3,540ee69808d3411a8bed3bd349d3f6a4,organizations/149081be00548b006dc38a88264eae32,countries/CN
4,52b076a7d8c74afaabba4f40af58bf0e,organizations/f4864ac3d5d716cc586d60afe8a403ef,countries/IN
...,...,...,...
2522,01eecdce31a140ba8b35bcd2b7d04772,organizations/2fc4208e46b525d0b90e4662a7d7e2f4,countries/US
2524,33cb29307f924b32b555c60b5080da3b,organizations/d3a2485b31c1e99218252f115a29d02e,countries/US
2525,5273ec48ebb14e289c51cb3c42379b29,organizations/ba616735135653cb6ff28ca5856e6599,countries/US
2539,3d1862033760464f98b728b174b9c6c0,organizations/68b95e65719c3deb6c32fda273487981,countries/US


In [26]:
db.delete_collection('sites')

True

In [27]:
if not db.has_collection('sites'):
    sites = db.create_collection('sites')
else:
    sites = db.collection('sites')

In [28]:
sites_list = site_df.to_dict(orient='records')

In [29]:
# sites_list
sites.insert_many(sites_list, overwrite_mode="replace")

[{'_id': 'sites/18857ddacc1f473bb65d0a42030601a6',
  '_key': '18857ddacc1f473bb65d0a42030601a6',
  '_rev': '_igxsi8a---'},
 {'_id': 'sites/8659f74cc67d4c53b80693eaae86c6f1',
  '_key': '8659f74cc67d4c53b80693eaae86c6f1',
  '_rev': '_igxsi8a--_'},
 {'_id': 'sites/14c27d3af19a40259c5fbea6c774fd00',
  '_key': '14c27d3af19a40259c5fbea6c774fd00',
  '_rev': '_igxsi8a--A'},
 {'_id': 'sites/540ee69808d3411a8bed3bd349d3f6a4',
  '_key': '540ee69808d3411a8bed3bd349d3f6a4',
  '_rev': '_igxsi8a--B'},
 {'_id': 'sites/52b076a7d8c74afaabba4f40af58bf0e',
  '_key': '52b076a7d8c74afaabba4f40af58bf0e',
  '_rev': '_igxsi8a--C'},
 {'_id': 'sites/cc1f0ee8f269425398d7724f37156fb3',
  '_key': 'cc1f0ee8f269425398d7724f37156fb3',
  '_rev': '_igxsi8a--D'},
 {'_id': 'sites/798de52a12224634877eeeafadb6be6a',
  '_key': '798de52a12224634877eeeafadb6be6a',
  '_rev': '_igxsi8a--E'},
 {'_id': 'sites/20d590d0f0fb4b35842d8942b8a57063',
  '_key': '20d590d0f0fb4b35842d8942b8a57063',
  '_rev': '_igxsi8a--F'},
 {'_id': 'sites/