In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from transformers import AutoTokenizer
from transformers import BertModel

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

import torch
from tqdm import tqdm

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances

from catboost import CatBoostClassifier

from sklearn.cluster import KMeans

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [3]:
CONN = "postgresql://robot-startml-ro:pheiph0hahj1Vaif@"\
    "postgres.lab.karpov.courses:6432/startml"

In [4]:
user_df = pd.read_sql(
    """SELECT * FROM public.user_data""",
    con=CONN
)

user_df.head()

Unnamed: 0,user_id,gender,age,country,city,exp_group,os,source
0,200,1,34,Russia,Degtyarsk,3,Android,ads
1,201,0,37,Russia,Abakan,0,Android,ads
2,202,1,17,Russia,Smolensk,4,Android,ads
3,203,0,18,Russia,Moscow,1,iOS,ads
4,204,0,36,Russia,Anzhero-Sudzhensk,3,Android,ads


In [5]:
post_df = pd.read_sql(
    """SELECT * FROM daria_stikheeva_enhanced_model_post_features""", # public.post_text_df
    con=CONN
)

post_df.head()

Unnamed: 0,index,post_id,text,topic,DistanceTo0thCluster,DistanceTo1thCluster,DistanceTo2thCluster,DistanceTo3thCluster,DistanceTo4thCluster,DistanceTo5thCluster,...,DistanceTo20thCluster,DistanceTo21thCluster,DistanceTo22thCluster,DistanceTo23thCluster,DistanceTo24thCluster,DistanceTo25thCluster,DistanceTo26thCluster,DistanceTo27thCluster,DistanceTo28thCluster,DistanceTo29thCluster
0,0,1,UK economy facing major risks\n\nThe UK manufa...,business,9.291727,11.372839,12.15777,9.772949,10.619096,8.222388,...,9.618526,10.472997,11.832791,9.911668,7.689897,10.696194,9.829049,11.195169,7.549977,7.043384
1,1,2,Aids and climate top Davos agenda\n\nClimate c...,business,8.043159,11.037454,10.62526,9.436192,10.358827,7.140182,...,9.395688,10.328116,10.752964,8.603471,7.908851,10.552367,7.942627,10.495662,7.581461,7.720238
2,2,3,Asian quake hits European shares\n\nShares in ...,business,8.370312,10.499072,10.566326,8.872985,9.751165,7.893025,...,9.118278,9.78301,10.474702,8.912313,8.336408,10.022666,7.763417,10.149131,8.149964,6.416407
3,3,4,India power shares jump on debut\n\nShares in ...,business,8.291944,10.849169,10.848932,9.728095,9.979709,6.830172,...,9.899307,9.833903,10.883025,8.405615,8.771545,10.288205,9.148214,10.250868,8.397944,6.955458
4,4,5,Lacroix label bought by US firm\n\nLuxury good...,business,8.128834,8.898019,8.58943,9.796355,8.364449,7.160905,...,9.403523,8.11485,8.45573,6.185608,9.521733,8.62851,8.307801,8.96991,8.678846,8.037048


In [7]:
post_df = post_df.drop('index', axis=1)
post_df

