<a href="https://colab.research.google.com/github/lamantinushka/cam_summarisation/blob/master/BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Introduction
In this part of the project we investigate potential of BERT architecture in unsupervised summarization task. Similar to GPT-2 being a sequence of transformer decoders, BERT is a self-attentive models consisting of sequentially stacked transformer encoders. There were two versions of BERT released: BERT base (110M parameters) and BERT large (340M parameters). Both of them were pretrained for two tasks: Masked Language Modeling and Next Sentence Prediction on a corpus of ~ 3 Bln words (English books + English Wikipedia). 

In this project we will use base version consisting of 12 transformer blocks with 12 self-attention heads each. It is suggested to fill in the gaps in tokens sequences of form \<Object1 -GAP1- Aspect -GAP2- Object2 -GAP3-\>. Where each GAP is a sequence of [MASK] tokens and number of masks is chosen uniformaly from 1 to 3. 

Original BERT model is proposed to be for the task for the task, but in our case it may not be relevant for several reasons:
- Our aim is to extract of BERT general knowledge about language structure, because task-specific information is already given by the objects and aspects.
- BERT domain is already wide as it was trained on English Wikipedia.
- Dataset of sentences from CAM answers is pretty small and as a result there is a high risk of overfitting. Having 110M parameters BERT will easily restore sentences from the training set.

That is why we use 'bert-base-uncased' model with no finetuning. 

We use three techniques for decoding:
- hard decoding (for every MASK choose the most probable token)
- soft decoding (for every MASK choose token from the predicted distribution)
- iterative decoding (on each step fill only one gap with the most probable token.)

# Loading data

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
import json

In [0]:
data = []
for line in open('drive/My Drive/summarization/mined_bow_str.json', 'r'):
    data.append(json.loads(line))

In [0]:
# usefull function that extracts all supporting sentences from the CAM output
def write_sentences(sample, sentences = None):
    if sentences is None:
        sentences = []
    for s in sample['object1']['sentences']:
        sentences.append(s['text'] + '\n')
    for s in sample['object2']['sentences']:
        sentences.append(s['text'] + '\n')
    return sentences

Let's choose some samples to compare model's performance

In [5]:
samples = [data[4], data[155], data[228]]

for s in samples:
  print(s['object1']['name'] + ' vs ' + s['object2']['name'])

python vs java
toyota vs nissan
tea vs juice


In [6]:
for s in samples:
  print(s['object1']['name'], ':', s['extractedAspectsObject1'])
  print(s['object2']['name'], ':', s['extractedAspectsObject2'])
  print()

python : ['simpler', 'older', 'easier to program in', 'bigger', 'libraries', 'higher', 'easier', 'faster to code', 'closer to python', 'easier to read']
java : ['higher for java', 'closer to java', 'longer', 'faster', 'stronger']

toyota : ['gm', 'easier', 'car', 'se', 'longer', 'veichles', 'smarter', 'easier to park', 'corners', 'dealt']
nissan : ['horsepower', 'faster', 'stronger', 'latecomer', 'quality', 'reputation', 'greater', 'wiser']

tea : ['cheaper']
juice : ['healthier']



# Decoding techniques

In [0]:
def gen_masked_input(object1, object2, aspect, obj1 = True):
    if obj1:
        object1_ = torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(object1)))
        object2_ = torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(object2)))
    else:
        object1_ = torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(object2)))
        object2_ = torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(object1)))
    aspect = torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(aspect)))
    n1 = np.random.randint(1, 4)
    n2 = np.random.randint(1, 4)
    n3 = np.random.randint(1, 4)
    l1 = len(object1_)
    l2 = len(object2_)
    l3 = len(aspect)
    length = n1 + n2 + n3 + l1 + l2 + l3
    result = 103 * torch.ones(2 + length, dtype = torch.long).cuda()
    result[0] = 101
    result[-1] = 102
    result[1:l1 + 1] = object1_
    result[l1 + n1 + 1: l1 + n1 + l3 + 1] = aspect
    result[l1 + n1 + l3 + n2 + 1: -1] = object2_
    return result.view(1, -1)

In [0]:
def decode_hard(model, input):
    result = input.view(-1)
    with torch.no_grad():
        preds = model(input).view(-1, 30522)
    mask = input == 103
    idxes = mask.nonzero().view(-1)
    for idx in idxes:
        result[idx] = torch.argmax(preds[idx])
    return result[1:-1]

In [0]:
def decode_soft(model, input, temperature = 1.0):
    result = input.view(-1)
    with torch.no_grad():
        preds = model(input).view(-1, 30522)
    mask = input == 103
    idxes = mask.nonzero().view(-1)
    for idx in idxes:
        word_weights = preds[idx].div(temperature).exp().cpu() 
        pred = torch.multinomial(word_weights, 1)[0].cuda()
        result[idx] = pred
    return result[1:-1]

In [0]:
def decode_iteratively(model, begining, maxlen = 20, temperature = 1.0):
    result = input.view(-1)
    mask = input == 103
    idxes = mask.nonzero().view(-1).cpu().numpy()
    np.random.shuffle(idxes)
    for idx in idxes:
        with torch.no_grad():
            preds = model(result.view(1, -1)).view(-1, 30522)
            result[idx] = torch.argmax(preds[idx])
    return result[1:-1]

