In [4]:
from transformers import BertConfig, BertModel
import tensorflow as tf
import os

In [5]:
conpono = BertModel.from_pretrained("../weights/transformers/")
tf_path = "../weights/model.ckpt"

def get_cpc_weights(tf_checkpoint_path):
    tf_path = os.path.abspath(tf_checkpoint_path)
    print("Converting TensorFlow checkpoint from {}".format(tf_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    excluded = ["BERTAdam", "_power", "global_step", "_CHECKPOINTABLE_OBJECT_GRAPH"]
    init_vars = list(filter(lambda x: all([True if e not in x[0] else False for e in excluded]), init_vars))
    for name, shape in init_vars:
        if "cpc" not in name:
            continue
        array = tf.train.load_variable(tf_path, name)
        return array
    
cpc_weights= get_cpc_weights(tf_path)

Converting TensorFlow checkpoint from /Users/daniter/Documents/jurafsky/conpono/weights/model.ckpt


In [82]:
sents = ["First, the opening arguments is made",
        "This statement is unrelated.",
        "This statement is also unrelated but it is much longer because it has many words in it.",
        "Then, the follow up argument is made.",]

In [83]:
import torch
import numpy as np
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


def encode_sents(sents):
    # encode anchor and continuations
    anchor = sents[0]
    inputs = [anchor + " [SEP] " + s for s in sents[1:]]
    inputs = [anchor] + inputs
    tokenized_inputs = tokenizer(inputs, return_tensors="pt", add_special_tokens=True, padding=True)
    output = conpono(**tokenized_inputs)['pooler_output'].detach()
    return output

def compute_coherence(encodings, cpc_weights):
    cpc_weights = torch.tensor(cpc_weights)[4]
    anchor_transform = torch.matmul(encodings[0], cpc_weights)
    scores = torch.matmul(encodings[1:], anchor_transform)
    return scores.detach().numpy()

In [89]:
output = encode_sents(sents)
scores = compute_coherence(output, cpc_weights)
print(scores)
np.argsort(scores)[::-1]

[1.3673818 1.2968947 1.3872905]


array([2, 0, 1])

In [91]:
sents = ["The first two Sherlock Holmes stories, the novels A Study in Scarlet (1887) and The Sign of the Four (1890), were moderately well received, but Holmes first became very popular early in 1891 when the first six short stories featuring the character were published in The Strand Magazine.",
         "Holmes became widely known in Britain and America.",
         "The character was so well-known that in 1893 when Arthur Conan Doyle killed Holmes in the short story \"The Final Problem\", the strongly negative response from readers was unlike any previous public reaction to a fictional event.",
         "The Strand reportedly lost more than 20,000 subscribers as a result of Holmes's death.",
         "Public pressure eventually contributed to Conan Doyle writing another Holmes story in 1901 and resurrecting the character in a story published in 1903.",
         "In Japan, Sherlock Holmes (and Alice from Alice's Adventures in Wonderland) became immensely popular in the country in the 1890s as it was opening up to the West, and they are cited as two British fictional Victorians who left an enormous creative and cultural legacy there."]

In [92]:
output = encode_sents(sents)
scores = compute_coherence(output, cpc_weights)
print(scores)
np.argsort(scores)[::-1]

[1.1913104 1.1906443 1.1479765 1.0828052 1.1195012]


array([0, 1, 2, 4, 3])