In [41]:
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd

from warnings import filterwarnings

filterwarnings('ignore')

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

from sqlalchemy import create_engine

from transformers import AutoTokenizer
from transformers import BertModel
from transformers import RobertaModel
from transformers import DistilBertModel
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

import torch
from tqdm import tqdm

from torch.utils.data import Dataset

from sklearn.decomposition import PCA

engine = create_engine(
        "postgresql://robot-startml-ro:pheiph0hahj1Vaif@"
        "postgres.lab.karpov.courses:6432/startml",
        pool_size=10,
        max_overflow=20,
        pool_timeout=30
    )

In [None]:
def get_model(model_name):
    assert model_name in ['bert', 'roberta', 'distilbert']
    
    checkpoint_names = {
        'bert': 'bert-base-cased',
        'roberta': 'roberta-base',
        'distilbert': 'distilbert-base-cased'
    }
    
    model_classes = {
        'bert': BertModel,
        'roberta': RobertaModel,
        'distilbert': DistilBertModel
    }
    
    return AutoTokenizer.from_pretrained(checkpoint_names[model_name]), model_classes[model_name].from_pretrained(checkpoint_names[model_name])


In [5]:
tokenizer, model = get_model('bert')

model = model.to(device)

In [8]:
engine = create_engine(
    "postgresql://robot-startml-ro:pheiph0hahj1Vaif@"
    "postgres.lab.karpov.courses:6432/startml"
)

with engine.connect() as conn:
    post_df = pd.read_sql(
        sql="SELECT * FROM public.post_text_df",
        con=conn.connection
    )


In [9]:
texts = post_df['text']
texts

0       UK economy facing major risks\n\nThe UK manufa...
1       Aids and climate top Davos agenda\n\nClimate c...
2       Asian quake hits European shares\n\nShares in ...
3       India power shares jump on debut\n\nShares in ...
4       Lacroix label bought by US firm\n\nLuxury good...
                              ...                        
7018    OK, I would not normally watch a Farrelly brot...
7019    I give this movie 2 stars purely because of it...
7020    I cant believe this film was allowed to be mad...
7021    The version I saw of this film was the Blockbu...
7022    Piece of subtle art. Maybe a masterpiece. Doub...
Name: text, Length: 7023, dtype: object

In [33]:
class PostDataset(Dataset):
    def __init__(self, texts, tokenizer):
        super().__init__()

        self.texts = tokenizer.batch_encode_plus(
            texts,
            add_special_tokens=True,
            return_token_type_ids=False,
            return_tensors='pt',
            truncation=True,
            padding=True
        )
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        return {'input_ids': self.texts['input_ids'][idx], 'attention_mask': self.texts['attention_mask'][idx]}

    def __len__(self):
        return len(self.texts['input_ids'])
    
dataset = PostDataset(texts.values.tolist(), tokenizer)

In [37]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

loader = DataLoader(dataset, batch_size=32, collate_fn=data_collator, pin_memory=True, shuffle=False)

In [38]:
@torch.inference_mode()
def get_embeddings_labels(model, loader):
    model.eval()

    total_embeddings = []

    for batch in tqdm(loader):
        batch = {key: batch[key].to(device) for key in ['attention_mask', 'input_ids']}

        embeddings = model(**batch)['last_hidden_state'][:, 0, :]

        total_embeddings.append(embeddings.cpu())

    return torch.cat(total_embeddings, dim=0)

In [39]:
embeddings = get_embeddings_labels(model, loader)
embeddings

100%|██████████| 220/220 [31:10<00:00,  8.50s/it]


tensor([[ 0.1404, -0.1407, -0.5757,  ..., -0.1379,  0.0430,  0.1423],
        [ 0.1575, -0.0977, -0.2307,  ..., -0.3009,  0.1905,  0.0198],
        [ 0.3146, -0.1152, -0.1813,  ..., -0.3541, -0.2043, -0.0270],
        ...,
        [ 0.6195,  0.2746, -0.1265,  ..., -0.3581, -0.1643,  0.1710],
        [ 0.6941,  0.0672, -0.2287,  ...,  0.0379,  0.1410,  0.1244],
        [ 0.4166,  0.1736, -0.1788,  ..., -0.2106,  0.3133,  0.0338]])

