## Подключение к базе и таблицы с юзерами и постами

In [1]:
! pip3 install psycopg2-binary

Collecting psycopg2-binary
  Downloading psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.4 kB)
Downloading psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m41.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: psycopg2-binary
Successfully installed psycopg2-binary-2.9.9


In [2]:
# Создадим переменную connection_path для того чтобы подключаться к базе данных не указывая явно в коде логин и пароль от БД

config_file = "config.txt"
with open(config_file, "r") as f:
    config_data = f.readlines()

config = {}
for line in config_data:
    key, value = line.strip().split("=")
    config[key] = value
    
connection_path = f"postgresql://{config['username']}:{config['password']}@{config['host']}:{config['port']}/{config['database']}"

In [3]:
from sqlalchemy import create_engine

engine = create_engine(connection_path)

connection = engine.connect().execution_options(stream_results=True)

## Работа с данными по постам. Создадим несколько дополнительных признаков.

In [4]:
# Посты и топики

import pandas as pd


posts_info = pd.read_sql(
    """SELECT * FROM public.post_text_df""",
    con=connection
)

posts_info

Unnamed: 0,post_id,text,topic
0,1,UK economy facing major risks\n\nThe UK manufa...,business
1,2,Aids and climate top Davos agenda\n\nClimate c...,business
2,3,Asian quake hits European shares\n\nShares in ...,business
3,4,India power shares jump on debut\n\nShares in ...,business
4,5,Lacroix label bought by US firm\n\nLuxury good...,business
...,...,...,...
7018,7315,"OK, I would not normally watch a Farrelly brot...",movie
7019,7316,I give this movie 2 stars purely because of it...,movie
7020,7317,I cant believe this film was allowed to be mad...,movie
7021,7318,The version I saw of this film was the Blockbu...,movie


In [5]:
# Сделаем эмбеддинги постов с помощью одного из трансформеров

from transformers import AutoTokenizer
from transformers import BertModel  # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel
from transformers import RobertaModel  # https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel
from transformers import DistilBertModel  # https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel


def get_model(model_name):
    assert model_name in ['bert', 'roberta', 'distilbert']

    checkpoint_names = {
        'bert': 'bert-base-cased',  # https://huggingface.co/bert-base-cased
        'roberta': 'roberta-base',  # https://huggingface.co/roberta-base
        'distilbert': 'distilbert-base-cased'  # https://huggingface.co/distilbert-base-cased
    }

    model_classes = {
        'bert': BertModel,
        'roberta': RobertaModel,
        'distilbert': DistilBertModel
    }

    return AutoTokenizer.from_pretrained(checkpoint_names[model_name]), model_classes[model_name].from_pretrained(checkpoint_names[model_name])

In [6]:
tokenizer, model = get_model('distilbert')

tokenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/263M [00:00<?, ?B/s]

In [7]:
# Сделаем датасет для постов

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding


class PostDataset(Dataset):
    def __init__(self, texts, tokenizer):
        super().__init__()

        self.texts = tokenizer.batch_encode_plus(
            texts,
            add_special_tokens=True,
            return_token_type_ids=False,
            return_tensors='pt',
            truncation=True,
            padding=True
        )
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        return {'input_ids': self.texts['input_ids'][idx], 'attention_mask': self.texts['attention_mask'][idx]}

    def __len__(self):
        return len(self.texts['input_ids'])
    
    
dataset = PostDataset(posts_info['text'].values.tolist(), tokenizer)

# добавляем паддинг к текстам меньшей длины, чтобы привести их к одной длине
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

loader = DataLoader(dataset, batch_size=32, collate_fn=data_collator, pin_memory=True, shuffle=False)

2024-03-29 09:23:57.036440: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-29 09:23:57.036570: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-29 09:23:57.187384: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [8]:
import torch
from tqdm import tqdm


