In [1]:
from dotenv import load_dotenv

load_dotenv("../.env.prod")
import os
from utils.chunker import chunker, chunk_single_text
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import pandas as pd

import json
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ds = pd.read_parquet("../files/processed/final_datasets/train.parquet")
ds.columns

Index(['candidate_id', 'vacant_id', 't_apply', 'stage_max', 'publish_date',
       'label', 'vacant_full_text', 'vacant_city_ids', 'candidate_full_text',
       'candidate_city_id', 'candidate_fourier_features',
       'no_valid_vacant_city_ids', 'selected_city_id', 'selected_distance',
       'exact_match', 'vacant_fourier_feature', 'job_chunks_input_ids',
       'job_chunks_attention_mask', 'cand_chunks_input_ids',
       'cand_chunks_attention_mask'],
      dtype='object')

In [3]:
ds["label_2"]=ds["label"].apply(lambda x: -1 if x==0 else x)


In [4]:
ds.columns

Index(['candidate_id', 'vacant_id', 't_apply', 'stage_max', 'publish_date',
       'label', 'vacant_full_text', 'vacant_city_ids', 'candidate_full_text',
       'candidate_city_id', 'candidate_fourier_features',
       'no_valid_vacant_city_ids', 'selected_city_id', 'selected_distance',
       'exact_match', 'vacant_fourier_feature', 'job_chunks_input_ids',
       'job_chunks_attention_mask', 'cand_chunks_input_ids',
       'cand_chunks_attention_mask', 'label_2'],
      dtype='object')

In [5]:
sums = ds.groupby("vacant_id")["label_2"].sum()


In [6]:
# one row per vacant_id with counts of each label
cnt = (ds
       .groupby(['vacant_id', 'label'])
       .size()                      # count rows
       .unstack(fill_value=0)       # columns 0 and 1
       .rename(columns={0: 'neg_candidate', 1: 'pos_candidate'})
       .reset_index())

print(cnt.head())


label  vacant_id  neg_candidate  pos_candidate
0          57283              3              0
1          92879             20              0
2         136132             40              0
3         138383              1              0
4         144608              2              0


In [7]:
cnt_valid = cnt[(cnt['pos_candidate'] > 0) & (cnt['neg_candidate'] > 0)]
cnt_valid.sum()

label
vacant_id        8516840831
neg_candidate        521200
pos_candidate        271782
dtype: int64

In [9]:
sums.groupby("vacant_id").sum()

vacant_id
57283     -3
92879    -20
136132   -40
138383    -1
144608    -2
          ..
420718    -1
420728     1
420730    -1
420755     1
420759     1
Name: label_2, Length: 67936, dtype: int64

In [10]:
import numpy as np
import pandas as pd

# start from ds
df = ds.copy()

# keep only 0/1 labels just in case
df = df[df['label'].isin([0, 1])].copy()

pairs = []

for vac_id, g in df.groupby('vacant_id'):
    pos = g[g['label'] == 1].sort_values('t_apply')
    neg = g[g['label'] == 0].sort_values('t_apply')
    
    # skip vacancies with no pos or no neg
    if pos.empty or neg.empty:
        continue
    
    pos = pos.reset_index(drop=True)
    neg = neg.reset_index(drop=True)
    
    n_pos = len(pos)
    n_neg = len(neg)
    

    neg_idx = np.tile(np.arange(n_neg), int(np.ceil(n_pos / n_neg)))[:n_pos]
    neg_rep = neg.iloc[neg_idx].reset_index(drop=True)
    
    pair = pd.DataFrame({
        'vacant_id': vac_id,
        'pos_candidate_id': pos['candidate_id'].values,
        'neg_candidate_id': neg_rep['candidate_id'].values,
    })
    

    other_cols = [c for c in df.columns if c not in ['candidate_id', 'vacant_id']]
    
    for col in other_cols:
        pair['pos_' + col] = pos[col].values
        pair['neg_' + col] = neg_rep[col].values
    
    pairs.append(pair)

pairs_df = pd.concat(pairs, ignore_index=True)