Unnamed: 0,post_id,text,topic,DistanceTo0thCluster,DistanceTo1thCluster,DistanceTo2thCluster,DistanceTo3thCluster,DistanceTo4thCluster,DistanceTo5thCluster,DistanceTo6thCluster,...,DistanceTo20thCluster,DistanceTo21thCluster,DistanceTo22thCluster,DistanceTo23thCluster,DistanceTo24thCluster,DistanceTo25thCluster,DistanceTo26thCluster,DistanceTo27thCluster,DistanceTo28thCluster,DistanceTo29thCluster
0,1,UK economy facing major risks\n\nThe UK manufa...,business,9.291727,11.372839,12.157770,9.772949,10.619096,8.222388,7.176412,...,9.618526,10.472997,11.832791,9.911668,7.689897,10.696194,9.829049,11.195169,7.549977,7.043384
1,2,Aids and climate top Davos agenda\n\nClimate c...,business,8.043159,11.037454,10.625260,9.436192,10.358827,7.140182,7.121281,...,9.395688,10.328116,10.752964,8.603471,7.908851,10.552367,7.942627,10.495662,7.581461,7.720238
2,3,Asian quake hits European shares\n\nShares in ...,business,8.370312,10.499072,10.566326,8.872985,9.751165,7.893025,6.322171,...,9.118278,9.783010,10.474702,8.912313,8.336408,10.022666,7.763417,10.149131,8.149964,6.416407
3,4,India power shares jump on debut\n\nShares in ...,business,8.291944,10.849169,10.848932,9.728095,9.979709,6.830172,4.812199,...,9.899307,9.833903,10.883025,8.405615,8.771545,10.288205,9.148214,10.250868,8.397944,6.955458
4,5,Lacroix label bought by US firm\n\nLuxury good...,business,8.128834,8.898019,8.589430,9.796355,8.364449,7.160905,6.189485,...,9.403523,8.114850,8.455730,6.185608,9.521733,8.628510,8.307801,8.969910,8.678846,8.037048
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,7315,"OK, I would not normally watch a Farrelly brot...",movie,9.616599,4.668174,10.309564,9.104768,7.119634,10.251859,10.797025,...,10.883625,8.089468,7.338254,9.323751,11.798183,6.195448,10.388471,8.767579,11.775478,11.182950
7019,7316,I give this movie 2 stars purely because of it...,movie,8.919337,4.022430,8.595892,8.993085,5.819405,9.316074,9.761645,...,10.398132,6.814692,6.039435,8.233876,10.987240,5.244951,9.519237,7.958955,10.897470,10.397407
7020,7317,I cant believe this film was allowed to be mad...,movie,8.579074,4.124624,9.049135,8.487318,6.127579,9.488235,9.833517,...,10.170457,7.257630,6.763633,8.598885,10.727423,5.492655,9.342589,7.969790,10.759872,9.925567
7021,7318,The version I saw of this film was the Blockbu...,movie,8.847170,5.783009,9.986953,9.128920,5.699656,9.458660,9.713754,...,10.355325,5.373230,6.419209,8.106543,11.178873,4.561183,9.648648,6.807680,11.108180,10.484558


In [6]:
feed_df = pd.read_sql(
    """SELECT * FROM public.feed_data LIMIT 1000000""",
    con=CONN
)

feed_df.head()

Unnamed: 0,timestamp,user_id,post_id,action,target
0,2021-11-21 18:24:23,73613,7170,view,0
1,2021-11-21 18:27:01,73613,4306,view,0
2,2021-11-21 18:28:53,73613,6401,view,0
3,2021-11-21 18:31:51,73613,1346,view,0
4,2021-11-21 18:34:05,73613,1722,view,0


In [7]:
def get_model(model_name):

    checkpoint_names = {
        'bert_cased': 'bert-base-cased',  # https://huggingface.co/bert-base-cased
        'bert_uncased': 'bert-base-uncased',  # https://huggingface.co/bert-base-uncased
    }

    model_classes = {
        'bert_cased': BertModel,
        'bert_uncased': BertModel,
    }

    return (
        AutoTokenizer.from_pretrained(checkpoint_names[model_name]),
        model_classes[model_name].from_pretrained(checkpoint_names[model_name])
    )

In [8]:
tokenizer, model = get_model('bert_uncased')

In [9]:
texts = post_df['text']
texts

0       UK economy facing major risks\n\nThe UK manufa...
1       Aids and climate top Davos agenda\n\nClimate c...
2       Asian quake hits European shares\n\nShares in ...
3       India power shares jump on debut\n\nShares in ...
4       Lacroix label bought by US firm\n\nLuxury good...
                              ...                        
7018    OK, I would not normally watch a Farrelly brot...
7019    I give this movie 2 stars purely because of it...
7020    I cant believe this film was allowed to be mad...
7021    The version I saw of this film was the Blockbu...
7022    Piece of subtle art. Maybe a masterpiece. Doub...
Name: text, Length: 7023, dtype: object

In [10]:
class PostDataset(Dataset):
    def __init__(self, texts, tokenizer):
        super().__init__()

        self.texts = tokenizer.batch_encode_plus(
            texts,
            add_special_tokens=True,
            return_token_type_ids=False,
            return_tensors='pt',
            truncation=True,
            padding=True
        )
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        return {'input_ids': self.texts['input_ids'][idx], 'attention_mask': self.texts['attention_mask'][idx]}

    def __len__(self):
        return len(self.texts['input_ids'])
    
dataset = PostDataset(texts.values.tolist(), tokenizer)

In [11]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

loader = DataLoader(dataset, batch_size=32, collate_fn=data_collator, pin_memory=True, shuffle=False)