@torch.inference_mode()
def get_embeddings_labels(model, loader):
    model.eval()
    
    total_embeddings = []
    
    for batch in tqdm(loader):
        
        batch = {key: batch[key].to(device) for key in ['attention_mask', 'input_ids']}
        
        """ [:, 0, :]  - Первый токен в последовательности обычно соответствует специальному токену начала 
        последовательности (например, [CLS] в модели BERT), который используется для обработки всего 
        текста в контексте задачи. Эмбеддинг этого токена содержит сжатую информацию о всем тексте и 
        может использоваться для классификации, регрессии или других задач обработки естественного языка."""

        embeddings = model(**batch)['last_hidden_state'][:, 0, :]

        total_embeddings.append(embeddings.cpu())

    return torch.cat(total_embeddings, dim=0)

In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)
print(torch.cuda.get_device_name())

model = model.to(device)

cuda:0
Tesla T4


In [10]:
embeddings = get_embeddings_labels(model, loader).numpy()

embeddings

100%|██████████| 220/220 [01:52<00:00,  1.96it/s]


array([[ 3.6315086e-01,  4.8937496e-02, -2.6408118e-01, ...,
        -1.4159346e-01,  1.5918216e-02,  9.1982896e-05],
       [ 2.3641640e-01, -1.5950108e-01, -3.2779828e-01, ...,
        -2.8993604e-01,  1.1936528e-01, -1.6235473e-03],
       [ 3.7519148e-01, -1.1394388e-01, -2.4054705e-01, ...,
        -3.3891949e-01,  5.8694065e-02, -2.1265799e-02],
       ...,
       [ 3.4038273e-01,  6.6492192e-02, -1.6318429e-01, ...,
        -8.6562753e-02,  2.0340374e-01,  3.2090571e-02],
       [ 4.3209219e-01,  1.1091532e-02, -1.1730607e-01, ...,
         7.5401559e-02,  1.0273975e-01,  1.5274222e-02],
       [ 3.0427766e-01, -7.6215670e-02, -6.7758739e-02, ...,
        -5.4348916e-02,  2.4438348e-01, -1.4148588e-02]], dtype=float32)

In [11]:
# Кластеризуем тексты

from sklearn.decomposition import PCA

centered = embeddings - embeddings.mean()

pca = PCA(n_components=50)
pca_decomp = pca.fit_transform(centered)

In [12]:
from sklearn.cluster import KMeans

n_clusters = 15

kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(pca_decomp)

posts_info['TextCluster'] = kmeans.labels_

dists_columns = [f'DistanceToCluster_{i}' for i in range(n_clusters)]

dists_df = pd.DataFrame(
    data=kmeans.transform(pca_decomp),
    columns=dists_columns
)

dists_df.head()



Unnamed: 0,DistanceToCluster_0,DistanceToCluster_1,DistanceToCluster_2,DistanceToCluster_3,DistanceToCluster_4,DistanceToCluster_5,DistanceToCluster_6,DistanceToCluster_7,DistanceToCluster_8,DistanceToCluster_9,DistanceToCluster_10,DistanceToCluster_11,DistanceToCluster_12,DistanceToCluster_13,DistanceToCluster_14
0,2.338563,3.409179,3.372166,3.469854,3.386987,3.619977,2.836988,3.000721,1.870061,2.230505,3.441087,3.668975,3.464343,1.892354,3.404402
1,2.309558,3.326556,3.323515,3.245044,3.366533,3.355541,2.554536,2.848635,2.198359,2.240365,2.977058,3.471692,3.139989,1.412387,3.208464
2,2.381367,3.35545,3.260675,3.390605,3.492749,3.361768,2.883928,3.032044,1.812861,3.042849,2.962888,3.452456,3.134252,1.693465,3.26427
3,2.80283,3.737489,3.510809,4.060458,3.744031,3.790728,3.375602,3.273691,2.453177,3.398152,3.711226,3.14823,3.793648,2.439035,3.677795
4,2.025491,2.813406,3.037411,3.242253,2.80553,3.051338,2.147091,2.647032,1.474374,2.94389,2.642455,3.174901,2.783598,2.132235,2.85177


In [13]:
posts_info = pd.concat((posts_info, dists_df), axis=1)

posts_info

