In [22]:
import os
import json
import pickle
import random
import time
import random
from contextlib import contextmanager
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from pprint import pprint

from sklearn.feature_extraction.text import TfidfVectorizer

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from datasets import load_dataset, Dataset, concatenate_datasets, load_from_disk
from transformers import (
    AutoTokenizer,
    BertModel, BertPreTrainedModel,
    AdamW, get_linear_schedule_with_warmup,
    TrainingArguments,
)

from sentence_transformers import SentenceTransformer, InputExample, losses, util

In [5]:
datasets = load_from_disk("../data/train_dataset")
datasets['train'] = datasets['train'].remove_columns(['document_id','__index_level_0__'])
datasets['validation'] = datasets['validation'].remove_columns(['document_id','__index_level_0__'])

concat_dataset = concatenate_datasets([datasets['train'], datasets['validation']])

In [59]:
model = SentenceTransformer('jhgan/ko-sroberta-multitask', cache_folder='/data/ephemeral/senttran')

In [7]:
positive_examples = [
    InputExample(texts=[q, c], label=1.0) for q, c in zip(concat_dataset['question'], concat_dataset['context'])
]

In [13]:
from dpr import STNegativeSampling

negative_sampler = STNegativeSampling()
negative_samples = negative_sampler.get_negative_samples(queries=concat_dataset['question'], num_samples=2)

negative_examples = []
for q, c in zip(concat_dataset['question'], negative_samples):
    negative_examples.append(InputExample(texts=[q, c[0]], label=0.0))
    negative_examples.append(InputExample(texts=[q, c[1]], label=0.0))

ST Embedding pickle load.


In [15]:
train_examples = positive_examples + negative_examples

In [16]:
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

In [17]:
train_loss = losses.CosineSimilarityLoss(model)

In [18]:
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=1,
    warmup_steps=100,
)

Iteration: 100%|██████████| 786/786 [02:36<00:00,  5.01it/s]
Epoch: 100%|██████████| 1/1 [02:36<00:00, 156.77s/it]


In [60]:
eval_datasets = load_dataset('squad_kor_v1')['validation']

In [61]:
q_emb = model.encode(eval_datasets['question'])
c_emb = model.encode(eval_datasets['context'])

In [62]:
scores = util.pytorch_cos_sim(q_emb, c_emb).numpy()

In [64]:
k = 25
topk = np.argsort(scores, axis=1)[:,::-1][:,:k]

correct = 0
for i, score in enumerate(topk):
    correct += int(i in score)

print('Score:', correct/len(topk))

Score: 0.8112227225493592


In [68]:
scores.shape

(5774, 5774)

In [80]:
np.sort(scores, axis=1)[:,::-1][:,:k].shape

(5774, 25)

In [72]:
topk.tolist()

[[6,
  4,
  5,
  1,
  0,
  2,
  3,
  1154,
  1153,
  1150,
  1152,
  1155,
  1156,
  1151,
  5003,
  5007,
  5006,
  5002,
  5005,
  5008,
  5004,
  5046,
  5041,
  5042,
  5043],
 [1,
  3,
  4,
  5,
  0,
  2,
  6,
  5110,
  5113,
  5109,
  5112,
  5108,
  5111,
  5107,
  874,
  878,
  875,
  876,
  877,
  4953,
  4954,
  4952,
  4948,
  4955,
  4951],
 [3,
  2,
  5,
  1,
  4,
  6,
  0,
  1156,
  1150,
  1153,
  1152,
  1151,
  1154,
  1155,
  5043,
  5041,
  5044,
  5048,
  5042,
  5047,
  5046,
  5045,
  2544,
  2543,
  2546],
 [4,
  6,
  2,
  3,
  5,
  1,
  0,
  782,
  781,
  778,
  779,
  777,
  783,
  776,
  780,
  1820,
  1819,
  1823,
  1821,
  1818,
  1822,
  4692,
  4697,
  4693,
  4695],
 [6,
  5,
  4,
  0,
  1,
  2,
  3,
  5064,
  5066,
  5065,
  5068,
  5063,
  5067,
  5069,
  2545,
  2543,
  2544,
  2546,
  2987,
  2989,
  2984,
  2990,
  2985,
  2988,
  2986],
 [3,
  5,
  4,
  2,
  6,
  1,
  0,
  5007,
  5002,
  5004,
  5006,
  5003,
  5008,
  5005,
  5044,
  5046,
  5045

In [67]:
scores[topk].shape

(5774, 25, 5774)

In [42]:
scores[0,5074]

0.90483755

In [40]:
scores.max(axis=1)

array([0.90483755, 0.845703  , 0.91225696, ..., 0.90231717, 0.83082664,
       0.8739666 ], dtype=float32)