In [40]:
texts_df = pd.DataFrame(embeddings)
texts_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.140363,-0.140695,-0.575681,-0.118175,-0.315324,-0.114379,0.431977,-0.144286,0.003233,-1.187819,...,0.051768,0.614536,-0.613107,0.131062,0.202857,0.175931,-0.167769,-0.137880,0.042959,0.142284
1,0.157530,-0.097739,-0.230651,-0.364432,-0.242782,0.310065,0.374489,-0.089235,0.202153,-1.130713,...,0.586032,0.652154,-0.112309,-0.085171,-0.051814,0.240048,0.200299,-0.300891,0.190542,0.019754
2,0.314568,-0.115163,-0.181322,-0.274698,-0.357378,0.285227,0.266597,0.002641,-0.033905,-1.092079,...,0.416194,0.641697,-0.326962,-0.042802,-0.073850,0.212023,-0.090193,-0.354050,-0.204323,-0.027024
3,0.415116,-0.241301,-0.260733,-0.436026,-0.194695,0.130077,0.458804,-0.235223,-0.032935,-1.008514,...,0.791115,0.562938,-0.194190,0.022462,0.108904,0.019627,0.362090,-0.150884,-0.048834,0.083071
4,0.614585,-0.235812,-0.047732,-0.406701,-0.284798,0.124150,0.545586,-0.284447,0.047562,-1.139114,...,0.605897,0.518613,0.007372,0.032436,0.015634,0.055696,0.145705,-0.061322,-0.021120,0.121545
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,0.521686,0.292126,-0.210760,-0.219636,0.124439,-0.032429,0.055145,0.118430,-0.033624,-1.167992,...,0.575904,0.319888,-0.427477,0.161139,0.179985,0.053671,0.051577,-0.353712,-0.010725,-0.047067
7019,0.487033,0.442944,-0.251731,-0.303212,0.068587,-0.176149,0.235079,-0.050546,-0.011602,-1.161575,...,0.077758,0.174062,-0.359764,-0.196519,0.043920,0.178851,-0.054635,-0.233469,0.349205,-0.031997
7020,0.619477,0.274618,-0.126504,-0.110225,0.167652,-0.184624,0.244474,0.020286,0.058746,-1.046320,...,0.307723,0.187838,-0.378197,-0.223616,0.019644,0.250250,0.038904,-0.358137,-0.164277,0.171024
7021,0.694089,0.067175,-0.228680,-0.255529,0.134818,0.223566,0.249515,-0.032106,-0.160485,-1.041768,...,0.361884,0.352622,-0.307157,-0.139961,0.150054,0.015053,0.045804,0.037859,0.141017,0.124412


In [42]:
pca = PCA(n_components=30)
texts_pca = pca.fit_transform(texts_df)
texts_pca = pd.DataFrame(texts_pca)
texts_pca.columns = [f"feature_{i}" for i in range(texts_pca.shape[1])]
texts_pca