Unnamed: 0,post_id,text,topic,TextCluster,DistanceToCluster_0,DistanceToCluster_1,DistanceToCluster_2,DistanceToCluster_3,DistanceToCluster_4,DistanceToCluster_5,DistanceToCluster_6,DistanceToCluster_7,DistanceToCluster_8,DistanceToCluster_9,DistanceToCluster_10,DistanceToCluster_11,DistanceToCluster_12,DistanceToCluster_13,DistanceToCluster_14
0,1,UK economy facing major risks\n\nThe UK manufa...,business,8,2.338563,3.409179,3.372166,3.469854,3.386987,3.619977,2.836988,3.000721,1.870061,2.230505,3.441087,3.668975,3.464343,1.892354,3.404402
1,2,Aids and climate top Davos agenda\n\nClimate c...,business,13,2.309558,3.326556,3.323515,3.245044,3.366533,3.355541,2.554536,2.848635,2.198359,2.240365,2.977058,3.471692,3.139989,1.412387,3.208464
2,3,Asian quake hits European shares\n\nShares in ...,business,13,2.381367,3.355450,3.260675,3.390605,3.492749,3.361768,2.883928,3.032044,1.812861,3.042849,2.962888,3.452456,3.134252,1.693465,3.264270
3,4,India power shares jump on debut\n\nShares in ...,business,13,2.802830,3.737489,3.510809,4.060458,3.744031,3.790728,3.375602,3.273691,2.453177,3.398152,3.711226,3.148230,3.793648,2.439035,3.677795
4,5,Lacroix label bought by US firm\n\nLuxury good...,business,8,2.025491,2.813406,3.037411,3.242253,2.805530,3.051338,2.147091,2.647032,1.474374,2.943890,2.642455,3.174901,2.783598,2.132235,2.851770
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,7315,"OK, I would not normally watch a Farrelly brot...",movie,1,2.737494,1.279071,3.131886,3.396509,1.798109,2.958357,2.338325,2.800836,3.002113,3.357036,3.049195,2.142013,1.816416,3.022646,2.067226
7019,7316,I give this movie 2 stars purely because of it...,movie,1,2.448088,0.925397,2.926459,3.382548,1.446728,2.603179,2.237197,2.488354,2.964206,3.192574,3.195076,1.958351,1.836187,3.035999,1.826141
7020,7317,I cant believe this film was allowed to be mad...,movie,1,2.810267,1.499289,2.829880,3.462165,2.010368,2.392578,2.450104,2.506140,3.188105,3.405635,3.153295,2.394759,1.980326,3.273257,2.227760
7021,7318,The version I saw of this film was the Blockbu...,movie,4,2.993838,1.488191,3.429820,3.409822,1.046261,3.312347,2.312419,3.090368,3.195245,3.445812,3.216785,1.788433,1.524476,3.304567,1.927283


## Работа с табличками действий и пользователей

In [14]:
# Попробуем забрать, скажем, 6 миллионов, сразу очистим и оставим только view


feed_data = pd.read_sql(
    """
     SELECT
        timestamp,
        post_id,
        gender,
        age,
        country,
        city,
        exp_group,
        os,
        source,
        cast(extract(hour from timestamp) as int) as hour,
        cast(extract(month from timestamp) as int) as month,
        CASE extract(dow from timestamp)
            WHEN 0 THEN 'Sunday'
            WHEN 1 THEN 'Monday'
            WHEN 2 THEN 'Tuesday'
            WHEN 3 THEN 'Wednesday'
            WHEN 4 THEN 'Thursday'
            WHEN 5 THEN 'Friday'
            WHEN 6 THEN 'Saturday'
        END as day_of_week,
        target
    FROM public.feed_data
    JOIN public.user_data ON public.feed_data.user_id = public.user_data.user_id
    WHERE action = 'view'
    LIMIT 6000000

    """,
    con=connection
)

feed_data.head()