In [15]:
@torch.inference_mode()
def get_embeddings_labels(model, loader):
    model.eval()

    total_embeddings = []

    for batch in tqdm(loader):
        batch = {key: batch[key].to(device) for key in ['attention_mask', 'input_ids']}

        embeddings = model(**batch)['last_hidden_state'][:, 0, :]

        total_embeddings.append(embeddings.cpu())

    return torch.cat(total_embeddings, dim=0)

In [16]:
embeddings = get_embeddings_labels(model, loader)
embeddings

100%|██████████| 220/220 [30:18<00:00,  8.27s/it]


tensor([[-0.6726, -0.0661, -0.1642,  ..., -0.2762,  0.8456,  0.0217],
        [-0.5875, -0.4204,  0.2798,  ..., -0.0183,  0.9172,  0.2609],
        [-0.4472, -0.3181,  0.3197,  ...,  0.1885,  0.8269,  0.0528],
        ...,
        [ 0.0415,  0.1860,  0.1918,  ..., -0.5334,  0.6429,  0.2543],
        [-0.2014, -0.3702, -0.1444,  ..., -0.5631,  0.6295,  0.3032],
        [-0.6395, -0.2433, -0.0190,  ..., -0.3811,  0.6440,  0.2714]])

In [17]:
texts_df = pd.DataFrame(embeddings)
texts_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,-0.672646,-0.066114,-0.164181,-0.434374,-0.511686,0.038449,-0.018457,0.440387,0.486812,-0.541671,...,0.681297,-0.049663,-0.088551,-0.589980,-0.186516,-0.065379,-0.095390,-0.276187,0.845573,0.021685
1,-0.587494,-0.420396,0.279791,-0.401147,-0.389208,-0.316731,-0.045538,0.423480,-0.071106,-0.256330,...,0.392163,-0.280619,0.057718,-0.520464,0.291019,0.033745,-0.615614,-0.018342,0.917151,0.260860
2,-0.447165,-0.318142,0.319700,0.004780,-0.230969,-0.360191,-0.126700,0.454154,0.264152,0.131370,...,0.729271,-0.184884,0.307492,-0.125189,0.161591,0.135195,-0.296847,0.188497,0.826870,0.052817
3,-0.619935,-0.630894,-0.067024,-0.214008,-0.272529,-0.087122,-0.260549,0.538185,0.371644,-0.300203,...,0.147915,0.116734,0.272336,-0.368351,0.123500,-0.470466,-0.006515,0.337206,0.526493,0.287202
4,-0.441943,-0.180337,-0.053377,-0.280836,-0.145585,0.030312,0.059537,0.177812,-0.045922,-0.044165,...,0.095446,-0.269966,0.478367,-0.429246,-0.067937,-0.189718,-0.117214,-0.310984,0.170111,0.268644
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,-0.208001,-0.113126,-0.224824,0.200626,-0.390257,-0.729697,0.045401,0.413207,0.675123,-0.466409,...,0.530587,-0.315324,0.384480,-0.365972,-0.372820,0.537090,-0.060865,-0.329380,0.718099,0.362158
7019,0.464172,-0.059786,0.002884,-0.507847,-0.443837,-0.165406,0.268679,0.650218,0.362003,-0.535136,...,0.227411,-0.242968,0.203928,-0.131606,0.013416,0.446630,-0.383271,-0.416972,0.651742,0.262723
7020,0.041493,0.185998,0.191797,-0.208402,-0.335347,-0.701733,0.422635,0.531157,0.594288,-0.567069,...,0.349404,-0.121715,-0.046800,-0.124749,-0.436591,0.412220,-0.297045,-0.533368,0.642929,0.254283
7021,-0.201409,-0.370204,-0.144415,-0.010181,0.037391,-0.678997,-0.109859,0.646580,0.701867,-0.388431,...,0.106108,-0.503268,0.145551,-0.223082,-0.494686,0.434754,-0.389514,-0.563073,0.629545,0.303248


In [23]:
pca = PCA(n_components=0.9)
texts_pca = pca.fit_transform(texts_df)
texts_pca = pd.DataFrame(texts_pca)
texts_pca.columns = [f"feature_{i}" for i in range(texts_pca.shape[1])]
texts_pca