Unnamed: 0,feature_0,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,feature_7,feature_8,feature_9,...,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29
0,-0.954854,1.755626,-0.206996,-1.697694,2.381043,0.014745,0.139376,-0.573178,-2.037547,0.976304,...,0.336478,-0.185772,0.194671,-0.565766,0.526488,0.283886,-0.196736,-0.481098,0.786151,0.213414
1,-3.082160,0.872374,1.121520,-0.695177,-0.051367,0.234765,0.309045,-0.000414,-0.246690,-0.531304,...,-1.031197,0.208442,0.432655,0.038369,-0.750575,-0.439471,-0.219063,0.345914,-0.002426,-0.134629
2,-2.298745,0.771816,1.484297,-0.921827,-0.043764,0.577132,-0.102377,-0.678315,-1.164012,1.001123,...,0.150226,0.151768,0.021240,0.073521,-0.177980,0.880746,-0.855430,0.180745,0.400131,-0.311773
3,-3.830453,0.031840,1.307999,2.101638,-0.484519,-0.095317,-0.206712,0.808228,0.210451,-0.079046,...,0.201019,0.089319,0.307492,-0.111880,0.087338,0.031964,0.363902,-0.215565,0.073987,-0.149682
4,-2.248866,-0.231331,1.637734,1.685146,-0.175615,-0.363438,-0.877842,0.201346,0.476313,-0.037055,...,0.085009,0.337344,-0.259403,0.189537,0.754722,-0.078927,0.084609,-0.211754,0.175888,0.168812
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,2.656976,1.065921,1.688354,0.109283,-0.562752,-0.528285,-0.269984,-0.190374,1.075465,0.433447,...,-0.560453,0.112495,0.116528,0.406745,0.143827,-0.364537,-0.393525,-0.520929,-0.366671,0.377670
7019,2.547191,0.549888,0.247580,-0.458580,-0.235311,-0.913796,-0.242978,-0.429370,-0.440352,0.022955,...,0.475006,-0.312133,-0.348936,-0.374635,0.639752,-0.764498,-0.268806,-0.484107,-0.152320,-0.409746
7020,2.438391,0.878832,1.034503,0.249472,-0.589286,-0.427734,-0.258928,-0.983651,0.138922,0.731661,...,-0.285687,-1.025985,0.105423,-0.219085,-0.130385,-0.014398,0.301613,-0.002806,0.104604,0.211356
7021,2.499131,1.253240,0.025757,0.992778,0.589767,0.471231,-0.410792,0.301115,0.129270,0.344709,...,-0.124819,0.211440,0.423616,0.255860,-0.018684,0.207491,0.298303,0.457453,-0.320553,-0.247146


In [43]:
post_df = pd.concat([post_df, texts_pca], axis=1)
post_df

Unnamed: 0,post_id,text,topic,feature_0,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,...,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29
0,1,UK economy facing major risks\n\nThe UK manufa...,business,-0.954854,1.755626,-0.206996,-1.697694,2.381043,0.014745,0.139376,...,0.336478,-0.185772,0.194671,-0.565766,0.526488,0.283886,-0.196736,-0.481098,0.786151,0.213414
1,2,Aids and climate top Davos agenda\n\nClimate c...,business,-3.082160,0.872374,1.121520,-0.695177,-0.051367,0.234765,0.309045,...,-1.031197,0.208442,0.432655,0.038369,-0.750575,-0.439471,-0.219063,0.345914,-0.002426,-0.134629
2,3,Asian quake hits European shares\n\nShares in ...,business,-2.298745,0.771816,1.484297,-0.921827,-0.043764,0.577132,-0.102377,...,0.150226,0.151768,0.021240,0.073521,-0.177980,0.880746,-0.855430,0.180745,0.400131,-0.311773
3,4,India power shares jump on debut\n\nShares in ...,business,-3.830453,0.031840,1.307999,2.101638,-0.484519,-0.095317,-0.206712,...,0.201019,0.089319,0.307492,-0.111880,0.087338,0.031964,0.363902,-0.215565,0.073987,-0.149682
4,5,Lacroix label bought by US firm\n\nLuxury good...,business,-2.248866,-0.231331,1.637734,1.685146,-0.175615,-0.363438,-0.877842,...,0.085009,0.337344,-0.259403,0.189537,0.754722,-0.078927,0.084609,-0.211754,0.175888,0.168812
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,7315,"OK, I would not normally watch a Farrelly brot...",movie,2.656976,1.065921,1.688354,0.109283,-0.562752,-0.528285,-0.269984,...,-0.560453,0.112495,0.116528,0.406745,0.143827,-0.364537,-0.393525,-0.520929,-0.366671,0.377670
7019,7316,I give this movie 2 stars purely because of it...,movie,2.547191,0.549888,0.247580,-0.458580,-0.235311,-0.913796,-0.242978,...,0.475006,-0.312133,-0.348936,-0.374635,0.639752,-0.764498,-0.268806,-0.484107,-0.152320,-0.409746
7020,7317,I cant believe this film was allowed to be mad...,movie,2.438391,0.878832,1.034503,0.249472,-0.589286,-0.427734,-0.258928,...,-0.285687,-1.025985,0.105423,-0.219085,-0.130385,-0.014398,0.301613,-0.002806,0.104604,0.211356
7021,7318,The version I saw of this film was the Blockbu...,movie,2.499131,1.253240,0.025757,0.992778,0.589767,0.471231,-0.410792,...,-0.124819,0.211440,0.423616,0.255860,-0.018684,0.207491,0.298303,0.457453,-0.320553,-0.247146


