In [6]:
import torch
import torch.nn as nn
import numpy as np

from warnings import filterwarnings

filterwarnings('ignore')

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

from sqlalchemy import create_engine

db_url = "postgresql://robot-startml-ro:pheiph0hahj1Vaif@postgres.lab.karpov.courses:6432/startml"
engine = create_engine(db_url)

In [7]:
from transformers import AutoTokenizer
from transformers import BertModel
from transformers import RobertaModel
from transformers import DistilBertModel


def get_model(model_name):
    assert model_name in ['bert', 'roberta', 'distilbert']
    
    checkpoint_names = {
        'bert': 'bert-base-cased',
        'roberta': 'roberta-base',
        'distilbert': 'distilbert-base-cased'
    }
    
    model_classes = {
        'bert': BertModel,
        'roberta': RobertaModel,
        'distilbert': DistilBertModel
    }
    
    return AutoTokenizer.from_pretrained(checkpoint_names[model_name]), model_classes[model_name].from_pretrained(checkpoint_names[model_name])

In [11]:
tokenizer, model = get_model('bert')

model = model.to(device)

In [26]:
import pandas as pd

def load_post_data():
    query = """
    SELECT *
    FROM post_text_df;
    """

    conn = engine.connect().execution_options(stream_results=True)
    post_data = pd.read_sql(query, conn)
    conn.close()
    
    return post_data

post_data = load_post_data()
texts = post_data['text']


In [12]:
texts

0       UK economy facing major risks\n\nThe UK manufa...
1       Aids and climate top Davos agenda\n\nClimate c...
2       Asian quake hits European shares\n\nShares in ...
3       India power shares jump on debut\n\nShares in ...
4       Lacroix label bought by US firm\n\nLuxury good...
                              ...                        
7018    OK, I would not normally watch a Farrelly brot...
7019    I give this movie 2 stars purely because of it...
7020    I cant believe this film was allowed to be mad...
7021    The version I saw of this film was the Blockbu...
7022    Piece of subtle art. Maybe a masterpiece. Doub...
Name: text, Length: 7023, dtype: object

In [13]:
from torch.utils.data import Dataset


class PostDataset(Dataset):
    def __init__(self, texts, tokenizer):
        super().__init__()

        self.texts = tokenizer.batch_encode_plus(
            texts,
            add_special_tokens=True,
            return_token_type_ids=False,
            return_tensors='pt',
            truncation=True,
            padding=True
        )
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        return {'input_ids': self.texts['input_ids'][idx], 'attention_mask': self.texts['attention_mask'][idx]}

    def __len__(self):
        return len(self.texts['input_ids'])

In [18]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

dataset = PostDataset(texts.values.tolist(), tokenizer)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

loader = DataLoader(dataset, batch_size=32, collate_fn=data_collator, pin_memory=True, shuffle=False)

In [19]:
import torch
from tqdm import tqdm


@torch.inference_mode()
def get_embeddings_labels(model, loader):
    model.eval()

    total_embeddings = []

    for batch in tqdm(loader):
        batch = {key: batch[key].to(device) for key in ['attention_mask', 'input_ids']}

        embeddings = model(**batch)['last_hidden_state'][:, 0, :]

        total_embeddings.append(embeddings.cpu())

    return torch.cat(total_embeddings, dim=0)

In [20]:
embeddings = get_embeddings_labels(model, loader).numpy()

100%|██████████| 220/220 [08:07<00:00,  2.22s/it]