Unnamed: 0,feature_0,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,feature_7,feature_8,feature_9,...,feature_164,feature_165,feature_166,feature_167,feature_168,feature_169,feature_170,feature_171,feature_172,feature_173
0,4.965774,3.913946,0.328075,-3.397384,-2.014962,-0.686577,1.049654,-0.635119,-1.892467,0.993359,...,0.324711,-0.000049,-0.164753,0.057744,-0.112642,-0.044534,-0.166250,-0.262771,0.260583,0.430576
1,3.040278,4.737936,0.715282,-0.058879,-0.389829,0.839894,1.047372,0.599039,-0.122321,0.892257,...,-0.176441,0.035042,-0.097097,0.171203,-0.144840,0.179231,0.019793,-0.124224,-0.063878,-0.113431
2,3.437137,3.810128,0.550966,-0.944149,-0.384335,-0.673771,2.232732,-0.309600,-2.267626,0.290142,...,-0.154946,-0.389230,0.321191,-0.259879,-0.032076,0.030327,0.252765,0.060212,-0.074356,0.000437
3,3.604059,3.987221,1.791548,-2.339946,0.460470,-0.596540,1.720387,-1.064763,-0.682786,1.130211,...,-0.162577,-0.215497,-0.234299,-0.085590,0.025233,-0.296933,0.060260,-0.084292,0.059296,-0.440590
4,0.358839,2.872988,1.794738,-0.376588,-1.147558,0.378130,-0.334092,-1.454386,0.994920,-0.534544,...,-0.045388,0.153395,0.081182,-0.157869,-0.069920,0.007391,-0.080908,0.404459,-0.012734,-0.007555
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,0.139852,-4.388859,-2.678585,-1.261501,-0.419838,0.715340,1.233368,-0.280138,0.998698,-0.889041,...,-0.219553,-0.220185,-0.100847,0.187979,-0.040853,0.021129,-0.121519,0.024364,-0.123461,0.049068
7019,-0.800363,-3.031542,-1.500390,-0.814136,0.210720,0.066633,-0.317377,-0.964372,0.650700,-1.336906,...,-0.136821,-0.216000,-0.050558,0.017762,-0.069943,-0.007033,-0.163437,0.206713,-0.035293,-0.109510
7020,-0.003255,-3.198644,-2.120910,-0.673796,0.890260,-0.273607,0.449289,-0.419640,-0.232311,-1.907923,...,0.020670,0.016576,-0.044304,-0.450725,-0.290186,0.400071,-0.047218,-0.345415,0.060905,0.006027
7021,0.503562,-3.950541,0.700295,-0.481486,-1.810637,0.223786,0.392069,-0.149856,-0.194369,0.638790,...,0.246727,-0.023337,-0.161811,-0.273735,0.042827,-0.115642,0.121783,-0.130900,-0.331327,-0.099284


In [30]:
kmeans_pca = KMeans(n_clusters=30, random_state=0, n_init=20).fit(texts_pca)
centers = kmeans_pca.cluster_centers_
distances = pairwise_distances(texts_pca, centers)
distances

array([[ 9.291727 , 11.372839 , 12.15777  , ..., 11.195169 ,  7.5499773,
         7.0433836],
       [ 8.043159 , 11.037454 , 10.62526  , ..., 10.495662 ,  7.581461 ,
         7.720238 ],
       [ 8.370312 , 10.499072 , 10.566326 , ..., 10.149131 ,  8.149964 ,
         6.4164066],
       ...,
       [ 8.579074 ,  4.1246243,  9.049135 , ...,  7.96979  , 10.759872 ,
         9.925567 ],
       [ 8.84717  ,  5.7830086,  9.986953 , ...,  6.8076797, 11.10818  ,
        10.484558 ],
       [ 9.596365 ,  8.882109 ,  9.516864 , ...,  7.7811866, 11.656442 ,
        11.141298 ]], dtype=float32)