In [44]:
post_df = post_df.drop('text', axis=1)
post_df

Unnamed: 0,post_id,topic,feature_0,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,feature_7,...,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29
0,1,business,-0.954854,1.755626,-0.206996,-1.697694,2.381043,0.014745,0.139376,-0.573178,...,0.336478,-0.185772,0.194671,-0.565766,0.526488,0.283886,-0.196736,-0.481098,0.786151,0.213414
1,2,business,-3.082160,0.872374,1.121520,-0.695177,-0.051367,0.234765,0.309045,-0.000414,...,-1.031197,0.208442,0.432655,0.038369,-0.750575,-0.439471,-0.219063,0.345914,-0.002426,-0.134629
2,3,business,-2.298745,0.771816,1.484297,-0.921827,-0.043764,0.577132,-0.102377,-0.678315,...,0.150226,0.151768,0.021240,0.073521,-0.177980,0.880746,-0.855430,0.180745,0.400131,-0.311773
3,4,business,-3.830453,0.031840,1.307999,2.101638,-0.484519,-0.095317,-0.206712,0.808228,...,0.201019,0.089319,0.307492,-0.111880,0.087338,0.031964,0.363902,-0.215565,0.073987,-0.149682
4,5,business,-2.248866,-0.231331,1.637734,1.685146,-0.175615,-0.363438,-0.877842,0.201346,...,0.085009,0.337344,-0.259403,0.189537,0.754722,-0.078927,0.084609,-0.211754,0.175888,0.168812
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,7315,movie,2.656976,1.065921,1.688354,0.109283,-0.562752,-0.528285,-0.269984,-0.190374,...,-0.560453,0.112495,0.116528,0.406745,0.143827,-0.364537,-0.393525,-0.520929,-0.366671,0.377670
7019,7316,movie,2.547191,0.549888,0.247580,-0.458580,-0.235311,-0.913796,-0.242978,-0.429370,...,0.475006,-0.312133,-0.348936,-0.374635,0.639752,-0.764498,-0.268806,-0.484107,-0.152320,-0.409746
7020,7317,movie,2.438391,0.878832,1.034503,0.249472,-0.589286,-0.427734,-0.258928,-0.983651,...,-0.285687,-1.025985,0.105423,-0.219085,-0.130385,-0.014398,0.301613,-0.002806,0.104604,0.211356
7021,7318,movie,2.499131,1.253240,0.025757,0.992778,0.589767,0.471231,-0.410792,0.301115,...,-0.124819,0.211440,0.423616,0.255860,-0.018684,0.207491,0.298303,0.457453,-0.320553,-0.247146


In [45]:
def batch_load_sql(query: str) -> pd.DataFrame:

    ### Читаем записанный DataFrame из базы данных -->>

    # Функция для чтения признаков из базы данных батчами

    CHUNKSIZE = 200000
    engine = create_engine("postgresql://robot-startml-ro:pheiph0hahj1Vaif@postgres.lab.karpov.courses:6432/startml")
    conn = engine.connect().execution_options(stream_results=True)
    chunks = []
    for chunk_dataframe in pd.read_sql(query, conn, chunksize=CHUNKSIZE):
        chunks.append(chunk_dataframe)
    conn.close()
    return pd.concat(chunks, ignore_index=True)


def load_features(table_name) -> pd.DataFrame:

    ### Читаем DataFrame из базы данных -->>
    query = f"SELECT * FROM {table_name}"
    return batch_load_sql(query)


def load_to_sql(table_name, data):

    ### Записываем DataFrame в базу данных -->>

    engine = create_engine("postgresql://robot-startml-ro:pheiph0hahj1Vaif@postgres.lab.karpov.courses:6432/startml")
    data.to_sql(table_name, con=engine, if_exists='replace', index=False, chunksize=10000)
    

In [46]:
# load_to_sql('darja_stiheeva_lms4973_post_features_bert_2', post_df)