Unnamed: 0,timestamp,post_id,gender,age,country,city,exp_group,os,source,hour,month,day_of_week,target
0,2021-10-04 12:46:01,6989,1,23,Belarus,Horad Barysaw,4,Android,ads,12,10,Monday,0
1,2021-10-04 12:46:37,1899,1,23,Belarus,Horad Barysaw,4,Android,ads,12,10,Monday,0
2,2021-10-04 12:48:56,1561,1,23,Belarus,Horad Barysaw,4,Android,ads,12,10,Monday,0
3,2021-10-04 12:50:58,247,1,23,Belarus,Horad Barysaw,4,Android,ads,12,10,Monday,1
4,2021-10-04 12:52:03,1341,1,23,Belarus,Horad Barysaw,4,Android,ads,12,10,Monday,0


In [15]:
# Добавим фитчу средний возраст по городу
def users_average_age_per_city(df):
    av_age = df.groupby('city')['age'].mean()
    df['av_age_per_city'] = df['city'].map(av_age)
    return df

In [16]:
# Добавим фитчу количество человек в городе
def count_users_in_city(df):
    count = df['city'].value_counts()  # Count the number of users in each country
    df['users_in_city'] = df['city'].map(count)  # Map the counts back to the original DataFrame
    return df

In [17]:
def user_feature_creation(df):
    df = count_users_in_city(df)
    df = users_average_age_per_city(df)
    
    return df

In [18]:
feed_data = user_feature_creation(feed_data)

In [19]:
feed_data.head()

Unnamed: 0,timestamp,post_id,gender,age,country,city,exp_group,os,source,hour,month,day_of_week,target,users_in_city,av_age_per_city
0,2021-10-04 12:46:01,6989,1,23,Belarus,Horad Barysaw,4,Android,ads,12,10,Monday,0,2446,23.817253
1,2021-10-04 12:46:37,1899,1,23,Belarus,Horad Barysaw,4,Android,ads,12,10,Monday,0,2446,23.817253
2,2021-10-04 12:48:56,1561,1,23,Belarus,Horad Barysaw,4,Android,ads,12,10,Monday,0,2446,23.817253
3,2021-10-04 12:50:58,247,1,23,Belarus,Horad Barysaw,4,Android,ads,12,10,Monday,1,2446,23.817253
4,2021-10-04 12:52:03,1341,1,23,Belarus,Horad Barysaw,4,Android,ads,12,10,Monday,0,2446,23.817253


In [20]:
feed_data = pd.merge(
    feed_data,
    posts_info,
    on='post_id',
    how='left'
)
feed_data

Unnamed: 0,timestamp,post_id,gender,age,country,city,exp_group,os,source,hour,...,DistanceToCluster_5,DistanceToCluster_6,DistanceToCluster_7,DistanceToCluster_8,DistanceToCluster_9,DistanceToCluster_10,DistanceToCluster_11,DistanceToCluster_12,DistanceToCluster_13,DistanceToCluster_14
0,2021-10-04 12:46:01,6989,1,23,Belarus,Horad Barysaw,4,Android,ads,12,...,2.959846,2.024314,2.770907,2.789429,2.870943,3.042505,2.241158,1.260157,2.831721,1.665336
1,2021-10-04 12:46:37,1899,1,23,Belarus,Horad Barysaw,4,Android,ads,12,...,3.549987,2.831552,3.460938,3.093694,2.947919,2.205883,3.505349,2.830170,3.158619,3.373848
2,2021-10-04 12:48:56,1561,1,23,Belarus,Horad Barysaw,4,Android,ads,12,...,3.860373,2.929727,3.583459,2.956222,3.364670,2.861131,3.542736,3.215307,2.960007,3.401962
3,2021-10-04 12:50:58,247,1,23,Belarus,Horad Barysaw,4,Android,ads,12,...,3.538179,3.151393,3.093237,2.060922,2.997500,3.197528,3.655066,3.376617,1.408696,3.468939
4,2021-10-04 12:52:03,1341,1,23,Belarus,Horad Barysaw,4,Android,ads,12,...,3.498249,2.805018,3.059329,2.492012,2.287634,3.045402,3.564595,3.246394,1.653775,3.400759
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5999995,2021-11-11 14:05:33,5578,1,17,Russia,Sylva,4,Android,ads,14,...,3.471966,2.571451,3.118466,3.302845,3.355639,3.516023,2.487076,2.353326,3.308369,2.525243
5999996,2021-11-11 14:06:49,391,1,17,Russia,Sylva,4,Android,ads,14,...,3.620740,2.981066,3.129497,2.614989,2.429788,3.251314,3.494147,3.226930,1.807104,3.446127
5999997,2021-11-11 14:09:32,3997,1,17,Russia,Sylva,4,Android,ads,14,...,1.999876,3.028530,1.696507,3.042752,3.465952,3.300930,3.144799,3.012338,3.113544,3.011846
5999998,2021-11-11 14:10:29,3616,1,17,Russia,Sylva,4,Android,ads,14,...,1.647646,3.099917,1.719748,3.175679,3.639952,3.373162,3.429652,3.101620,3.306071,3.139479