In [31]:
distances = pd.DataFrame(distances)
distances

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,9.291727,11.372839,12.157770,9.772949,10.619096,8.222388,7.176412,12.964418,10.913680,4.568590,...,9.618526,10.472997,11.832791,9.911668,7.689898,10.696194,9.829049,11.195169,7.549977,7.043384
1,8.043159,11.037454,10.625260,9.436192,10.358827,7.140182,7.121281,11.155830,9.781802,6.717895,...,9.395688,10.328116,10.752964,8.603471,7.908851,10.552367,7.942627,10.495662,7.581461,7.720238
2,8.370312,10.499072,10.566326,8.872985,9.751165,7.893025,6.322171,11.314504,9.492709,5.197131,...,9.118278,9.783010,10.474702,8.912313,8.336408,10.022666,7.763417,10.149131,8.149964,6.416407
3,8.291944,10.849169,10.848932,9.728095,9.979709,6.830172,4.812199,11.612259,10.548822,6.141023,...,9.899307,9.833903,10.883025,8.405615,8.771545,10.288205,9.148214,10.250868,8.397944,6.955458
4,8.128834,8.898019,8.589430,9.796355,8.364449,7.160905,6.189486,8.601188,9.910775,8.113894,...,9.403523,8.114850,8.455729,6.185608,9.521733,8.628510,8.307801,8.969910,8.678846,8.037048
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,9.616599,4.668174,10.309564,9.104768,7.119634,10.251859,10.797025,10.104640,11.813255,11.341706,...,10.883625,8.089468,7.338254,9.323751,11.798183,6.195448,10.388471,8.767579,11.775478,11.182950
7019,8.919337,4.022430,8.595892,8.993085,5.819405,9.316074,9.761645,8.213351,10.715178,10.552584,...,10.398132,6.814692,6.039435,8.233876,10.987240,5.244951,9.519238,7.958955,10.897470,10.397407
7020,8.579074,4.124624,9.049135,8.487318,6.127579,9.488235,9.833517,8.852289,10.804566,10.464577,...,10.170457,7.257630,6.763633,8.598885,10.727423,5.492655,9.342589,7.969790,10.759872,9.925567
7021,8.847170,5.783009,9.986953,9.128920,5.699656,9.458660,9.713754,9.753468,10.775406,10.298182,...,10.355325,5.373230,6.419209,8.106543,11.178873,4.561183,9.648648,6.807680,11.108180,10.484558


In [32]:
distances.columns = [f"DistanceTo{i}thCluster" for i in range(distances.shape[1])]
distances

Unnamed: 0,DistanceTo0thCluster,DistanceTo1thCluster,DistanceTo2thCluster,DistanceTo3thCluster,DistanceTo4thCluster,DistanceTo5thCluster,DistanceTo6thCluster,DistanceTo7thCluster,DistanceTo8thCluster,DistanceTo9thCluster,...,DistanceTo20thCluster,DistanceTo21thCluster,DistanceTo22thCluster,DistanceTo23thCluster,DistanceTo24thCluster,DistanceTo25thCluster,DistanceTo26thCluster,DistanceTo27thCluster,DistanceTo28thCluster,DistanceTo29thCluster
0,9.291727,11.372839,12.157770,9.772949,10.619096,8.222388,7.176412,12.964418,10.913680,4.568590,...,9.618526,10.472997,11.832791,9.911668,7.689898,10.696194,9.829049,11.195169,7.549977,7.043384
1,8.043159,11.037454,10.625260,9.436192,10.358827,7.140182,7.121281,11.155830,9.781802,6.717895,...,9.395688,10.328116,10.752964,8.603471,7.908851,10.552367,7.942627,10.495662,7.581461,7.720238
2,8.370312,10.499072,10.566326,8.872985,9.751165,7.893025,6.322171,11.314504,9.492709,5.197131,...,9.118278,9.783010,10.474702,8.912313,8.336408,10.022666,7.763417,10.149131,8.149964,6.416407
3,8.291944,10.849169,10.848932,9.728095,9.979709,6.830172,4.812199,11.612259,10.548822,6.141023,...,9.899307,9.833903,10.883025,8.405615,8.771545,10.288205,9.148214,10.250868,8.397944,6.955458
4,8.128834,8.898019,8.589430,9.796355,8.364449,7.160905,6.189486,8.601188,9.910775,8.113894,...,9.403523,8.114850,8.455729,6.185608,9.521733,8.628510,8.307801,8.969910,8.678846,8.037048
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,9.616599,4.668174,10.309564,9.104768,7.119634,10.251859,10.797025,10.104640,11.813255,11.341706,...,10.883625,8.089468,7.338254,9.323751,11.798183,6.195448,10.388471,8.767579,11.775478,11.182950
7019,8.919337,4.022430,8.595892,8.993085,5.819405,9.316074,9.761645,8.213351,10.715178,10.552584,...,10.398132,6.814692,6.039435,8.233876,10.987240,5.244951,9.519238,7.958955,10.897470,10.397407
7020,8.579074,4.124624,9.049135,8.487318,6.127579,9.488235,9.833517,8.852289,10.804566,10.464577,...,10.170457,7.257630,6.763633,8.598885,10.727423,5.492655,9.342589,7.969790,10.759872,9.925567
7021,8.847170,5.783009,9.986953,9.128920,5.699656,9.458660,9.713754,9.753468,10.775406,10.298182,...,10.355325,5.373230,6.419209,8.106543,11.178873,4.561183,9.648648,6.807680,11.108180,10.484558


