In [1]:
import torch
import torch.nn as nn
import numpy as np

from warnings import filterwarnings

filterwarnings('ignore')

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

from sqlalchemy import create_engine

db_url = "postgresql://robot-startml-ro:pheiph0hahj1Vaif@postgres.lab.karpov.courses:6432/startml"
engine = create_engine(db_url)

In [2]:
from transformers import AutoTokenizer
from transformers import BertModel
from transformers import RobertaModel
from transformers import DistilBertModel


def get_model(model_name):
    assert model_name in ['bert', 'roberta', 'distilbert']
    
    checkpoint_names = {
        'bert': 'bert-base-cased',
        'roberta': 'roberta-base',
        'distilbert': 'distilbert-base-cased'
    }
    
    model_classes = {
        'bert': BertModel,
        'roberta': RobertaModel,
        'distilbert': DistilBertModel
    }
    
    return AutoTokenizer.from_pretrained(checkpoint_names[model_name]), model_classes[model_name].from_pretrained(checkpoint_names[model_name])

In [3]:
tokenizer, model = get_model('bert')

model = model.to(device)

In [44]:
import pandas as pd

def load_post_data():
    query = """
    SELECT *
    FROM post_text_df;
    """

    conn = engine.connect().execution_options(stream_results=True)
    post_data = pd.read_sql(query, conn)
    conn.close()
    
    return post_data

post_data = load_post_data()
texts = post_data['text']


In [5]:
texts

0       UK economy facing major risks\n\nThe UK manufa...
1       Aids and climate top Davos agenda\n\nClimate c...
2       Asian quake hits European shares\n\nShares in ...
3       India power shares jump on debut\n\nShares in ...
4       Lacroix label bought by US firm\n\nLuxury good...
                              ...                        
7018    OK, I would not normally watch a Farrelly brot...
7019    I give this movie 2 stars purely because of it...
7020    I cant believe this film was allowed to be mad...
7021    The version I saw of this film was the Blockbu...
7022    Piece of subtle art. Maybe a masterpiece. Doub...
Name: text, Length: 7023, dtype: object

In [6]:
from torch.utils.data import Dataset


class PostDataset(Dataset):
    def __init__(self, texts, tokenizer):
        super().__init__()

        self.texts = tokenizer.batch_encode_plus(
            texts,
            add_special_tokens=True,
            return_token_type_ids=False,
            return_tensors='pt',
            truncation=True,
            padding=True
        )
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        return {'input_ids': self.texts['input_ids'][idx], 'attention_mask': self.texts['attention_mask'][idx]}

    def __len__(self):
        return len(self.texts['input_ids'])

In [7]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

dataset = PostDataset(texts.values.tolist(), tokenizer)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

loader = DataLoader(dataset, batch_size=32, collate_fn=data_collator, pin_memory=True, shuffle=False)

In [8]:
import torch
from tqdm import tqdm


@torch.inference_mode()
def get_embeddings_labels(model, loader):
    model.eval()

    total_embeddings = []

    for batch in tqdm(loader):
        batch = {key: batch[key].to(device) for key in ['attention_mask', 'input_ids']}

        embeddings = model(**batch)['last_hidden_state'][:, 0, :]

        total_embeddings.append(embeddings.cpu())

    return torch.cat(total_embeddings, dim=0)

In [9]:
embeddings = get_embeddings_labels(model, loader).numpy()

  0%|          | 0/220 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 220/220 [04:59<00:00,  1.36s/it]


