In [1]:
import os
import numpy as np
import pandas as pd
from unidecode import unidecode
import spacy
import re
from tqdm import tqdm
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
from sklearn.decomposition import TruncatedSVD
import cloudpickle

In [2]:
def get_csv_test_data(path: str):
    '''
    Read CSV file
    '''
    data = pd.read_csv(path, sep=",")
    return data

In [3]:
nlp = spacy.load('en_core_web_sm')
def lemmatize(s: str):
    '''
    lemmatize tags
    '''
    s = str(s)
    if len(s) > 2:
        doc = nlp(s)
        return str(doc[0].lemma_)
    else:
        return s

In [4]:
def read_acronym_list():
    acronym_list = pd.read_csv(os.path.join("..", 'utils', 'acronym_replace.csv'))
    acronym_list = acronym_list.set_index('from').T
    acronym_dict = acronym_list.to_dict(orient='records')[0]
    return acronym_dict

In [5]:
acronym_dict = read_acronym_list()
def hashtag_clean(s: str):
    s = unidecode(s)
    s = s.lower()
    s = s.replace(", ", "")
    s = s.replace(",", "")
    s = s.replace(" # ", "#")
    s = s.replace(" #", "#")
    s = s.replace("# ", "#")
    s = s.replace("\n", "")
    s = s.replace("\t", "")
    s = s.replace("\r", "")
    s = s.replace("_", "")
    s = s.replace("\xa0", "")
    s = s.replace(".", "")
    tokens = []
    for token in s.split('#')[1:]:
        if acronym_dict.get(token) is not None:
            token = acronym_dict.get(token)
        tokens.append(lemmatize(token))
    return ", ".join(list(set(tokens)))

In [6]:
def clean_test_hashtag_column(tags: pd.Series, threshold: float=0.45):
    clean_hashtags = tags.map(hashtag_clean)
    return clean_hashtags

In [7]:
training_data_tags = [*pd.read_csv(os.path.join("..", 'utils', "training_data_tags.csv")).squeeze("columns")]
training_data_tags

['deeplearne',
 'blockchain',
 'datum',
 'design',
 'machinelearne',
 'bitcoin',
 'code',
 'artificialintelligence',
 'product',
 'bigdata',
 'businessowner',
 'other_tag',
 'entrepreneur',
 'datascience',
 'programming',
 'crypto',
 'cryptocurrency',
 'business',
 'startup',
 'iot',
 'technology',
 'development',
 'ethereum']

In [8]:
def spread_test_tags(data: pd.DataFrame):
    data = data.reset_index(drop=True).copy()
    for train_tag in training_data_tags:
        data[train_tag] = 0
    for row_n in tqdm(range(len(data['clean_hashtags']))):
        row = data['clean_hashtags'][row_n]
        for tag in row.split(", "):
            if training_data_tags.count(tag) == 0:
                data.loc[row_n, 'other_tag'] += 1
            else:
                data.loc[row_n, tag] = 1
    data.fillna(0, inplace=True)
    return data

In [9]:
def prepare_test_data(path: str):
    test_data = get_csv_test_data(path)
    y = test_data['Likes'].copy()
    test_data['clean_hashtags'] = clean_test_hashtag_column(test_data['Hashtags'])
    test_data = spread_test_tags(test_data)
    test_data = test_data[[*pd.read_csv(os.path.join("..", 'utils', 'training_features.csv')).squeeze('columns')]].copy()
    x = test_data.copy()
    return x, y

In [10]:
x_test, y_test = prepare_test_data(os.path.join("..", 'data', "test_data.csv"))

100%|████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 1116.84it/s]


In [11]:
pd.concat([x_test, y_test], axis= 1).to_csv(os.path.join("..", "data", "test_preprocessed.csv"), index=False)

In [12]:
x_test.shape, y_test.shape

((10, 24), (10,))

In [13]:
def predict(x_test: pd.DataFrame, model_path: str=os.path.join("..", 'models', "best_model.pkl")):
    with open(model_path, "rb") as file:
        best_model = cloudpickle.load(file)
    prediction = best_model.predict(x_test)
    return prediction

In [14]:
pred = predict(x_test)

In [15]:
pred

array([27.89148679, 27.40785436, 43.79557062, 27.53590198, 27.77690198,
       53.08459714, 43.79557062, 29.63899406, 27.77690198, 52.16938022])

In [16]:
y_test

0     14
1     24
2     21
3     31
4     16
5    136
6     20
7     28
8     31
9    139
Name: Likes, dtype: int64