In [33]:
post_df = pd.concat([post_df, distances], axis=1)
post_df

Unnamed: 0,post_id,text,topic,DistanceTo0thCluster,DistanceTo1thCluster,DistanceTo2thCluster,DistanceTo3thCluster,DistanceTo4thCluster,DistanceTo5thCluster,DistanceTo6thCluster,...,DistanceTo20thCluster,DistanceTo21thCluster,DistanceTo22thCluster,DistanceTo23thCluster,DistanceTo24thCluster,DistanceTo25thCluster,DistanceTo26thCluster,DistanceTo27thCluster,DistanceTo28thCluster,DistanceTo29thCluster
0,1,UK economy facing major risks\n\nThe UK manufa...,business,9.291727,11.372839,12.157770,9.772949,10.619096,8.222388,7.176412,...,9.618526,10.472997,11.832791,9.911668,7.689898,10.696194,9.829049,11.195169,7.549977,7.043384
1,2,Aids and climate top Davos agenda\n\nClimate c...,business,8.043159,11.037454,10.625260,9.436192,10.358827,7.140182,7.121281,...,9.395688,10.328116,10.752964,8.603471,7.908851,10.552367,7.942627,10.495662,7.581461,7.720238
2,3,Asian quake hits European shares\n\nShares in ...,business,8.370312,10.499072,10.566326,8.872985,9.751165,7.893025,6.322171,...,9.118278,9.783010,10.474702,8.912313,8.336408,10.022666,7.763417,10.149131,8.149964,6.416407
3,4,India power shares jump on debut\n\nShares in ...,business,8.291944,10.849169,10.848932,9.728095,9.979709,6.830172,4.812199,...,9.899307,9.833903,10.883025,8.405615,8.771545,10.288205,9.148214,10.250868,8.397944,6.955458
4,5,Lacroix label bought by US firm\n\nLuxury good...,business,8.128834,8.898019,8.589430,9.796355,8.364449,7.160905,6.189486,...,9.403523,8.114850,8.455729,6.185608,9.521733,8.628510,8.307801,8.969910,8.678846,8.037048
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,7315,"OK, I would not normally watch a Farrelly brot...",movie,9.616599,4.668174,10.309564,9.104768,7.119634,10.251859,10.797025,...,10.883625,8.089468,7.338254,9.323751,11.798183,6.195448,10.388471,8.767579,11.775478,11.182950
7019,7316,I give this movie 2 stars purely because of it...,movie,8.919337,4.022430,8.595892,8.993085,5.819405,9.316074,9.761645,...,10.398132,6.814692,6.039435,8.233876,10.987240,5.244951,9.519238,7.958955,10.897470,10.397407
7020,7317,I cant believe this film was allowed to be mad...,movie,8.579074,4.124624,9.049135,8.487318,6.127579,9.488235,9.833517,...,10.170457,7.257630,6.763633,8.598885,10.727423,5.492655,9.342589,7.969790,10.759872,9.925567
7021,7318,The version I saw of this film was the Blockbu...,movie,8.847170,5.783009,9.986953,9.128920,5.699656,9.458660,9.713754,...,10.355325,5.373230,6.419209,8.106543,11.178873,4.561183,9.648648,6.807680,11.108180,10.484558


In [10]:
df = pd.merge(feed_df, post_df, on='post_id', how='left')
df = pd.merge(df, user_df, on='user_id', how='left')

df['hour'] = pd.to_datetime(df['timestamp']).apply(lambda x: x.hour)
df['month'] = pd.to_datetime(df['timestamp']).apply(lambda x: x.month)

df = df.drop([
    # 'timestamp', 
    'action',
    'text',
    'topic',
    ], axis=1
    )

df = df.set_index(['user_id', 'post_id'])
df