In [10]:
texts_df = pd.DataFrame(embeddings)
texts_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.140364,-0.140695,-0.575681,-0.118174,-0.315324,-0.114378,0.431977,-0.144286,0.003233,-1.187819,...,0.051768,0.614536,-0.613106,0.131062,0.202857,0.175931,-0.167768,-0.137880,0.042959,0.142283
1,0.157531,-0.097739,-0.230650,-0.364431,-0.242781,0.310066,0.374489,-0.089235,0.202153,-1.130713,...,0.586033,0.652154,-0.112307,-0.085171,-0.051814,0.240047,0.200299,-0.300891,0.190541,0.019753
2,0.314569,-0.115163,-0.181322,-0.274698,-0.357378,0.285229,0.266597,0.002641,-0.033905,-1.092078,...,0.416196,0.641697,-0.326961,-0.042803,-0.073850,0.212023,-0.090192,-0.354049,-0.204324,-0.027025
3,0.415117,-0.241300,-0.260733,-0.436027,-0.194695,0.130077,0.458805,-0.235223,-0.032935,-1.008515,...,0.791115,0.562938,-0.194189,0.022461,0.108905,0.019628,0.362090,-0.150884,-0.048834,0.083071
4,0.614585,-0.235812,-0.047732,-0.406702,-0.284798,0.124151,0.545586,-0.284447,0.047562,-1.139113,...,0.605897,0.518613,0.007371,0.032436,0.015634,0.055695,0.145705,-0.061323,-0.021120,0.121545
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,0.521686,0.292126,-0.210759,-0.219636,0.124439,-0.032428,0.055145,0.118430,-0.033625,-1.167994,...,0.575905,0.319888,-0.427476,0.161139,0.179985,0.053670,0.051578,-0.353713,-0.010725,-0.047067
7019,0.487033,0.442944,-0.251731,-0.303212,0.068587,-0.176148,0.235078,-0.050545,-0.011601,-1.161574,...,0.077759,0.174062,-0.359764,-0.196519,0.043920,0.178852,-0.054636,-0.233469,0.349204,-0.031998
7020,0.619477,0.274618,-0.126503,-0.110224,0.167653,-0.184623,0.244474,0.020286,0.058746,-1.046320,...,0.307723,0.187838,-0.378196,-0.223615,0.019644,0.250250,0.038904,-0.358137,-0.164277,0.171023
7021,0.694089,0.067175,-0.228681,-0.255529,0.134819,0.223566,0.249515,-0.032106,-0.160485,-1.041768,...,0.361884,0.352623,-0.307156,-0.139961,0.150053,0.015053,0.045804,0.037859,0.141017,0.124412


In [38]:
from sklearn.decomposition import PCA

pca = PCA(n_components=30)
texts_pca = pca.fit_transform(texts_df)
texts_pca = pd.DataFrame(texts_pca)
texts_pca.columns = [f"feature_{i}" for i in range(texts_pca.shape[1])]
texts_pca

Unnamed: 0,feature_0,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,feature_7,feature_8,feature_9,...,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29
0,-0.954874,1.755627,-0.206941,-1.697690,2.381061,0.014774,0.139381,-0.573169,-2.037523,0.976315,...,0.336253,-0.185460,0.194123,0.564932,0.525173,0.280055,-0.193786,-0.484033,0.806804,0.231849
1,-3.082160,0.872352,1.121503,-0.695175,-0.051375,0.234762,0.309035,-0.000412,-0.246709,-0.531299,...,-1.030952,0.208925,0.433281,-0.035999,-0.746134,-0.434780,-0.217824,0.344667,-0.003610,-0.201203
2,-2.298752,0.771814,1.484302,-0.921823,-0.043765,0.577132,-0.102365,-0.678309,-1.164002,1.001131,...,0.150932,0.150034,0.019595,-0.074723,-0.176677,0.879467,-0.843823,0.186000,0.404308,-0.331484
3,-3.830448,0.031835,1.308001,2.101628,-0.484515,-0.095320,-0.206715,0.808234,0.210454,-0.079063,...,0.199250,0.091348,0.306556,0.114134,0.085322,0.039163,0.362099,-0.216753,0.067907,-0.138530
4,-2.248873,-0.231333,1.637718,1.685153,-0.175618,-0.363433,-0.877844,0.201343,0.476305,-0.037055,...,0.084889,0.336607,-0.261492,-0.188992,0.750385,-0.081467,0.087014,-0.205021,0.161054,0.134185
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,2.656971,1.065925,1.688354,0.109288,-0.562751,-0.528288,-0.269987,-0.190375,1.075467,0.433451,...,-0.559354,0.110457,0.117231,-0.406836,0.144725,-0.372739,-0.408536,-0.509177,-0.357776,0.364907
7019,2.547191,0.549891,0.247585,-0.458580,-0.235306,-0.913794,-0.242980,-0.429371,-0.440356,0.022951,...,0.474300,-0.310189,-0.347831,0.376716,0.641611,-0.760319,-0.267477,-0.496896,-0.132134,-0.384323
7020,2.438385,0.878835,1.034503,0.249474,-0.589289,-0.427737,-0.258931,-0.983647,0.138919,0.731666,...,-0.284936,-1.029313,0.106715,0.217382,-0.126448,-0.017247,0.306917,-0.000604,0.103954,0.173402
7021,2.499133,1.253243,0.025751,0.992776,0.589763,0.471235,-0.410792,0.301116,0.129271,0.344713,...,-0.124480,0.209611,0.424176,-0.256761,-0.019671,0.201820,0.301852,0.456726,-0.311787,-0.238059