#BERT

In [0]:
!pip install pytorch-pretrained-bert

In [0]:
import numpy as np
import pandas as pd
import torch
from pytorch_pretrained_bert import *
from torch.nn import CrossEntropyLoss, KLDivLoss
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [9]:
model = BertForMaskedLM.from_pretrained('bert-base-uncased').cuda()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

100%|██████████| 407873900/407873900 [00:13<00:00, 29950466.41B/s]
100%|██████████| 231508/231508 [00:00<00:00, 907920.79B/s]


In [10]:
tokenizer.convert_tokens_to_ids(['[SEP]'])

[102]

In [15]:
decode = decode_hard

for s in samples:
    name1 = s['object1']['name']
    name2 = s['object2']['name']
    aspects = {a: True for a in s['extractedAspectsObject1']}
    for a in s['extractedAspectsObject2']:
        aspects[a] = False

    sentences = []
    for a in aspects:
        input = gen_masked_input(name1, name2, a, obj1 = aspects[a])
        result = decode(model, input)
        tokens = tokenizer.convert_ids_to_tokens(result.view(-1).detach().cpu().numpy())
        sentences.append(' '.join(tokens).replace(' ##', ''))
  
    print(name1 + ' vs ' + name2)
    for sent in sentences:
        print(sent)
    print()

python vs java
python ( simpler java java python java java
python java is the older versions versions of java java
python is easier to program in parallel ) java java java
python is is bigger than java java
python java java libraries java java java java java java java
python java java higher ) java ) java java java java
python is easier than java java java java
python is is is faster to code than ) java java java
python ( is closer to python ) java java java java
python is is easier to read than java java java
java and higher for java and python python
java is closer to java . . : python python python
java ( ( ( longer ) python python python
java ( ( , faster ) python python
java ( ( ( stronger ) ) python python python python

toyota vs nissan
toyota nissan nissan gm nissan nissan nissan nissan nissan
toyota - easier to nissan nissan nissan nissan nissan
toyota toyota nissan car nissan nissan nissan nissan nissan nissan nissan
toyota nissan se nissan nissan nissan nissan nissan nissan 

In [16]:
decode = decode_soft

for s in samples:
    name1 = s['object1']['name']
    name2 = s['object2']['name']
    aspects = {a: True for a in s['extractedAspectsObject1']}
    for a in s['extractedAspectsObject2']:
        aspects[a] = False

    sentences = []
    for a in aspects:
        input = gen_masked_input(name1, name2, a, obj1 = aspects[a])
        result = decode(model, input)
        tokens = tokenizer.convert_ids_to_tokens(result.view(-1).detach().cpu().numpy())
        sentences.append(' '.join(tokens).replace(' ##', ''))
  
    print(name1 + ' vs ' + name2)
    for sent in sentences:
        print(sent)
    print()

python vs java
python ( ( the simpler . in java java
python ( ( & older ) classes java java java
python , ( is easier to program in c java java java java
python web java python bigger ; java java java java java
python java extended python libraries . java java java java
pythoned @ q higher type java java
python c ; easier than java java java java
python java these is faster to code python source java java java java
python is is closer to python : - java java java java
python almost is easier to read ) java java java
java is higher for java , : in python python python
java ( is closer to java ) python python python python
java no longer than python python python
java ( faster schedule denim python python python
java is ( stronger ) python python

toyota vs nissan
toyota columbia chrysler gm nissan nissan nissan nissan
toyota suspension is - easier than the nissan nissan nissan nissan
toyota stock car nissan bmw nissan nissan
toyota nissan seso hyundai nissan nissan nissan
toyota cunning

In [17]:
decode = decode_iteratively

for s in samples:
    name1 = s['object1']['name']
    name2 = s['object2']['name']
    aspects = {a: True for a in s['extractedAspectsObject1']}
    for a in s['extractedAspectsObject2']:
        aspects[a] = False

    sentences = []
    for a in aspects:
        input = gen_masked_input(name1, name2, a, obj1 = aspects[a])
        result = decode(model, input)
        tokens = tokenizer.convert_ids_to_tokens(result.view(-1).detach().cpu().numpy())
        sentences.append(' '.join(tokens).replace(' ##', ''))
  
    print(name1 + ' vs ' + name2)
    for sent in sentences:
        print(sent)
    print()

python vs java
python ( simpler language ) java java java
python java ( and older ) java java java
python is also much easier to program in . net java java java
python java bigger java java java java java java java
python . java libraries . java java
python java higher . python python java java
python java easier than java java java java
python java is much faster to code in . java java java java
python is closer to python than to java java java
python . java is easier to read than . java java
javascript or higher for javascript . python python python
java is also much closer to java . python python python python python python
java ( longer ) python python python python
java ( is faster ) . python python python
java python python - stronger java python python python python

toyota vs nissan
toyota nissan gm nissan gm nissan nissan nissan nissan
toyota nissan ( much easier ) nissan nissan nissan nissan
toyota kei - car nissan nissan nissan nissan nissan nissan
toyota nissan nissan se ni

# Outline
- Interative generation gives relatively consistent sentences, but they are not very distinct and probably can be generated using templates.
- Probably 3 gaps is too much

## TBD:
- finetuning for MLM task on CAM outputs (?)
- some beam-search technique for iterative generation