Unnamed: 0_level_0,Unnamed: 1_level_0,timestamp,target,DistanceTo0thCluster,DistanceTo1thCluster,DistanceTo2thCluster,DistanceTo3thCluster,DistanceTo4thCluster,DistanceTo5thCluster,DistanceTo6thCluster,DistanceTo7thCluster,...,DistanceTo29thCluster,gender,age,country,city,exp_group,os,source,hour,month
user_id,post_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
73613,7170,2021-11-21 18:24:23,0,9.695014,8.380120,12.401138,9.454146,7.849140,10.575001,11.043717,12.663692,...,10.435617,0,36,Russia,Moscow,3,Android,ads,18,11
73613,4306,2021-11-21 18:27:01,0,8.726352,7.158506,10.671004,8.903807,7.374293,9.794944,9.996875,10.618178,...,10.365971,0,36,Russia,Moscow,3,Android,ads,18,11
73613,6401,2021-11-21 18:28:53,0,7.874333,5.012115,9.430361,8.554598,5.722526,8.626222,8.872305,9.324366,...,9.395555,0,36,Russia,Moscow,3,Android,ads,18,11
73613,1346,2021-11-21 18:31:51,0,8.041052,10.894142,12.343847,8.696802,10.813454,8.269463,8.608901,12.915394,...,7.231206,0,36,Russia,Moscow,3,Android,ads,18,11
73613,1722,2021-11-21 18:34:05,0,8.581766,11.185467,13.294253,6.645728,10.787417,10.360549,10.180234,13.811675,...,7.811396,0,36,Russia,Moscow,3,Android,ads,18,11
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22682,3463,2021-12-18 16:57:39,0,11.050144,8.386330,3.888911,11.842509,8.509528,9.848718,9.992439,3.273405,...,11.866624,1,26,Russia,Stavropol,1,Android,ads,16,12
22682,3939,2021-12-18 16:59:39,0,10.993773,9.291432,3.835819,11.907921,9.537594,10.134304,10.054476,4.338567,...,11.710479,1,26,Russia,Stavropol,1,Android,ads,16,12
22682,5424,2021-12-18 17:00:58,0,9.516913,6.041898,7.059989,10.318454,5.660382,9.259231,9.667165,6.431178,...,10.882202,1,26,Russia,Stavropol,1,Android,ads,17,12
22682,731,2021-12-18 17:01:24,0,7.400699,8.726024,11.774611,6.827807,9.558574,9.158106,9.635085,12.008074,...,8.798364,1,26,Russia,Stavropol,1,Android,ads,17,12


In [11]:
def prepare_data(df):
    ### Split by 2021-12-15

    df_train = df[df.timestamp < '2021-12-15']
    df_test = df[df.timestamp >= '2021-12-15']

    df_train = df_train.drop('timestamp', axis=1)
    df_test = df_test.drop('timestamp', axis=1)

    X_train = df_train.drop('target', axis=1)
    X_test = df_test.drop('target', axis=1)

    y_train = df_train['target']
    y_test = df_test['target']

    return X_train, y_train, X_test, y_test

X_train, y_train, X_test, y_test = prepare_data(df)
X_train.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,DistanceTo0thCluster,DistanceTo1thCluster,DistanceTo2thCluster,DistanceTo3thCluster,DistanceTo4thCluster,DistanceTo5thCluster,DistanceTo6thCluster,DistanceTo7thCluster,DistanceTo8thCluster,DistanceTo9thCluster,...,DistanceTo29thCluster,gender,age,country,city,exp_group,os,source,hour,month
user_id,post_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
73613,7170,9.695014,8.38012,12.401138,9.454146,7.84914,10.575001,11.043717,12.663692,12.230749,11.088776,...,10.435617,0,36,Russia,Moscow,3,Android,ads,18,11
73613,4306,8.726352,7.158506,10.671004,8.903807,7.374293,9.794944,9.996875,10.618178,10.257808,10.227919,...,10.365971,0,36,Russia,Moscow,3,Android,ads,18,11
73613,6401,7.874333,5.012115,9.430361,8.554598,5.722526,8.626222,8.872305,9.324366,10.125917,9.601742,...,9.395555,0,36,Russia,Moscow,3,Android,ads,18,11
73613,1346,8.041052,10.894142,12.343847,8.696802,10.813454,8.269463,8.608901,12.915394,10.830884,8.164646,...,7.231206,0,36,Russia,Moscow,3,Android,ads,18,11
73613,1722,8.581766,11.185467,13.294253,6.645728,10.787417,10.360549,10.180234,13.811675,10.253342,9.530361,...,7.811396,0,36,Russia,Moscow,3,Android,ads,18,11