In [42]:
# from sqlalchemy import text

# with engine.connect() as connection:
#     connection.execute(text("DROP TABLE IF EXISTS post_process_features_dl"))

In [45]:
post_data = pd.concat([post_data, texts_pca], axis=1)
post_data

Unnamed: 0,post_id,text,topic,feature_0,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,...,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29
0,1,UK economy facing major risks\n\nThe UK manufa...,business,-0.954874,1.755627,-0.206941,-1.697690,2.381061,0.014774,0.139381,...,0.336253,-0.185460,0.194123,0.564932,0.525173,0.280055,-0.193786,-0.484033,0.806804,0.231849
1,2,Aids and climate top Davos agenda\n\nClimate c...,business,-3.082160,0.872352,1.121503,-0.695175,-0.051375,0.234762,0.309035,...,-1.030952,0.208925,0.433281,-0.035999,-0.746134,-0.434780,-0.217824,0.344667,-0.003610,-0.201203
2,3,Asian quake hits European shares\n\nShares in ...,business,-2.298752,0.771814,1.484302,-0.921823,-0.043765,0.577132,-0.102365,...,0.150932,0.150034,0.019595,-0.074723,-0.176677,0.879467,-0.843823,0.186000,0.404308,-0.331484
3,4,India power shares jump on debut\n\nShares in ...,business,-3.830448,0.031835,1.308001,2.101628,-0.484515,-0.095320,-0.206715,...,0.199250,0.091348,0.306556,0.114134,0.085322,0.039163,0.362099,-0.216753,0.067907,-0.138530
4,5,Lacroix label bought by US firm\n\nLuxury good...,business,-2.248873,-0.231333,1.637718,1.685153,-0.175618,-0.363433,-0.877844,...,0.084889,0.336607,-0.261492,-0.188992,0.750385,-0.081467,0.087014,-0.205021,0.161054,0.134185
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,7315,"OK, I would not normally watch a Farrelly brot...",movie,2.656971,1.065925,1.688354,0.109288,-0.562751,-0.528288,-0.269987,...,-0.559354,0.110457,0.117231,-0.406836,0.144725,-0.372739,-0.408536,-0.509177,-0.357776,0.364907
7019,7316,I give this movie 2 stars purely because of it...,movie,2.547191,0.549891,0.247585,-0.458580,-0.235306,-0.913794,-0.242980,...,0.474300,-0.310189,-0.347831,0.376716,0.641611,-0.760319,-0.267477,-0.496896,-0.132134,-0.384323
7020,7317,I cant believe this film was allowed to be mad...,movie,2.438385,0.878835,1.034503,0.249474,-0.589289,-0.427737,-0.258931,...,-0.284936,-1.029313,0.106715,0.217382,-0.126448,-0.017247,0.306917,-0.000604,0.103954,0.173402
7021,7318,The version I saw of this film was the Blockbu...,movie,2.499133,1.253243,0.025751,0.992776,0.589763,0.471235,-0.410792,...,-0.124480,0.209611,0.424176,-0.256761,-0.019671,0.201820,0.301852,0.456726,-0.311787,-0.238059


In [46]:
post_data.to_sql('post_process_features_dl', con=engine, if_exists="replace")

296