In [11]:
pairs_df.to_parquet("../files/processed/paired_datasets/train.parquet")

In [14]:
pairs_df[[ 'vacant_id', 'pos_candidate_id', 'neg_candidate_id',  'pos_cand_chunks_input_ids', 'neg_cand_chunks_input_ids',
       'pos_cand_chunks_attention_mask', 'neg_cand_chunks_attention_mask', 'pos_job_chunks_input_ids', 'neg_job_chunks_input_ids',]]

Unnamed: 0,vacant_id,pos_candidate_id,neg_candidate_id,pos_cand_chunks_input_ids,neg_cand_chunks_input_ids,pos_cand_chunks_attention_mask,neg_cand_chunks_attention_mask,pos_job_chunks_input_ids,neg_job_chunks_input_ids
0,156860,5429402,5582321,"[[0, 8919, 2282, 3634, 21201, 2084, 4376, 1622...","[[0, 25180, 11117, 18174, 2854, 9534, 6899, 66...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[0, 1071, 4899, 5852, 2084, 4301, 2894, 7032,...","[[0, 1071, 4899, 5852, 2084, 4301, 2894, 7032,..."
1,156860,6934742,6815063,"[[0, 8919, 2282, 8717, 2084, 4454, 2393, 4376,...","[[0, 16220, 16647, 7860, 2103, 1065, 2526, 855...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[0, 1071, 4899, 5852, 2084, 4301, 2894, 7032,...","[[0, 1071, 4899, 5852, 2084, 4301, 2894, 7032,..."
2,187799,2719424,7497666,"[[0, 17820, 25782, 18174, 9534, 9831, 3531, 43...","[[0, 7574, 2725, 1014, 25180, 13727, 29612, 13...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[0, 9037, 21868, 18706, 14266, 2116, 2067, 21...","[[0, 9037, 21868, 18706, 14266, 2116, 2067, 21..."
3,187799,2434828,3963750,"[[0, 8919, 2282, 8717, 2084, 4376, 2008, 2233,...","[[0, 25180, 8919, 2282, 3634, 21201, 2054, 437...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[0, 9037, 21868, 18706, 14266, 2116, 2067, 21...","[[0, 9037, 21868, 18706, 14266, 2116, 2067, 21..."
4,187799,6409856,24464,"[[0, 2937, 18517, 2084, 1045, 10412, 2084, 587...","[[0, 2008, 27874, 10115, 4376, 3584, 28998, 51...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[0, 9037, 21868, 18706, 14266, 2116, 2067, 21...","[[0, 9037, 21868, 18706, 14266, 2116, 2067, 21..."
...,...,...,...,...,...,...,...,...,...
271777,420679,188031,5685341,"[[0, 17820, 25782, 18174, 9534, 23717, 6636, 4...","[[0, 25180, 1045, 17607, 3740, 3531, 2143, 137...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[0, 28620, 3262, 3976, 21000, 2144, 1028, 201...","[[0, 28620, 3262, 3976, 21000, 2144, 1028, 201..."
271778,420679,4458792,5652823,"[[0, 17820, 25782, 22848, 5816, 21339, 10450, ...","[[0, 11712, 2054, 4017, 9052, 2867, 18503, 246...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[0, 28620, 3262, 3976, 21000, 2144, 1028, 201...","[[0, 28620, 3262, 3976, 21000, 2144, 1028, 201..."
271779,420697,6251487,1057116,"[[0, 8919, 2282, 3634, 21201, 2084, 4376, 2030...","[[0, 11272, 2233, 19305, 5816, 21339, 3531, 10...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[0, 7865, 28998, 2054, 2132, 15553, 7873, 317...","[[0, 7865, 28998, 2054, 2132, 15553, 7873, 317..."
271780,420759,3693274,8070994,"[[0, 6743, 2084, 4376, 5712, 2054, 1065, 16095...","[[0, 25180, 8919, 2282, 12786, 7371, 27394, 27...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...","[[0, 2065, 15534, 6587, 2550, 6587, 2550, 1014...","[[0, 2065, 15534, 6587, 2550, 6587, 2550, 1014..."