In [13]:
catboost_default = CatBoostClassifier(task_type='GPU')

object_cols = ['country', 'city', 'os', 'source']

catboost_default.fit(X_train, y_train, object_cols, verbose=False)

<catboost.core.CatBoostClassifier at 0x26a72efe0b0>

In [14]:
print(f"ROCAUC TRAIN: {roc_auc_score(y_train, catboost_default.predict_proba(X_train)[:, 1])}")
print(f"ROCAUC TEST: {roc_auc_score(y_test, catboost_default.predict_proba(X_test)[:, 1])}")

ROCAUC TRAIN: 0.6689334738782516
ROCAUC TEST: 0.6363864272891095


In [15]:
catboost_default.save_model(
    'catboost_enhanced_model',
    format="cbm"                  
)

In [16]:
from_file = CatBoostClassifier()  # здесь не указываем параметры, которые были при обучении, в дампе модели все есть

from_file.load_model("catboost_enhanced_model")

from_file.predict(X_train)

array([0, 0, 0, ..., 0, 0, 0], dtype=int64)

In [40]:
post_df.to_sql(    
   "daria_stikheeva_enhanced_model_post_features",                    
    con=CONN,
    schema="public",                   
    if_exists='replace'            
   )

23

In [41]:
post_df

Unnamed: 0,post_id,text,topic,DistanceTo0thCluster,DistanceTo1thCluster,DistanceTo2thCluster,DistanceTo3thCluster,DistanceTo4thCluster,DistanceTo5thCluster,DistanceTo6thCluster,...,DistanceTo20thCluster,DistanceTo21thCluster,DistanceTo22thCluster,DistanceTo23thCluster,DistanceTo24thCluster,DistanceTo25thCluster,DistanceTo26thCluster,DistanceTo27thCluster,DistanceTo28thCluster,DistanceTo29thCluster
0,1,UK economy facing major risks\n\nThe UK manufa...,business,9.291727,11.372839,12.157770,9.772949,10.619096,8.222388,7.176412,...,9.618526,10.472997,11.832791,9.911668,7.689898,10.696194,9.829049,11.195169,7.549977,7.043384
1,2,Aids and climate top Davos agenda\n\nClimate c...,business,8.043159,11.037454,10.625260,9.436192,10.358827,7.140182,7.121281,...,9.395688,10.328116,10.752964,8.603471,7.908851,10.552367,7.942627,10.495662,7.581461,7.720238
2,3,Asian quake hits European shares\n\nShares in ...,business,8.370312,10.499072,10.566326,8.872985,9.751165,7.893025,6.322171,...,9.118278,9.783010,10.474702,8.912313,8.336408,10.022666,7.763417,10.149131,8.149964,6.416407
3,4,India power shares jump on debut\n\nShares in ...,business,8.291944,10.849169,10.848932,9.728095,9.979709,6.830172,4.812199,...,9.899307,9.833903,10.883025,8.405615,8.771545,10.288205,9.148214,10.250868,8.397944,6.955458
4,5,Lacroix label bought by US firm\n\nLuxury good...,business,8.128834,8.898019,8.589430,9.796355,8.364449,7.160905,6.189486,...,9.403523,8.114850,8.455729,6.185608,9.521733,8.628510,8.307801,8.969910,8.678846,8.037048
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,7315,"OK, I would not normally watch a Farrelly brot...",movie,9.616599,4.668174,10.309564,9.104768,7.119634,10.251859,10.797025,...,10.883625,8.089468,7.338254,9.323751,11.798183,6.195448,10.388471,8.767579,11.775478,11.182950
7019,7316,I give this movie 2 stars purely because of it...,movie,8.919337,4.022430,8.595892,8.993085,5.819405,9.316074,9.761645,...,10.398132,6.814692,6.039435,8.233876,10.987240,5.244951,9.519238,7.958955,10.897470,10.397407
7020,7317,I cant believe this film was allowed to be mad...,movie,8.579074,4.124624,9.049135,8.487318,6.127579,9.488235,9.833517,...,10.170457,7.257630,6.763633,8.598885,10.727423,5.492655,9.342589,7.969790,10.759872,9.925567
7021,7318,The version I saw of this film was the Blockbu...,movie,8.847170,5.783009,9.986953,9.128920,5.699656,9.458660,9.713754,...,10.355325,5.373230,6.419209,8.106543,11.178873,4.561183,9.648648,6.807680,11.108180,10.484558