In [21]:
feed_data.drop(['post_id','text'], axis=1, inplace=True)

## Обучение модели

Обучим моель с теми же параметрами что и при первом варианте чтобы понять насколько улучшила качество векторизация текстов через трансформеры

In [22]:
### Валидация:
# Так как данные имеют временную структуру train и test будем разбивать по колонке timestamp, 
# чтобы при решении не 'подглядывать ответы'

max(feed_data.timestamp), min(feed_data.timestamp)

(Timestamp('2021-12-29 23:44:39'), Timestamp('2021-10-01 06:01:40'))

In [23]:
# За отсечку возьмем 2021-12-15
### За отсечку возьмем 2021-12-15
split_date = '2021-12-15'

df_train = feed_data[feed_data.timestamp < split_date]
df_test = feed_data[feed_data.timestamp >= split_date]

df_train = df_train.drop('timestamp', axis=1)
df_test = df_test.drop('timestamp', axis=1)

X_train = df_train.drop('target', axis=1)
X_test = df_test.drop('target', axis=1)

y_train = df_train['target']
y_test = df_test['target']

y_train.shape, y_test.shape

((5006916,), (993084,))

In [25]:
from catboost import CatBoostClassifier


object_cols = [
    'topic', 'TextCluster', 'gender', 'country', 
    'city', 'exp_group', 'hour', 'month', 'day_of_week',
    'os', 'source'
]

catboost = CatBoostClassifier(
    iterations=200,
    random_seed=111,
    thread_count=-1,
    task_type="GPU",
    verbose = 50
)

catboost.fit(X_train, y_train, object_cols)
# catboost.fit(X=feed_data.drop(['target'], axis=1), y=feed_data['target'], cat_features=object_cols)

Learning rate set to 0.089243
0:	learn: 0.6244748	total: 562ms	remaining: 1m 51s
50:	learn: 0.3399895	total: 30.9s	remaining: 1m 30s
100:	learn: 0.3377428	total: 1m	remaining: 59.8s
150:	learn: 0.3367510	total: 1m 29s	remaining: 28.9s
199:	learn: 0.3360913	total: 1m 55s	remaining: 0us


<catboost.core.CatBoostClassifier at 0x7dde8d4b1de0>

In [26]:
# Замерим качество работы такой модели
# Возьмем ROC-AUC
from sklearn.metrics import roc_auc_score

print(f"Качество на трейне: {roc_auc_score(y_train, catboost.predict_proba(X_train)[:, 1])}")
print(f"Качество на тесте: {roc_auc_score(y_test, catboost.predict_proba(X_test)[:, 1])}")

Качество на трейне: 0.6926968721127715
Качество на тесте: 0.6644254796965121


Обнаружено, что применение трансформеров для улучшения результатов модели с TF-IDF привело лишь к увеличению качества на 0.0004, что нельзя считать значительным улучшением.

### Сохраним модель и положим в базу фичи, необходимые для функционала нашей модели

In [27]:
catboost.save_model(
    'model_test',
    format="cbm"                  
)

In [28]:
posts_info.to_sql(
   "koriakov_posts_info_features_dl",
    con=connection_path,
    schema="public",
    if_exists='replace'
)

23