In [1]:
import os
import re
import logging
from tqdm.auto import tqdm

import pandas as pd

## functions

In [2]:
def get_article_path(title):
#     fp = filemap.query('filename == @title.lower().strip()')['relpath'].iloc[0]
    fp = filemap.query('filename == @title.lower().strip()')['path'].iloc[0]
#     fp = os.path.join(articles_dp, fp)
    return fp

def get_article_text(fp):
    with open(fp) as fin: text = fin.read()
    return text

In [3]:
def process_title(title: pd.Series):
    title = title.str.lower().str.strip()
    title = title.str.replace(r'([^\w\s])|_|-', '_', regex=True)
    return title

In [4]:
def two_sets_stats(arr1, arr2):
    print(f'init shapes: {len(arr1), len(arr2)}')
    s1 = set(arr1)
    s2 = set(arr2)
    print(f'unique elements: {len(s1), len(s2)}')
    print(f's1 & s2: {len(s1 & s2)}')
    print(f's1 ^ s2: {len(s1 ^ s2)}')
    print(f's1 - s2: {len(s1 - s2)}')
    print(f's2 - s1: {len(s2 - s1)}')

## data

In [5]:
data_root_dp = '/media/rtn/Windows 10/work/univier/wiki_extract/wiki_parsed'
# articles_dp = os.path.join(data_root_dp, 'articles')
# filemap = pd.read_csv(os.path.join(data_root_dp, 'filemap.csv'))
articles_dp = data_root_dp
filemap = pd.read_csv(os.path.join(data_root_dp, 'filepaths.csv'))
print(filemap.shape)

(223619, 3)


In [6]:
get_article_text(get_article_path('poetry'))

'Poetry is a type of art form and a type of literature. Poetry uses the qualities of words, in different ways, to be artistic. Poetry can be as short as a few words, or as long as a book. A poem as short as one line is called a monostich. A poem that is as long as a book is an epic. There are many "poetic forms" (forms of poetry). Some of forms are: Sonnet, Haiku, Ballad, Stev, Prose poem, Ode, Free verse, Blank verse, thematic, limerick and nursery rhymes. Poetry can be used to describe (comparing, talking about, or expressing emotion) many things. It can make sense or be nonsense, it can rhyme or not. It can have many shapes and sizes; it can be serious or funny. "To say something poetically" means to give information in an artistic way. A more modern approach is digital poetry. Computers and webtechnology is used to express poetry and make it interactive. So called interdisciplinary poetry (wich means combination of different forms of poetry) are made possible by linking the poetic 

In [7]:
labels_test = pd.read_csv('data/queries.tsv', sep='\t', header=None)
labels_test.columns = ['query', 'title']
labels_test['title'] = process_title(labels_test['title'])

print(labels_test.shape)
labels_test.head(3)

(200, 2)


Unnamed: 0,query,title
0,animals that have shells and live in water,shell__zoology_
1,how many different types of scorpions are there,scorpion
2,describe the structure of a scientific name fo...,binomial_nomenclature


In [8]:
labels_train_raw = pd.read_csv('data/train.tsv', sep='\t', header=None)
labels_train_raw.columns = ['query', 'title']
labels_train_raw['title'] = process_title(labels_train_raw['title'])

print(labels_train_raw.shape)
labels_train_raw.head(3)

(50000, 2)


Unnamed: 0,query,title
0,where does the most metabolic activity in the ...,cytoplasm
1,what kind of dog played in turner and hooch,dogue_de_bordeaux
2,when is there gonna be an eclipse 2017,solar_eclipse_of_august_21__2017


In [9]:
filemap['article_id'] = list(range(filemap.shape[0]))

In [10]:
title2id = filemap.set_index('filename')['article_id'].to_dict()
filemap.iloc[title2id['belarus']]

filename                                                belarus
path          /media/rtn/Windows 10/work/univier/wiki_extrac...
html_path     /media/rtn/Windows 10/work/univier/wiki_extrac...
article_id                                                25864
Name: 25864, dtype: object

### filter train, test samples

In [11]:
two_sets_stats(filemap['filename'], labels_train_raw['title'])

init shapes: (223619, 50000)
unique elements: (223619, 16731)
s1 & s2: 15430
s1 ^ s2: 209490
s1 - s2: 208189
s2 - s1: 1301


In [12]:
two_sets_stats(filemap['filename'], labels_test['title'])

init shapes: (223619, 200)
unique elements: (223619, 142)
s1 & s2: 137
s1 ^ s2: 223487
s1 - s2: 223482
s2 - s1: 5


In [13]:
labels_train_raw = labels_train_raw[labels_train_raw['title'].isin(filemap['filename'])]
labels_test = labels_test[labels_test['title'].isin(filemap['filename'])]
print(labels_train_raw.shape[0])
print(labels_test.shape[0])

45260
194


## create dataset

In [14]:
import random
import math

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'using device: {device}')

from torch.utils.data import Dataset, DataLoader

import sentence_transformers as st
import sentence_transformers.losses
from sentence_transformers.datasets import NoDuplicatesDataLoader
from sentence_transformers.readers import InputExample

from sklearn.model_selection import train_test_split

from typing import List, Union

using device: cuda


