In [1]:
import csv
from charity import Charity

tsv_path = 'data/wikipedia.tsv'

with open(tsv_path, 'r') as tsv_file:
    reader = csv.DictReader(tsv_file, dialect='excel-tab')
    charities = [
        Charity(**row)
        for row in reader
    ]

In [2]:
documents = [charity.description for charity in charities]

In [11]:
import torch

query = torch.arange(3 * 2).view(3, 2)
positive_examples = torch.arange(3 * 2 * 2).view(3, 2, 2)
negative_examples = torch.arange(3 * 4 * 2).view(3, 4, 2)

In [44]:
torch.cat([query.repeat(1, 4).view(3, 4, 2), negative_examples], 2).reshape(-1, 2 * 2)

tensor([[ 0,  1,  0,  1],
        [ 0,  1,  2,  3],
        [ 0,  1,  4,  5],
        [ 0,  1,  6,  7],
        [ 2,  3,  8,  9],
        [ 2,  3, 10, 11],
        [ 2,  3, 12, 13],
        [ 2,  3, 14, 15],
        [ 4,  5, 16, 17],
        [ 4,  5, 18, 19],
        [ 4,  5, 20, 21],
        [ 4,  5, 22, 23]])

In [45]:
import re

from nltk.tokenize import sent_tokenize
from bert import get_embeddings

embeddings_index = []
embeddings_list = []

for doc_index, doc in enumerate(documents):
    print(doc_index)
    sentences = [
        re.sub('\[\d+\]', '', sent).strip().lower()
        for sent in sent_tokenize(doc)
    ]
    sentence_embeddings = get_embeddings(sentences)
    
    embeddings_index += ([doc_index] * len(sentence_embeddings))
    embeddings_list.append(sentence_embeddings)

0
1


KeyboardInterrupt: 

In [64]:
import torch

embeddings = torch.cat(embeddings_list)

In [80]:
import random

import numpy as np

def sample_indices(index, n_positive, n_negative, n_examples):
    all_query_examples = []
    all_positive_examples = []
    all_negative_examples = []
    
    unique_elements = np.unique(index)
    index_indices = torch.arange(len(index))
    
    for _ in range(n_examples):
        elem = random.choice(unique_elements)
        
        elem_indices = index_indices[index == elem]
        query_index = random.choice(elem_indices)
        
        same_document_indices = index_indices[(index == elem)]
        different_document_indices = index_indices[index != elem]
        
        positive_examples = same_document_indices[random.sample(range(0, len(same_document_indices)), n_positive)]
        negative_examples = different_document_indices[random.sample(range(0, len(different_document_indices)), n_negative)]

        all_query_examples.append(query_index)
        all_positive_examples.append(positive_examples)
        all_negative_examples.append(negative_examples)

    query_indices = torch.tensor(all_query_examples)
    positive_example_indices = torch.stack(all_positive_examples)
    negative_example_indices = torch.stack(all_negative_examples)
    
    return query_indices, positive_example_indices, negative_example_indices

In [97]:
import torch.nn as nn

# model
input_size = 768
output_size = 100
linear_model = nn.Linear(input_size, output_size)

torch.nn.init.xavier_uniform_(linear_model.weight)
linear_model.bias.data.fill_(0.1)
    
loss_function = torch.nn.MarginRankingLoss(margin=1)
optimizer = torch.optim.Adam(linear_model.parameters(), lr=0.01)
cos = nn.CosineSimilarity(dim=2, eps=1e6)

In [98]:
shuffle_now = False

if shuffle_now:
    shuffled_indices = np.arange(len(embeddings_index))
    np.random.shuffle(shuffled_indices)
    shuffled_indices

In [99]:
def compute_loss(embeddings, embeddings_index, n_positive, n_negative, n_examples):
    query_indices, positive_example_indices, negative_example_indices = sample_indices(
        embeddings_index,
        n_positive=n_positive,
        n_negative=n_negative,
        n_examples=n_examples,
    )
    
    model_embeddings = linear_model(embeddings)
    embedded_query = model_embeddings[query_indices]
    embedded_positive = model_embeddings[positive_example_indices]
    embedded_negative = model_embeddings[negative_example_indices]
    
    positive_similarity = cos(embedded_query.view(-1, 1, output_size), embedded_positive)    
    negative_similarity = cos(embedded_query.view(-1, 1, output_size), embedded_negative)

    positive_flat = positive_similarity.view(-1, 1).repeat(1, n_negative).view(-1)
    negative_flat = negative_similarity.repeat(1, n_positive).view(-1)
    
    all_ones = torch.ones(len(positive_flat))

    return loss_function(positive_flat, negative_flat, all_ones)

In [100]:
n_examples = 5000
n_positive = 1
n_negative = 10

train_split = 0.6
split_index = int(len(shuffled_indices) * train_split)
train_indices = shuffled_indices[:split_index]
validation_indices = shuffled_indices[split_index:]

train_embeddings_index = torch.tensor(embeddings_index)[train_indices]
validation_embeddings_index = torch.tensor(embeddings_index)[validation_indices]

for epoch in range(1000):
    train_loss = compute_loss(embeddings, train_embeddings_index, n_positive, n_negative, n_examples)
    
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
    
    with torch.no_grad():
        validation_loss = compute_loss(
            embeddings,
            validation_embeddings_index,
            n_positive,
            n_negative,
            n_examples,
        )
    
    print('epoch: {} train loss: {} validation loss: {}'.format(
        epoch,
        train_loss,
        validation_loss,
    ))

epoch: 0 train loss: 0.9999994039535522 validation loss: 0.9999371767044067
epoch: 1 train loss: 0.9999579787254333 validation loss: 0.9999967813491821
epoch: 2 train loss: 0.9999977350234985 validation loss: 0.9999866485595703
epoch: 3 train loss: 0.9999920129776001 validation loss: 0.9999663829803467
epoch: 4 train loss: 0.9999759197235107 validation loss: 0.9999755620956421
epoch: 5 train loss: 0.9999852180480957 validation loss: 0.9999975562095642
epoch: 6 train loss: 0.9999985694885254 validation loss: 0.9999995231628418
epoch: 7 train loss: 0.9999997019767761 validation loss: 0.9999879002571106
epoch: 8 train loss: 0.9999924898147583 validation loss: 0.9999814629554749
epoch: 9 train loss: 0.9999879598617554 validation loss: 0.9999890327453613
epoch: 10 train loss: 0.9999915361404419 validation loss: 0.9999983310699463
epoch: 11 train loss: 0.9999985694885254 validation loss: 0.9999999403953552
epoch: 12 train loss: 0.9999999403953552 validation loss: 0.9999967813491821
epoch: 13

KeyboardInterrupt: 

In [None]:
from torch.nn.functional import cosine_similarity

search = "ms"

cos_out = nn.CosineSimilarity(dim=1, eps=1e6)

embedded = get_embeddings([search])
processed = linear_model(embedded)
processed_embeddings = linear_model(embeddings)

similarities = cosine_similarity(
    processed.view(-1, output_size),
    processed_embeddings,
    dim=1)

top_n = similarities.argsort().tolist()[::-1][:5]
doc_indices = torch.tensor(embeddings_index)[top_n]
[charities[i].name for i in doc_indices]

In [12]:
embedded.shape

torch.Size([1, 768])

In [13]:
embeddings.shape

torch.Size([336, 768])

In [17]:
processed_embeddings.shape

torch.Size([336, 10])