In [21]:
texts_df = pd.DataFrame(embeddings)
texts_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.140364,-0.140695,-0.575681,-0.118174,-0.315324,-0.114378,0.431977,-0.144286,0.003233,-1.187819,...,0.051768,0.614536,-0.613106,0.131062,0.202857,0.175931,-0.167768,-0.137880,0.042959,0.142283
1,0.157531,-0.097739,-0.230650,-0.364431,-0.242781,0.310066,0.374489,-0.089235,0.202153,-1.130713,...,0.586033,0.652154,-0.112307,-0.085171,-0.051814,0.240047,0.200299,-0.300891,0.190541,0.019753
2,0.314569,-0.115163,-0.181322,-0.274698,-0.357378,0.285229,0.266597,0.002641,-0.033905,-1.092078,...,0.416196,0.641697,-0.326961,-0.042803,-0.073850,0.212023,-0.090192,-0.354049,-0.204324,-0.027025
3,0.415117,-0.241300,-0.260733,-0.436027,-0.194695,0.130077,0.458805,-0.235223,-0.032935,-1.008515,...,0.791115,0.562938,-0.194189,0.022461,0.108905,0.019628,0.362090,-0.150884,-0.048834,0.083071
4,0.614585,-0.235812,-0.047732,-0.406702,-0.284798,0.124151,0.545586,-0.284447,0.047562,-1.139113,...,0.605897,0.518613,0.007371,0.032436,0.015634,0.055695,0.145705,-0.061323,-0.021120,0.121545
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,0.521686,0.292126,-0.210759,-0.219636,0.124439,-0.032428,0.055145,0.118430,-0.033625,-1.167994,...,0.575905,0.319888,-0.427476,0.161139,0.179985,0.053670,0.051578,-0.353713,-0.010725,-0.047067
7019,0.487033,0.442944,-0.251731,-0.303212,0.068587,-0.176148,0.235078,-0.050545,-0.011601,-1.161574,...,0.077759,0.174062,-0.359764,-0.196519,0.043920,0.178852,-0.054636,-0.233469,0.349204,-0.031998
7020,0.619477,0.274618,-0.126503,-0.110224,0.167653,-0.184623,0.244474,0.020286,0.058746,-1.046320,...,0.307723,0.187838,-0.378196,-0.223615,0.019644,0.250250,0.038904,-0.358137,-0.164277,0.171023
7021,0.694089,0.067175,-0.228681,-0.255529,0.134819,0.223566,0.249515,-0.032106,-0.160485,-1.041768,...,0.361884,0.352623,-0.307156,-0.139961,0.150053,0.015053,0.045804,0.037859,0.141017,0.124412


In [23]:
from sklearn.decomposition import PCA

pca = PCA(n_components=30)
texts_pca = pca.fit_transform(texts_df)
texts_pca = pd.DataFrame(texts_pca)
texts_pca

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,-0.954863,1.755616,-0.206993,-1.697674,2.381038,0.014747,0.139377,-0.573150,-2.037540,0.976351,...,0.338530,-0.182809,0.190631,0.565924,0.521890,0.278507,-0.183996,-0.476529,0.777698,0.195513
1,-3.082158,0.872354,1.121489,-0.695172,-0.051389,0.234753,0.309024,-0.000409,-0.246721,-0.531308,...,-1.032854,0.208780,0.433039,-0.034591,-0.746456,-0.431069,-0.222845,0.334198,0.011763,-0.138270
2,-2.298751,0.771814,1.484301,-0.921830,-0.043765,0.577134,-0.102373,-0.678311,-1.163997,1.001126,...,0.149604,0.152361,0.019646,-0.075164,-0.181193,0.878006,-0.849582,0.181071,0.402798,-0.285991
3,-3.830445,0.031834,1.308007,2.101631,-0.484517,-0.095318,-0.206729,0.808245,0.210474,-0.079059,...,0.200726,0.088172,0.309463,0.112293,0.090700,0.034563,0.362013,-0.214260,0.076631,-0.149233
4,-2.248871,-0.231334,1.637724,1.685151,-0.175624,-0.363424,-0.877828,0.201336,0.476315,-0.037059,...,0.085179,0.336614,-0.259093,-0.190053,0.754564,-0.082125,0.089698,-0.211874,0.175428,0.156080
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7018,2.656969,1.065924,1.688353,0.109284,-0.562751,-0.528287,-0.269984,-0.190379,1.075467,0.433443,...,-0.559607,0.112844,0.114708,-0.406652,0.138423,-0.367053,-0.392969,-0.520288,-0.361686,0.389871
7019,2.547189,0.549891,0.247583,-0.458581,-0.235307,-0.913794,-0.242980,-0.429367,-0.440355,0.022959,...,0.473830,-0.309549,-0.350118,0.377452,0.640814,-0.760392,-0.267643,-0.487458,-0.144611,-0.388277
7020,2.438384,0.878835,1.034504,0.249472,-0.589289,-0.427737,-0.258931,-0.983647,0.138920,0.731668,...,-0.283128,-1.025396,0.102983,0.217687,-0.131774,-0.016691,0.304594,-0.000696,0.091741,0.177297
7021,2.499132,1.253244,0.025755,0.992777,0.589764,0.471234,-0.410792,0.301115,0.129272,0.344713,...,-0.123717,0.213075,0.419677,-0.255996,-0.023847,0.205595,0.304831,0.459587,-0.324650,-0.258726


In [27]:
post_data = pd.concat([post_data, texts_pca], axis=1)
post_data.to_sql('post_process_features_dl', con=engine, if_exists='replace')

296