In [15]:
labels_train, labels_val = train_test_split(labels_train_raw, random_state=12, test_size=0.15, shuffle=True)
print(labels_train.shape)
print(labels_val.shape)

(38471, 2)
(6789, 2)


In [16]:
labels_train.head()

Unnamed: 0,query,title
38886,a sea route through the arctic ocean in canada...,northwest_passage
15973,who plays angie on the george lopez show,constance_marie
46486,who plays queen cersei on game of thrones,lena_headey
7715,who plays the dolphin in the spongebob movie,the_spongebob_movie__sponge_out_of_water
39805,the mask of the red death by edgar allan poe,the_masque_of_the_red_death


In [17]:
class WikiQAInputExample(InputExample):
    def __init__(self, guid: str = '', texts: List[str] = None,  label: Union[int, float] = 0, article_id: int = None):
        self.article_id = article_id
        super().__init__(guid, texts, label)

In [18]:
class WikiQADataset(Dataset):
    def __init__(self, examples: pd.DataFrame, title2id):
        self.examples = examples
        self.title2id = title2id
        assert isinstance(examples, pd.DataFrame)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, ix):
        assert 0 <= ix < len(self)
        query, article_title = self.examples.iloc[ix][['query', 'title']]
        article = get_article_text(get_article_path(article_title))
        example = WikiQAInputExample(texts=[query, article], label=1, article_id=self.title2id[article_title])
        return example

In [19]:
class CustomNoDuplicatesDataLoader:

    def __init__(self, dataset: WikiQADataset, batch_size):
        """
        A special data loader to be used with MultipleNegativesRankingLoss.
        The data loader ensures that there are no duplicate sentences within the same batch
        """
        self.batch_size = batch_size
        self.collate_fn = None
        self.dataset = dataset
        self.data_pointer = 0
        self.index_order = list(range(len(dataset)))
        random.shuffle(self.index_order)  # shuffle index inplace before the first iteration

    def __iter__(self):
        for _ in range(self.__len__()):
            batch = []
            texts_in_batch = set()

            while len(batch) < self.batch_size:
                ix = self.index_order[self.data_pointer]
                example = self.dataset[ix]

                valid_example = True
                for text in example.texts:
                    if text.strip().lower() in texts_in_batch:
                        valid_example = False
                        break

                if valid_example:
                    batch.append(example)
                    for text in example.texts:
                        texts_in_batch.add(text.strip().lower())

                self.data_pointer += 1
                if self.data_pointer >= len(self.dataset):
                    self.data_pointer = 0
                    random.shuffle(self.index_order)  # reshuffle index order

            yield self.collate_fn(batch) if self.collate_fn is not None else batch

    def __len__(self):
        return math.floor(len(self.dataset) / self.batch_size)

In [21]:
train_dataset = WikiQADataset(labels_train, title2id=title2id)
# val_dataset = WikiQADataset(labels_val, title2id=title2id)
# print(f'train, val lenghts: {len(train_dataset), len(val_dataset)}')

In [22]:
train_dataset[3].texts

['who plays the dolphin in the spongebob movie',
 'The SpongeBob Movie: Sponge Out of Water is a 2015 3D animated/live action superhero comedy movie based on Nickelodeon\'s SpongeBob SquarePants television series. It is preceded by The SpongeBob SquarePants Movie. It was produced by Nickelodeon Movies and Paramount Animation, and was distributed by Paramount Pictures. It was released theatrically on February 6, 2015. The movie released February 6, 2015 on HD Digital. It was nominated in the 2015 Kids\' Choice Awards for "Favorite Animated Movie", but lost to Disney\'s 54th feature film Big Hero 6. A diabolical pirate named Burger Beard above the sea steals the secret Krabby Patty formula. SpongeBob, Patrick, Mr. Krabs, Sandy and Squidward must team up with Plankton in order to get it back unbounce. Tom Kenny as SpongeBob SquarePants/Invincibubble and Gary the Snail Bill Fagerbakke as Patrick Star/Mr. Superawesomeness Mr. Lawrence as Plankton/Plank-Ton Clancy Brown as Mr. Krabs/Sir Pinc

In [24]:
# TODO: has max_seq_length: 128
model = st.SentenceTransformer('distilbert-base-nli-mean-tokens')
model.to(device);

In [26]:
train_loss = st.losses.MultipleNegativesRankingLoss(model=model)

In [29]:
train_batch_size = 32
num_epochs = 2
warmup_steps = math.ceil(len(train_dl) * num_epochs * 0.1)
model_save_path = '/media/rtn/data/fajly2/checkpoints/bert'

train_dl = CustomNoDuplicatesDataLoader(train_dataset, batch_size=train_batch_size)

print(f'train_batch_size: {train_batch_size}')
print(f"num_epochs: {num_epochs}")
print(f"Warmup-steps: {warmup_steps}")

train_batch_size: 32
num_epochs: 2
Warmup-steps: 962


In [30]:
model.fit(
    train_objectives=[(train_dl, train_loss)],
#     evaluator=dev_evaluator,
    epochs=num_epochs,
#     evaluation_steps=int(len(train_dl)*0.1),
    warmup_steps=warmup_steps,
    output_path=model_save_path,
    use_amp=True
)

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1202 [00:00<?, ?it/s]

KeyboardInterrupt: 