In [54]:
PATH = '..'
DATA_SOURCE = f'{PATH}/data/source'
DATA_PROCESSED = f'{PATH}/data/processed'
DATA_MODELS = f'{PATH}/models'

SEED = 42
DEVICE = 'cpu'
MAX_LEN = 512

In [55]:
import pandas as pd
import os
from sklearn.model_selection import train_test_split

from transformers import AutoTokenizer, AutoModel
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import os
import random

import torch
from torch import nn, Tensor
from typing import Iterable, Dict


from tqdm.notebook import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline, make_pipeline, make_union
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
from torch.optim import AdamW
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score

from IPython.display import clear_output
import gc
from pathlib import Path

from transformers import AutoModelForSequenceClassification
from torch.nn import CrossEntropyLoss

from sklearn.metrics import f1_score
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import StratifiedKFold

import seaborn as sns

from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score

import warnings
warnings.filterwarnings("ignore", category=UserWarning) 


In [56]:
import torch
import torch.nn as nn
from transformers import BertModel

class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()
        
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min = 1e-9)
        mean_embeddings = sum_embeddings/sum_mask
        return mean_embeddings

class ColBERT(nn.Module):
    def __init__(self, bert, dim=768, add_pooling=False):
        super(ColBERT, self).__init__()
        self.bert = bert
        self.fc = nn.Linear(dim*2**add_pooling, 1)
        self.sigmoid = nn.Sigmoid()
        self.pooling = MeanPooling()
        self.add_pooling = add_pooling
        
    def forward(self, query, document):
        query_embedding = self.bert(**query)[0][:, 0, :]
        document_embedding = self.bert(**document)[0][:, 0, :]
        interaction = torch.mul(query_embedding, document_embedding)
        
        if self.add_pooling:
            output = self.bert(**document)
            pooling = self.pooling(output['last_hidden_state'], document['attention_mask'])
            emb = torch.cat((interaction,pooling), -1)
        else:
            emb = interaction
        scores = self.fc(emb)
        relevance_scores = self.sigmoid(scores).squeeze()
        return relevance_scores

In [57]:
class TextDataset(Dataset):
    def __init__(self, tokenizer, query, document, relevance=None, cluster=None, max_length=128):
        self.len = len(query)
        self.query = tokenizer.batch_encode_plus(
            list(query), 
            padding=True, 
            truncation=True, 
            return_tensors='pt', 
            max_length=max_length
        )
        self.document = tokenizer.batch_encode_plus(
            list(document), 
            padding=True, 
            truncation=True, 
            return_tensors='pt', 
            max_length=max_length
        )        
        if relevance is not None:
            self.relevance = torch.tensor(list(relevance))
        else:
            self.relevance = torch.tensor(np.ones(len(query)))
        if cluster is not None:
            self.cluster = torch.tensor(list(cluster))
        else:
            self.cluster = torch.tensor(np.ones(len(query)))

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        query = dict(
            input_ids=self.query['input_ids'][idx],
            attention_mask=self.query['attention_mask'][idx],
        )
        document = dict(
            input_ids=self.document['input_ids'][idx],
            attention_mask=self.document['attention_mask'][idx],
        )
        relevance = self.relevance[idx]        
        return dict(
            query = query,
            document = document,
            relevance = relevance,
        )

In [58]:
def batch_generator(sentences, batch_size):
    n_samples = len(sentences)
    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)
        yield sentences[start:end]

In [69]:
tokenizer = AutoTokenizer.from_pretrained('cointegrated/rubert-tiny2')
model = torch.load(f'{DATA_MODELS}/rubert-tiny2_1.pt')
model.to(DEVICE);

In [70]:
df = pd.read_pickle(f'{DATA_PROCESSED}/df.pickle')
df = df.query('split == "test"').reset_index(drop=True)

ds = TextDataset(
    tokenizer, 
    query=df['description_sorted_vacancy'], 
    document=df['description_sorted_resume'], 
    max_length=MAX_LEN
)

In [None]:
with torch.no_grad():
    df['predict_proba'] = np.concatenate([
        model(
            {k: v.to(DEVICE) for k, v in batch['query'].items()}, 
            {k: v.to(DEVICE) for k, v in batch['document'].items()}
        ).cpu().numpy().squeeze() 
        for batch in batch_generator(ds, 10)
    ])

In [None]:
clms = ['uuid_vacancy', 'uuid_resume', 'rank', 'label']
df['rank'] = df.groupby('uuid_vacancy')['predict_proba'].rank(method='first', ascending=False).astype(int)
df = df.sort_values('rank')
df['label'] = df.eval('predict_proba > 0.45').astype(int)

In [67]:
df['label'].mean()

0.7522123893805309

In [68]:
df[clms]

Unnamed: 0,uuid_vacancy,uuid_resume,rank,label
37,8b9c8d16-c7f0-38a2-b80c-d94030c15a6f,ebcd86ef-6e1f-39cf-8af3-85adaec6d3b3,1,1
29,8b9c8d16-c7f0-38a2-b80c-d94030c15a6f,2444cdfb-370c-3f84-b97b-9462255688f2,2,1
62,8b9c8d16-c7f0-38a2-b80c-d94030c15a6f,ff2a61d1-b70b-352b-8f08-ebc0a84de7ca,3,1
20,8b9c8d16-c7f0-38a2-b80c-d94030c15a6f,73592479-12bf-38d4-84f0-91fe33518b47,4,1
73,8b9c8d16-c7f0-38a2-b80c-d94030c15a6f,fd0ccbd0-3a58-3818-8691-98f31de17527,5,1
...,...,...,...,...
64,8b9c8d16-c7f0-38a2-b80c-d94030c15a6f,fd9c4130-177f-3546-8974-189a52fcc496,109,0
92,8b9c8d16-c7f0-38a2-b80c-d94030c15a6f,6065ee07-4c42-3cf4-8e45-9d053b9041f7,110,0
33,8b9c8d16-c7f0-38a2-b80c-d94030c15a6f,53a2a38b-4fef-3d65-8d6d-818d5e80dbfa,111,0
104,8b9c8d16-c7f0-38a2-b80c-d94030c15a6f,e59d1c07-489b-3299-803f-5dea7da43b56,112,0


In [65]:
df[clms].to_excel(f'{DATA_PROCESSED}/result.xlsx', index=False)
df[clms].to_csv(f'{DATA_PROCESSED}/result.csv', index=False)