# Introduction

<center><h3>**Welcome to the Summarization Notebook.**</h3></center>

In this assignment, you are going to train a neural network to summarize news articles.
Your neural network is going to learn from example, as we provide you with (article, summary) pairs.
We provide you with a **toy dataset** made of only articles about police related news.
Usual datasets can be 20x larger in size, but we have reduced it for computational purposes.

You will do this using a Transformer network, from the __[Attention is all you need](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)__ paper.
In this assignment you will:
- Learn to process text into sub-word tokens, to avoid fixed vocabulary sizes, and UNK tokens.
- Implement the key conceptual blocks of a Transformer.
- Use a Transformer to read a news article, and produce a summary.
- Perform operations on learned word-vectors to examine what the model has learned.

    
** Before you start **

You should read the Attention is all you need paper.
We are providing you with skeleton code for the Transformer, but there will have to implement 5 conceptual blocks of the transformer yourself:
-  AttentionQKV: the Query, Key, Value attention mechanism at the center of the Transformer
- MultiHeadAttention: the multiple heads that enable each input to attend at many places at once.
- PositionEmbedding: the sinusoid-based position embedding of the Transformer.
- Encoder & Decoder: The encoder (that reads inputs, such as news articles), the decoder (that produces the output summary, one token at a time)
- Full Transformer: piecing it all together.

All dataset files should be placed in the `dataset/` folder of this assignment.

If you are using Google Colab, follow the instructions to mount your Google Drive onto the remote machine.

# Library imports

In [1]:
!pip install segtok
!pip install sentencepiece

Collecting segtok
  Downloading segtok-1.5.11-py3-none-any.whl (24 kB)
Installing collected packages: segtok
Successfully installed segtok-1.5.11
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 40.5 MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.96


Run the first of the following two cells if you are running the homework locally, and run the second cell if you are running the homework in Colab

In [2]:
DRIVE=False
root_folder = ""
dataset_folder = "dataset/"

In [3]:
from google.colab import drive
drive.mount('/content/drive')
root_folder = "/content/drive/My Drive/Sergey Levine - Deep Learning/hw3_public-master/"
dataset_folder = "/content/drive/My Drive/Sergey Levine - Deep Learning/hw3_public-master/dataset/"

Mounted at /content/drive


In [4]:
%cd /content/drive/My Drive/Sergey Levine - Deep Learning/hw3_public-master/

/content/drive/My Drive/Sergey Levine - Deep Learning/hw3_public-master


In [5]:
%ls

'1 Language Modeling.ipynb'        prepare_submission.sh
'2 Summarization.ipynb'            [0m[01;34m__pycache__[0m/
'3 Knowledge Distillation.ipynb'   README.md
 [01;34mbest_models[0m/                      requirements.txt
 capita.py                         [01;34msubmission_logs[0m/
 [01;34mdataset[0m/                          transformer_attention.py
 download_data.sh                  [01;34mtransformer_checks[0m/
 [01;34mkd_checks[0m/                        transformer.py
 kd_loss.py                        transformer_utils.py
 language_model.py                 utils.py
 [01;34mmodels[0m/


In [6]:
!bash download_data.sh

Downloading data... Please wait, this might take a while...
--2022-03-07 13:11:49--  https://bcourses.berkeley.edu/files/74751488/download?download_frd=1
Resolving bcourses.berkeley.edu (bcourses.berkeley.edu)... 52.70.183.143, 52.7.69.253, 52.204.28.3
Connecting to bcourses.berkeley.edu (bcourses.berkeley.edu)|52.70.183.143|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://a1072-74751488.cluster71.canvas-user-content.com/files/1072~74751488/download?download_frd=1 [following]
--2022-03-07 13:11:50--  https://a1072-74751488.cluster71.canvas-user-content.com/files/1072~74751488/download?download_frd=1
Resolving a1072-74751488.cluster71.canvas-user-content.com (a1072-74751488.cluster71.canvas-user-content.com)... 3.212.179.111, 34.229.37.149, 3.211.85.118
Connecting to a1072-74751488.cluster71.canvas-user-content.com (a1072-74751488.cluster71.canvas-user-content.com)|3.212.179.111|:443... connected.
HTTP request sent, awaiting response... 302 Found
Lo

In [7]:
# This cell autoreloads the notebook when you change you python file code.
# If you think the notebook did not reload, rerun this cell.
%load_ext autoreload
%autoreload 2

In [8]:
import os
import sys
sys.path.append(root_folder)
#from transformer import Transformer
import sentencepiece as spm
import torch as th
from torch import nn
from torch.nn import functional as F
from torch import optim
import numpy as np
import json
import capita
import os
from transformer_utils import set_device
import gc
from utils import validate_to_array, model_out_to_list

device = th.device('cpu')
list_to_device = lambda th_obj: [tensor.to(device) for tensor in th_obj]

In [9]:
# Load the word piece model that will be used to tokenize the texts into
# word pieces with a vocabulary size of 10000
sp = spm.SentencePieceProcessor()
sp.Load(root_folder+"dataset/wp_vocab10000.model")

vocab = [line.split('\t')[0] for line in open(root_folder+"dataset/wp_vocab10000.vocab", "r")]
pad_index = vocab.index('#')

def pad_sequence(numerized, pad_index, to_length):
    pad = numerized[:to_length]
    padded = pad + [pad_index] * (to_length - len(pad))
    mask = [w != pad_index for w in padded]
    return padded, mask

# Building blocks of a Transformer


**TODO**:

Implement the 5 blocks of the Transformer. In order to finish this section, you should get very small error <1e-7 on each of the 5 checks in this section.


The Transformer is split into 3 files: transformer_attention.py, transformer_utils.py and transformer.py

Each section below gives you directions and a way to verify your code works properly.

You do not need to modify the rest of the code provided, but should read it to understand overall architecture.

Our Transformer is built as a Pytorch model, a standard that is good for you to get accustomed to.



## (1) Implementing the Query-Key-Value Attention (AttentionQKV)

This part is located in AttentionQKV in transformer_attention.py. You must implement the call function of the class.
You will need to implement the mathematical procedure of AttentionQKV that is described in the [Attention is all you need paper](https://arxiv.org/pdf/1706.03762.pdf).

In [83]:
from transformer_attention import AttentionQKV

batch_size = 2;
n_queries = 3;
n_keyval = 5;
depth_k = 2;
depth_v = 2

with open(root_folder+"transformer_checks/attention_qkv_io.json", "r") as f:
    io = json.load(f)
    queries = th.tensor(io['queries'])
    keys = th.tensor(io['keys'])
    values = th.tensor(io['values'])
    expected_output  = th.tensor(io['output'])
    expected_weights = th.tensor(io['weights'])

attn_qkv = AttentionQKV()
output, weights = attn_qkv(queries, keys, values)
validate_to_array(model_out_to_list,((queries,keys,values),attn_qkv),'attentionqkv', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output)).item(), "(should be 0.0 or close to 0.0)")
print("Total error on the weights:",th.sum(th.abs(expected_weights-weights)).item(), "(should be 0.0 or close to 0.0)")

Total error on the output: 2.8312206268310547e-07 (should be 0.0 or close to 0.0)
Total error on the weights: 2.849847078323364e-07 (should be 0.0 or close to 0.0)


  arr = np.asanyarray(arr)


## (2) Implementing Multi-head attention

This part is located in the class MultiHeadProjection in transformer_attention.py.
You must implement the call, \_split_heads, and \_combine_heads functions.

**Procedure**

The objective is to leverage the AttentionQKV class you already wrote.

Your input are the queries, keys, values as 3-d tensors (batch_size, sequence_length, feature_size).

Split them into 4-d tensors (batch_size, n_heads, sequence_length, new_feature_size). Where:
$$feature\_size = n\_heads * new_{feature\_size}.$$

You can then feed the split qkv to your implemented AttentionQKV, which will treat each head as an independent attention function.

Then the output must be combined back into a 3-d tensor.
You can test the validity of your implementation in the cell below.

In [84]:
from transformer_attention import MultiHeadProjection

batch_size = 2;
n_queries = 3;
n_heads = 4
n_keyval = 5;
depth_k = 8;
depth_v = 8;

with open(root_folder+"transformer_checks/multihead_io.json", "r") as f:
    io = json.load(f)
    queries = th.tensor(io['queries'])
    keys = th.tensor(io['keys'])
    values = th.tensor(io['values'])
    expected_output  = th.tensor(io['output'])

mhp = MultiHeadProjection(n_heads, (depth_k,depth_v))
multihead_output = mhp((queries, keys, values))
validate_to_array(model_out_to_list,(((queries,keys,values),),mhp),'multihead', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-multihead_output)).item(), "(should be 0.0 or close to 0.0)")

Total error on the output: 1.5934929251670837e-06 (should be 0.0 or close to 0.0)


## (3) Position Embedding 

You must implement the FeedForward and PositionEmbedding classes in transformer.py.


The cell below helps you verify the validity of your implementation


In [85]:
from transformer import PositionEmbedding

batch_size = 2;
sequence_length = 3;
dim = 4;

with open(root_folder+"transformer_checks/position_embedding_io.json", "r") as f:
    io = json.load(f)
    inputs = th.tensor(io['inputs'])
    expected_output  = th.tensor(io['output'])

pos_emb = PositionEmbedding(dim)
(inputs,expected_output,pos_emb) = list_to_device((inputs,expected_output,pos_emb))
output_t = pos_emb(inputs)
validate_to_array(model_out_to_list,((inputs,),pos_emb),'position_embedding', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output_t)).item(), "(should be 0.0 or close to 0.0)")

Total error on the output: 2.980232238769531e-07 (should be 0.0 or close to 0.0)


In [87]:
a = th.Tensor([[1,1], [1,1]])
b = th.Tensor([[2,2], [2,2]])
print(th.stack((a,b), dim=2))
print(th.stack((a,b), dim=2).view(2,4))

tensor([[[1., 2.],
         [1., 2.]],

        [[1., 2.],
         [1., 2.]]])
tensor([[1., 2., 1., 2.],
        [1., 2., 1., 2.]])


## (4) Transformer Encoder / Transformer Decoder

You now have all the blocks needed to implement the Transformer.
For this part, you have to fill in 2 classes in the transformer.py file: TransformerEncoderBlock, TransformerDecoderBlock.

The code below will verify the accuracy of each block

In [86]:
from transformer import TransformerEncoderBlock

batch_size = 2
sequence_length = 5
hidden_size = 6
filter_size = 12
n_heads = 2

with open(root_folder+"transformer_checks/transformer_encoder_block_io_new.json", "r") as f:
    io = json.load(f)
    inputs = th.tensor(io['inputs'])
    expected_output = th.tensor(io['output'])
enc_block = TransformerEncoderBlock(input_size=6, n_heads=n_heads, filter_size=filter_size, hidden_size=hidden_size)
# th.save(enc_block.state_dict(),root_folder+"transformer_checks/transformer_encoder_block")
enc_block.load_state_dict(th.load(root_folder+"transformer_checks/transformer_encoder_block"))
(inputs,expected_output,enc_block) = list_to_device((inputs,expected_output,enc_block))
output_t = enc_block(inputs)
validate_to_array(model_out_to_list,((inputs,),enc_block),'encoder_block', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output_t)).item(), "(should be 0.0 or close to 0.0)")

Total error on the output: 4.999339580535889e-06 (should be 0.0 or close to 0.0)


In [88]:
from transformer import TransformerDecoderBlock
batch_size = 2
encoder_length = 5
decoder_length = 3
hidden_size = 6
filter_size = 12
n_heads = 2

with open(root_folder+"transformer_checks/transformer_decoder_block_io_new.json", "r") as f:
    io = json.load(f)
    decoder_inputs = th.tensor(io['decoder_inputs'])
    encoder_output = th.tensor(io['encoder_output'])
    expected_output = th.tensor(io['output'])

dec_block = TransformerDecoderBlock(input_size=6, n_heads=n_heads, filter_size=filter_size, hidden_size=hidden_size)
dec_block.load_state_dict(th.load(root_folder+"transformer_checks/transformer_decoder_block"))
(decoder_inputs,encoder_output,expected_output,dec_block) = list_to_device((decoder_inputs,encoder_output,expected_output,dec_block))
output_t = dec_block(decoder_inputs, encoder_output)
validate_to_array(model_out_to_list,((decoder_inputs, encoder_output),dec_block),'decoder_block', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output_t)).item(), "(should be 0.0 or close to 0.0)")


Total error on the output: 3.2186508178710938e-06 (should be 0.0 or close to 0.0)


## (5) Transformer

This is the final high-level function that pieces it all together.

You have to implement the call function of the Transformer class in the `transformer.py` file.

The block below verifies your implementation is correct.

In [90]:
from transformer import Transformer

batch_size = 2
vocab_size = 11
n_layers = 3
n_heads = 4
d_model = 8
d_filter = 16
input_length = 5
output_length = 3

with open(root_folder+"transformer_checks/transformer_io_new.json", "r") as f:
    io = json.load(f)
    enc_input = th.tensor(io['enc_input'])
    dec_input = th.tensor(io['dec_input'])
    enc_mask = th.tensor(io['enc_mask'])
    dec_mask = th.tensor(io['dec_mask'])
    expected_output = th.tensor(io['output'])
transformer = Transformer(vocab_size=vocab_size, n_layers=n_layers, n_heads=n_heads, d_model=d_model, d_filter=d_filter)
transformer.load_state_dict(th.load(root_folder+"transformer_checks/transformer"))
(enc_input,dec_input,enc_mask,dec_mask,expected_output,transformer) \
    = list_to_device((enc_input,dec_input,enc_mask,dec_mask,expected_output,transformer))
output_t = transformer(enc_input, target_sequence=dec_input, encoder_mask=enc_mask, decoder_mask=dec_mask)
validate_to_array(model_out_to_list, ((enc_input, dec_input, enc_mask, dec_mask),transformer),'transformer', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output_t)).item(), "(should be 0.0 or close to 0.0)")

Total error on the output: 5.9545040130615234e-05 (should be 0.0 or close to 0.0)


# Training the model

Your objective is to train the Language on the dataset you are provided to reach a **validation loss <= 6.50**

Careful: we will be testing this loss on an unreleased test set, so make sure to evaluate properly on a validation set and not overfit.

You must save the model you want us to test under: models/final_transformer_summarization (the .index, .meta and .data files)

**Advice**:
- It should be possible to attain validation loss <= 6.50 with the model dimensions we've specified (n_layers=6, d_model=104, d_filter=416), but you can tune these hyperparameters. Increasing d_model will yield better model, at the cost of longer training time.
- You should try tuning the learning rate, as well as what optimizer you use.
- You might need to train for a few (up to 2 hours) to obtain our expected loss. Remember to tune your hyperparameters first, once you find ones that work well, let it train for longer.

**Dataset**: as in the previous notebook, make sure the dataset files are in the `dataset` folder. These can be found on the Google Drive.


In [91]:
with open(root_folder+"dataset/summarization_dataset_preprocessed.json", "r") as f:
    dataset = json.load(f)

# We load the dataset, and split it into 2 sub-datasets based on if they are training or validation.
# Feel free to split this dataset another way, but remember, a validation set is important, to have an idea of 
# the amount of overfitting that has occurred!

d_train = [d for d in dataset if d['cut'] == 'training']
d_valid = [d for d in dataset if d['cut'] == 'evaluation']

len(d_train), len(d_valid)

(61055, 1558)

In [92]:
# An example (article, summary) pair in the training data:

print(d_train[145]['story'])
print("=======================\n=======================")
print(d_train[145]['summary'])

Tbilisi, Georgia (CNN)Police have shot and killed a white tiger that killed a man Wednesday in Tbilisi, Georgia, a Ministry of Internal Affairs representative said, after severe flooding allowed hundreds of wild animals to escape the city zoo. 
The tiger attack happened at a warehouse in the city center. The animal had been unaccounted for since the weekend floods destroyed the zoo premises.
The man killed, who was 43, worked in a company based in the warehouse, the Ministry of Internal Affairs said. Doctors said he was attacked in the throat and died before reaching the hospital. 
Experts are still searching the warehouse, the ministry said, adding that earlier reports that the tiger had injured a second man were unfounded. 
The zoo administration said Wednesday that another tiger was still missing. It was unable to confirm if the creature was dead or had escaped alive.
Georgian Prime Minister Irakli Garibashvili apologized to the public, saying he had been misinformed by the zoo's ma

In [93]:
print(d_train[0])

{'story': 'Wang Deqing (in black) and his students display their skills at the United Nations office in Vienna. Provided to China Daily\nHungarians claim to be the only Europeans with Eastern roots. That may explain the decision to incorporate elements of Zen philosophy and kung fu in training courses for the country\'s top police units and the presidential bodyguard. \nWang Deqing, a 32nd-generation Shaolin Temple warrior monk who moved to Hungary in 1999, works at Hungary\'s National Police School as head coach of the Special Police Force and coach of the president\'s escort. \n"Many people think I equip them with kung fu skills", said Wang, speaking at the International Chan Wu Federation center he established in 2003. \n"That\'s not true, especially when training the president\'s sniper team", he said. \nIn practice, "Chan", which refers to Zen Buddhism, and "Wu" (martial arts) are equally important: "I mainly use Chan to cultivate the minds of the snipers, who always work under hi

In [None]:
'''
d_train[0]

{'story': 'Wang Deqing (in black) and his students display their skills at the United Nations office in Vienna. Provided to China Daily\nHungarians claim to be the only Europeans with Eastern roots. That may explain the decision to incorporate elements of Zen philosophy and kung fu in training courses for the country\'s top police units and the presidential bodyguard. \nWang Deqing, a 32nd-generation Shaolin Temple warrior monk who moved to Hungary in 1999, works at Hungary\'s National Police School as head coach of the Special Police Force and coach of the president\'s escort. \n"Many people think I equip them with kung fu skills", said Wang, speaking at the International Chan Wu Federation center he established in 2003. \n"That\'s not true, especially when training the president\'s sniper team", he said. \nIn practice, "Chan", which refers to Zen Buddhism, and "Wu" (martial arts) are equally important: "I mainly use Chan to cultivate the minds of the snipers, who always work under high-pressure conditions", he said. \n"I have taught them how to be calm and have a peaceful mind, even in super-dangerous and critical situations, or when working alone in concealed surroundings." \nHis efforts were recognized when he was appointed executive chairman of the China-Hungary Police Exchange Association, associated with Hungary\'s Ministry of Interior Affairs. \nRecently, Peter Medgyessy, Hungary\'s former prime minister, invited Wang to dinner and thanked him for his contribution to boosting Chinese culture in the country. \nWang said Medgyessy\'s assistance as prime minister from 2002 to 2004 was essential to the promotion of traditional Chinese medicine, acupuncture, kung fu and even the Chinese language in Hungary: "Medgyessy is a visionary and respected Hungarian leader who boosted Sino-Hungarian exchanges." \nWang also trains Chan Wu coaches in Europe, home to many of the 30 branches of the Chan Wu Federation across the world, which cater to about 200,000 practitioners. \n"I think it is most popular in Hungary. One Hungarian coach told me he has taught about 1,000 students", said Wang. "That\'s an amazing achievement." \nHis students are required to preserve and promote authentic Shaolin kung fu as it was taught to him by his masters, including traditional etiquette and disciplines. \nHe said the rules encourage students to cultivate martial virtues and establish harmonious, happy attitudes and values.A A', 
'cut': 'training', 
'avg_sent_length': 18.0, 
'n_sents': 2, 
'summary': "Hungarians claim to be the only Europeans with Eastern roots. That may explain the decision to incorporate elements of Zen philosophy and kung fu in training courses for the country's top police units and the presidential bodyguard.", 
'source': 'chinadaily.com.cn', 
'n_words': 36, 
'type': 'newslens', 
'input': [3, 4, 8618, 3, 4, 258, 1777, 29, 3, 0, 13, 405, 3, 0, 11, 34, 841, 1849, 59, 3135, 35, 5, 3, 4, 232, 3, 4, 1843, 391, 13, 8706, 23, 3, 4, 2142, 9, 3, 4, 826, 3, 4, 644, 3, 4, 5213, 4054, 8, 1031, 9, 39, 5, 139, 3, 4, 901, 8, 31, 3, 4, 2454, 4527, 8, 23, 3, 4, 20, 156, 2356, 5, 592, 9, 8936, 6124, 12, 3, 4, 3, 5757, 8654, 11, 613, 2403, 3360, 13, 816, 812, 8, 21, 5, 221, 44, 8, 298, 107, 5766, 11, 5, 184, 573, 378, 4965, 23, 3, 4, 8618, 3, 4, 258, 1777, 29, 7, 10, 1922, 1653, 16, 9025, 3, 4, 1253, 110, 768, 3, 4, 5561, 7574, 8116, 52, 1038, 9, 3, 4, 5213, 895, 13, 4331, 7, 1513, 35, 5213, 895, 44, 8, 3, 4, 276, 3, 4, 107, 3, 4, 222, 38, 309, 1096, 12, 5, 3, 4, 741, 3, 4, 107, 3, 4, 777, 11, 1096, 12, 5, 184, 44, 8, 5746, 23, 18, 3, 4, 165, 85, 205, 3, 15, 36, 3, 57, 1202, 214, 112, 31, 613, 2403, 3360, 3135, 18, 7, 32, 3, 4, 8618, 7, 1136, 35, 5, 3, 4, 384, 3, 4, 3, 2214, 3, 4, 885, 353, 3, 4, 6773, 559, 28, 4180, 13, 2956, 23, 18, 20, 44, 8, 53, 1579, 7, 1429, 67, 816, 5, 184, 44, 8, 8901, 209, 18, 7, 28, 32, 23, 3, 4, 13, 2135, 7, 18, 3, 4, 3, 2214, 18, 7, 70, 5735, 8, 9, 3, 4, 3, 5757, 3, 4, 7211, 241, 1277, 7, 11, 18, 3, 4, 885, 353, 18, 3, 0, 905, 3797, 4434, 3, 0, 48, 6524, 695, 33, 18, 3, 15, 36, 973, 63, 333, 3, 4, 3, 2214, 9, 3, 5367, 1960, 692, 5, 1186, 8, 12, 5, 8901, 8, 7, 52, 427, 186, 188, 230, 16, 2932, 1305, 1489, 18, 7, 28, 32, 23, 18, 3, 15, 36, 40, 3, 4718, 112, 142, 9, 39, 3188, 11, 40, 10, 5024, 1186, 7, 163, 13, 1234, 16, 5168, 997, 846, 11, 2384, 984, 8, 7, 77, 67, 463, 1612, 13, 7416, 22, 2822, 8, 23, 18, 3, 4, 34, 1703, 58, 7151, 67, 28, 19, 3, 4866, 1008, 1808, 12, 5, 826, 16, 241, 2403, 895, 3, 4, 107, 3, 4, 2227, 3, 4, 1655, 7, 2712, 22, 31], 
'output': [3, 4, 5213, 4054, 8, 1031, 9, 39, 5, 139, 3, 4, 901, 8, 31, 3, 4, 2454, 4527, 8, 23, 3, 4, 20, 156, 2356, 5, 592, 9, 8936, 6124, 12, 3, 4, 3, 5757, 8654, 11, 613, 2403, 3360, 13, 816, 812, 8, 21, 5, 221, 44, 8, 298, 107, 5766, 11, 5, 184, 573, 378, 4965, 6, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998], 
'input_mask': [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True], 
'output_mask': [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]}
'''

Similarly to the previous assignment, we create a function to get a random batch to train on, given a dataset.

In [94]:
def build_batch(dataset, batch_size):
    indices = list(np.random.randint(0, len(dataset), size=batch_size))
    
    batch = [dataset[i] for i in indices]
    batch_input = np.array([a['input'] for a in batch])
    batch_input_mask = np.array([a['input_mask'] for a in batch])
    batch_output = np.array([a['output'] for a in batch])
    batch_output_mask = np.array([a['output_mask'] for a in batch])
    
    return batch_input, batch_input_mask, batch_output, batch_output_mask

We now instantiate the Transformer with our sets of hyperparameters specific to the task of summarization.
In summarization, we are going to go from documents with up to 400 words, to documents with up to 100 words.
The vocabulary size is set for you, and is of 10,000 words (we are using WordPieces, [here is a paper about subword encoding](http://aclweb.org/anthology/P18-1007), if you are interested).

In [95]:
# Use this trainer to train a Transformer model

class TransformerTrainer(nn.Module):
    def __init__(self, vocab_size, d_model, input_length, output_length, n_layers, d_filter, dropout=0, learning_rate=1e-3):
        super().__init__()
        self.model = Transformer(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, d_filter=d_filter)

        # Summarization loss
        criterion = nn.CrossEntropyLoss(reduce='none')
        #Masking된 부분은 loss 계산에서 제외해야함 (즉 Mask = False부분) 그렇지 않으면 loss가 underpreciated됨
        #assignment 3.1 참조
        self.loss_fn = lambda pred,target,mask: (criterion(pred.permute(0,2,1),target)*mask).sum()/mask.sum()
        self.learning_rate = learning_rate
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
    def forward(self,batch,optimize=True):
        pred_logits = self.model(**batch)
        target,mask = batch['target_sequence'],batch['decoder_mask']
        loss = self.loss_fn(pred_logits,target,mask)
        accuracy = (th.eq(pred_logits.argmax(dim=2,keepdim=False),target).float()*mask).sum()/mask.sum()
        
        if optimize:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
                
        return loss, accuracy

In [96]:
# Dataset related parameters
vocab_size = len(vocab)
ilength = 400 # Length of the article
olength  = 100 # Length of the summaries

# Model related parameters, feel free to modify these.
n_layers = 6
d_model  = 320
d_filter = 600
batch_size = 16

dropout = 0.5
learning_rate = 1e-3
trainer = TransformerTrainer(vocab_size, d_model, ilength, olength, n_layers, d_filter, dropout)
model_id = 'test1'
os.makedirs(root_folder+'models/part2/',exist_ok=True)

device = th.device("cuda" if th.cuda.is_available() else "cpu")
print(device)
set_device(device)

cuda




In [97]:
# Skeleton code, as in the previous notebook.
# Write code training code and save your best performing model on the
# validation set. We will be testing the loss on a held-out test dataset.
from tqdm import tqdm
gc.collect()
trainer.model.to(device)
trainer.model.train()
losses,accuracies = [],[]
t = tqdm(range(int(1e4)+1))
for i in t:
    trainer.model.train()
    # Create a random mini-batch from the training dataset
    batch = build_batch(d_train, batch_size)
    # Build the feed-dict connecting placeholders and mini-batch
    batch_input, batch_input_mask, batch_output, batch_output_mask = [th.tensor(tensor) for tensor in batch]
    batch_input, batch_input_mask, batch_output, batch_output_mask \
                = list_to_device([batch_input, batch_input_mask, batch_output, batch_output_mask])
    batch = {'source_sequence': batch_input, 'target_sequence': batch_output,
            'encoder_mask': batch_input_mask, 'decoder_mask': batch_output_mask}

    # Obtain the loss. Be careful when you use the train_op and not, as previously.
    train_loss, accuracy = trainer(batch)
    losses.append(train_loss.item()),accuracies.append(accuracy.item())
    if i % 10 == 0:
        t.set_description(f"Iteration: {i} Loss: {np.mean(losses[-10:]):0.4f} Accuracy: {np.mean(accuracies[-10:]):0.2f}")
    if i % 100 == 0:
      trainer.model.eval()
      valid_losses = []
      for i in range(100):
        batch = build_batch(d_valid, 1)
        # Build the feed-dict connecting placeholders and mini-batch
        batch_input, batch_input_mask, batch_output, batch_output_mask = [th.tensor(tensor) for tensor in batch]
        batch_input, batch_input_mask, batch_output, batch_output_mask \
                = list_to_device([batch_input, batch_input_mask, batch_output, batch_output_mask])
        batch = {'source_sequence': batch_input, 'target_sequence': batch_output,
              'encoder_mask': batch_input_mask, 'decoder_mask': batch_output_mask}
        valid_loss, accuracy = trainer(batch,optimize=False)
        valid_losses.append(float(valid_loss.cpu().item()))
      print("\nValidation loss:", np.mean(valid_losses))

      save_dict = dict(
            kwargs = dict(
                vocab_size=vocab_size,
                d_model=d_model,
                n_layers=n_layers, 
                d_filter=d_filter
            ),
            model_state_dict = trainer.model.state_dict(),
            notes = ""
      )
      th.save(save_dict, root_folder+f'models/part2/model_{model_id}.pt')

Iteration: 0 Loss: 169.0814 Accuracy: 0.00:   0%|          | 0/10001 [00:00<?, ?it/s]


Validation loss: 197.76614238739015


Iteration: 100 Loss: 21.1189 Accuracy: 0.15:   1%|          | 100/10001 [00:26<38:10,  4.32it/s]


Validation loss: 27.75264456272125


Iteration: 200 Loss: 13.0469 Accuracy: 0.15:   2%|▏         | 200/10001 [00:52<38:18,  4.26it/s]


Validation loss: 21.657522740364076


Iteration: 300 Loss: 9.5846 Accuracy: 0.16:   3%|▎         | 300/10001 [01:19<38:23,  4.21it/s]


Validation loss: 12.986054615974426


Iteration: 400 Loss: 35.7179 Accuracy: 0.12:   4%|▍         | 400/10001 [01:46<37:59,  4.21it/s]


Validation loss: 35.5619765663147


Iteration: 500 Loss: 7.3016 Accuracy: 0.17:   5%|▍         | 500/10001 [02:13<38:00,  4.17it/s]


Validation loss: 9.484641375541687


Iteration: 600 Loss: 4.9648 Accuracy: 0.22:   6%|▌         | 600/10001 [02:40<37:32,  4.17it/s]


Validation loss: 6.742714303731918


Iteration: 700 Loss: 4.4069 Accuracy: 0.22:   7%|▋         | 700/10001 [03:07<37:21,  4.15it/s]


Validation loss: 5.95758271932602


Iteration: 800 Loss: 4.8018 Accuracy: 0.22:   8%|▊         | 800/10001 [03:35<37:07,  4.13it/s]


Validation loss: 6.005355693101883


Iteration: 900 Loss: 4.0667 Accuracy: 0.25:   9%|▉         | 900/10001 [04:02<37:01,  4.10it/s]


Validation loss: 5.396253039836884


Iteration: 1000 Loss: 3.7115 Accuracy: 0.26:  10%|▉         | 1000/10001 [04:29<36:17,  4.13it/s]


Validation loss: 5.534274363517762


Iteration: 1100 Loss: 4.0214 Accuracy: 0.27:  11%|█         | 1100/10001 [04:57<36:11,  4.10it/s]


Validation loss: 5.176209925413132


Iteration: 1200 Loss: 3.9293 Accuracy: 0.27:  12%|█▏        | 1200/10001 [05:24<35:39,  4.11it/s]


Validation loss: 5.172869830131531


Iteration: 1300 Loss: 3.5923 Accuracy: 0.26:  13%|█▎        | 1300/10001 [05:52<35:30,  4.08it/s]


Validation loss: 5.098096271753311


Iteration: 1400 Loss: 3.5282 Accuracy: 0.28:  14%|█▍        | 1400/10001 [06:19<34:53,  4.11it/s]


Validation loss: 5.166384080052376


Iteration: 1500 Loss: 3.6449 Accuracy: 0.28:  15%|█▍        | 1500/10001 [06:47<34:33,  4.10it/s]


Validation loss: 5.407085096240044


Iteration: 1600 Loss: 3.5115 Accuracy: 0.28:  16%|█▌        | 1600/10001 [07:14<34:18,  4.08it/s]


Validation loss: 4.858137022256852


Iteration: 1700 Loss: 3.3303 Accuracy: 0.29:  17%|█▋        | 1700/10001 [07:42<33:45,  4.10it/s]


Validation loss: 4.7070702487230305


Iteration: 1800 Loss: 3.3302 Accuracy: 0.28:  18%|█▊        | 1800/10001 [08:09<33:19,  4.10it/s]


Validation loss: 4.6834423604607585


Iteration: 1900 Loss: 3.2559 Accuracy: 0.29:  19%|█▉        | 1900/10001 [08:37<32:51,  4.11it/s]


Validation loss: 5.085670074820518


Iteration: 2000 Loss: 3.6211 Accuracy: 0.27:  20%|█▉        | 2000/10001 [09:04<32:46,  4.07it/s]


Validation loss: 4.779442867040634


Iteration: 2100 Loss: 3.4601 Accuracy: 0.28:  21%|██        | 2100/10001 [09:32<32:17,  4.08it/s]


Validation loss: 4.907604376375676


Iteration: 2200 Loss: 3.2785 Accuracy: 0.28:  22%|██▏       | 2200/10001 [09:59<31:49,  4.09it/s]


Validation loss: 4.727091668844223


Iteration: 2300 Loss: 3.4714 Accuracy: 0.30:  23%|██▎       | 2300/10001 [10:27<31:31,  4.07it/s]


Validation loss: 5.000888011455536


Iteration: 2400 Loss: 3.6639 Accuracy: 0.28:  24%|██▍       | 2400/10001 [10:54<31:10,  4.06it/s]


Validation loss: 5.061364965438843


Iteration: 2500 Loss: 3.2082 Accuracy: 0.29:  25%|██▍       | 2500/10001 [11:22<30:51,  4.05it/s]


Validation loss: 4.467503894865513


Iteration: 2600 Loss: 3.6575 Accuracy: 0.30:  26%|██▌       | 2600/10001 [11:50<30:20,  4.06it/s]


Validation loss: 4.499722100049257


Iteration: 2700 Loss: 3.6265 Accuracy: 0.29:  27%|██▋       | 2700/10001 [12:17<29:52,  4.07it/s]


Validation loss: 4.763921056985855


Iteration: 2800 Loss: 3.3648 Accuracy: 0.28:  28%|██▊       | 2800/10001 [12:44<29:23,  4.08it/s]


Validation loss: 4.464279787540436


Iteration: 2900 Loss: 3.2683 Accuracy: 0.31:  29%|██▉       | 2900/10001 [13:12<28:49,  4.10it/s]


Validation loss: 4.29125749707222


Iteration: 3000 Loss: 3.4780 Accuracy: 0.31:  30%|██▉       | 3000/10001 [13:40<28:33,  4.09it/s]


Validation loss: 4.448152059465647


Iteration: 3100 Loss: 3.2936 Accuracy: 0.32:  31%|███       | 3100/10001 [14:07<27:58,  4.11it/s]


Validation loss: 4.598234370201826


Iteration: 3200 Loss: 3.2080 Accuracy: 0.30:  32%|███▏      | 3200/10001 [14:35<27:43,  4.09it/s]


Validation loss: 4.539089050889015


Iteration: 3300 Loss: 3.1432 Accuracy: 0.30:  33%|███▎      | 3300/10001 [15:02<27:24,  4.07it/s]


Validation loss: 4.50026059538126


Iteration: 3400 Loss: 3.2197 Accuracy: 0.30:  34%|███▍      | 3400/10001 [15:30<27:01,  4.07it/s]


Validation loss: 4.52170858681202


Iteration: 3500 Loss: 3.2252 Accuracy: 0.30:  35%|███▍      | 3500/10001 [15:57<26:29,  4.09it/s]


Validation loss: 4.229967430531978


Iteration: 3600 Loss: 3.1636 Accuracy: 0.31:  36%|███▌      | 3600/10001 [16:25<26:02,  4.10it/s]


Validation loss: 4.814085837602615


Iteration: 3700 Loss: 3.3057 Accuracy: 0.32:  37%|███▋      | 3700/10001 [16:52<25:35,  4.10it/s]


Validation loss: 4.526949077993631


Iteration: 3800 Loss: 3.3420 Accuracy: 0.30:  38%|███▊      | 3800/10001 [17:20<25:15,  4.09it/s]


Validation loss: 4.2904874366521835


Iteration: 3900 Loss: 3.1343 Accuracy: 0.29:  39%|███▉      | 3900/10001 [17:47<24:53,  4.09it/s]


Validation loss: 4.742253116369247


Iteration: 4000 Loss: 3.0831 Accuracy: 0.31:  40%|███▉      | 4000/10001 [18:15<24:22,  4.10it/s]


Validation loss: 4.274132409989834


Iteration: 4100 Loss: 3.1320 Accuracy: 0.32:  41%|████      | 4100/10001 [18:42<24:05,  4.08it/s]


Validation loss: 4.260210165083408


Iteration: 4200 Loss: 3.4072 Accuracy: 0.30:  42%|████▏     | 4200/10001 [19:10<23:43,  4.08it/s]


Validation loss: 4.2933994761109355


Iteration: 4300 Loss: 3.1505 Accuracy: 0.31:  43%|████▎     | 4300/10001 [19:37<23:22,  4.07it/s]


Validation loss: 4.395208732187748


Iteration: 4400 Loss: 3.1165 Accuracy: 0.31:  44%|████▍     | 4400/10001 [20:05<22:49,  4.09it/s]


Validation loss: 4.483158076405525


Iteration: 4500 Loss: 3.3200 Accuracy: 0.31:  45%|████▍     | 4500/10001 [20:32<22:23,  4.09it/s]


Validation loss: 4.598066391348839


Iteration: 4600 Loss: 3.3462 Accuracy: 0.31:  46%|████▌     | 4600/10001 [21:00<22:05,  4.08it/s]


Validation loss: 4.691667890548706


Iteration: 4700 Loss: 3.1999 Accuracy: 0.31:  47%|████▋     | 4700/10001 [21:27<21:33,  4.10it/s]


Validation loss: 4.451281911879778


Iteration: 4800 Loss: 2.9015 Accuracy: 0.32:  48%|████▊     | 4800/10001 [21:55<21:15,  4.08it/s]


Validation loss: 4.583016339540482


Iteration: 4900 Loss: 2.9850 Accuracy: 0.32:  49%|████▉     | 4900/10001 [22:22<20:47,  4.09it/s]


Validation loss: 4.680178520083428


Iteration: 5000 Loss: 3.1398 Accuracy: 0.32:  50%|████▉     | 5000/10001 [22:50<20:30,  4.06it/s]


Validation loss: 4.308933596909046


Iteration: 5100 Loss: 2.9472 Accuracy: 0.32:  51%|█████     | 5100/10001 [23:18<20:00,  4.08it/s]


Validation loss: 4.212918794751167


Iteration: 5200 Loss: 3.0893 Accuracy: 0.30:  52%|█████▏    | 5200/10001 [23:45<19:36,  4.08it/s]


Validation loss: 4.291961791962385


Iteration: 5300 Loss: 2.9088 Accuracy: 0.31:  53%|█████▎    | 5300/10001 [24:13<19:14,  4.07it/s]


Validation loss: 4.332009802013635


Iteration: 5400 Loss: 2.9839 Accuracy: 0.33:  54%|█████▍    | 5400/10001 [24:40<18:46,  4.08it/s]


Validation loss: 4.1100115495175125


Iteration: 5500 Loss: 3.2452 Accuracy: 0.32:  55%|█████▍    | 5500/10001 [25:08<18:28,  4.06it/s]


Validation loss: 4.571375412940979


Iteration: 5600 Loss: 3.1027 Accuracy: 0.32:  56%|█████▌    | 5600/10001 [25:35<17:59,  4.08it/s]


Validation loss: 4.26611090451479


Iteration: 5700 Loss: 3.3076 Accuracy: 0.32:  57%|█████▋    | 5700/10001 [26:03<17:34,  4.08it/s]


Validation loss: 4.161788202226162


Iteration: 5800 Loss: 3.2099 Accuracy: 0.33:  58%|█████▊    | 5800/10001 [26:31<17:16,  4.05it/s]


Validation loss: 4.521788949370384


Iteration: 5900 Loss: 3.2088 Accuracy: 0.33:  59%|█████▉    | 5900/10001 [26:58<16:47,  4.07it/s]


Validation loss: 4.454361218437552


Iteration: 6000 Loss: 3.0011 Accuracy: 0.32:  60%|█████▉    | 6000/10001 [27:26<16:19,  4.08it/s]


Validation loss: 4.091231398880482


Iteration: 6100 Loss: 3.0588 Accuracy: 0.32:  61%|██████    | 6100/10001 [27:53<15:58,  4.07it/s]


Validation loss: 4.27446113422513


Iteration: 6200 Loss: 3.1235 Accuracy: 0.33:  62%|██████▏   | 6200/10001 [28:21<15:31,  4.08it/s]


Validation loss: 4.128960973024368


Iteration: 6300 Loss: 3.0669 Accuracy: 0.32:  63%|██████▎   | 6300/10001 [28:48<15:11,  4.06it/s]


Validation loss: 4.521404674053192


Iteration: 6400 Loss: 3.0771 Accuracy: 0.32:  64%|██████▍   | 6400/10001 [29:16<14:46,  4.06it/s]


Validation loss: 4.185351933240891


Iteration: 6500 Loss: 3.0160 Accuracy: 0.32:  65%|██████▍   | 6500/10001 [29:44<14:18,  4.08it/s]


Validation loss: 4.340578725188971


Iteration: 6600 Loss: 2.9424 Accuracy: 0.32:  66%|██████▌   | 6600/10001 [30:11<13:52,  4.09it/s]


Validation loss: 4.031976688206196


Iteration: 6700 Loss: 3.1250 Accuracy: 0.32:  67%|██████▋   | 6700/10001 [30:39<13:30,  4.07it/s]


Validation loss: 4.412088760137558


Iteration: 6800 Loss: 2.8688 Accuracy: 0.32:  68%|██████▊   | 6800/10001 [31:06<13:06,  4.07it/s]


Validation loss: 4.4759162622690205


Iteration: 6900 Loss: 2.9188 Accuracy: 0.32:  69%|██████▉   | 6900/10001 [31:34<12:45,  4.05it/s]


Validation loss: 3.969870668053627


Iteration: 7000 Loss: 3.0841 Accuracy: 0.33:  70%|██████▉   | 7000/10001 [32:02<12:17,  4.07it/s]


Validation loss: 3.9860355344414713


Iteration: 7100 Loss: 3.1680 Accuracy: 0.32:  71%|███████   | 7100/10001 [32:29<11:54,  4.06it/s]


Validation loss: 4.032039672359824


Iteration: 7200 Loss: 2.8131 Accuracy: 0.34:  72%|███████▏  | 7200/10001 [32:57<11:28,  4.07it/s]


Validation loss: 4.200180932283401


Iteration: 7300 Loss: 2.9516 Accuracy: 0.33:  73%|███████▎  | 7300/10001 [33:24<11:00,  4.09it/s]


Validation loss: 4.234337040185928


Iteration: 7400 Loss: 2.9177 Accuracy: 0.33:  74%|███████▍  | 7400/10001 [33:52<10:37,  4.08it/s]


Validation loss: 4.18317928776145


Iteration: 7500 Loss: 2.8358 Accuracy: 0.34:  75%|███████▍  | 7500/10001 [34:20<10:15,  4.06it/s]


Validation loss: 3.924242662191391


Iteration: 7600 Loss: 3.0616 Accuracy: 0.32:  76%|███████▌  | 7600/10001 [34:47<09:47,  4.08it/s]


Validation loss: 4.141073478087783


Iteration: 7700 Loss: 3.0573 Accuracy: 0.33:  77%|███████▋  | 7700/10001 [35:15<09:24,  4.08it/s]


Validation loss: 4.136583731174469


Iteration: 7800 Loss: 3.0180 Accuracy: 0.34:  78%|███████▊  | 7800/10001 [35:42<08:58,  4.09it/s]


Validation loss: 4.256031985580921


Iteration: 7900 Loss: 2.8239 Accuracy: 0.34:  79%|███████▉  | 7900/10001 [36:10<08:35,  4.08it/s]


Validation loss: 4.440949386060238


Iteration: 8000 Loss: 3.2049 Accuracy: 0.33:  80%|███████▉  | 8000/10001 [36:37<08:09,  4.09it/s]


Validation loss: 4.16936356022954


Iteration: 8100 Loss: 2.8701 Accuracy: 0.33:  81%|████████  | 8100/10001 [37:05<07:43,  4.10it/s]


Validation loss: 3.9654847410321237


Iteration: 8200 Loss: 3.0995 Accuracy: 0.35:  82%|████████▏ | 8200/10001 [37:32<07:19,  4.09it/s]


Validation loss: 3.9095156174898147


Iteration: 8300 Loss: 2.9062 Accuracy: 0.34:  83%|████████▎ | 8300/10001 [38:00<06:58,  4.07it/s]


Validation loss: 4.104664561748504


Iteration: 8400 Loss: 3.0897 Accuracy: 0.34:  84%|████████▍ | 8400/10001 [38:28<06:30,  4.10it/s]


Validation loss: 4.3487956093996765


Iteration: 8500 Loss: 2.9534 Accuracy: 0.33:  85%|████████▍ | 8500/10001 [38:55<06:09,  4.07it/s]


Validation loss: 3.890183217227459


Iteration: 8600 Loss: 2.9886 Accuracy: 0.32:  86%|████████▌ | 8600/10001 [39:23<05:44,  4.07it/s]


Validation loss: 3.974294940829277


Iteration: 8700 Loss: 2.7853 Accuracy: 0.34:  87%|████████▋ | 8700/10001 [39:50<05:19,  4.08it/s]


Validation loss: 4.004639595225453


Iteration: 8800 Loss: 2.9896 Accuracy: 0.33:  88%|████████▊ | 8800/10001 [40:18<04:54,  4.07it/s]


Validation loss: 4.18040486574173


Iteration: 8900 Loss: 2.9481 Accuracy: 0.35:  89%|████████▉ | 8900/10001 [40:45<04:29,  4.09it/s]


Validation loss: 4.200621882081032


Iteration: 9000 Loss: 3.0525 Accuracy: 0.33:  90%|████████▉ | 9000/10001 [41:13<04:05,  4.09it/s]


Validation loss: 4.18631453871727


Iteration: 9100 Loss: 2.8092 Accuracy: 0.34:  91%|█████████ | 9100/10001 [41:41<03:40,  4.08it/s]


Validation loss: 3.856468292027712


Iteration: 9200 Loss: 3.0249 Accuracy: 0.34:  92%|█████████▏| 9200/10001 [42:08<03:17,  4.06it/s]


Validation loss: 4.301836274862289


Iteration: 9300 Loss: 2.8934 Accuracy: 0.33:  93%|█████████▎| 9300/10001 [42:36<02:53,  4.05it/s]


Validation loss: 3.9562834337353707


Iteration: 9400 Loss: 2.9110 Accuracy: 0.35:  94%|█████████▍| 9400/10001 [43:05<02:27,  4.08it/s]


Validation loss: 4.075473356917501


Iteration: 9500 Loss: 2.8693 Accuracy: 0.35:  95%|█████████▍| 9500/10001 [43:33<02:02,  4.08it/s]


Validation loss: 4.156623038053513


Iteration: 9600 Loss: 2.7464 Accuracy: 0.35:  96%|█████████▌| 9600/10001 [44:00<01:38,  4.09it/s]


Validation loss: 4.206954821944237


Iteration: 9700 Loss: 2.8813 Accuracy: 0.35:  97%|█████████▋| 9700/10001 [44:28<01:13,  4.08it/s]


Validation loss: 4.170395869910717


Iteration: 9800 Loss: 2.9146 Accuracy: 0.36:  98%|█████████▊| 9800/10001 [44:56<00:49,  4.07it/s]


Validation loss: 4.169097366333008


Iteration: 9900 Loss: 2.8694 Accuracy: 0.35:  99%|█████████▉| 9900/10001 [45:23<00:24,  4.06it/s]


Validation loss: 4.121846368312836


Iteration: 10000 Loss: 2.7013 Accuracy: 0.35: 100%|█████████▉| 10000/10001 [45:51<00:00,  4.06it/s]


Validation loss: 4.067913110554218


Iteration: 10000 Loss: 2.7013 Accuracy: 0.35: 100%|██████████| 10001/10001 [45:54<00:00,  3.63it/s]


# Using the Summarization model

Now that you have trained a Transformer to perform Summarization, we will use the model on news articles from the wild.

The three subsections below explore what the model has learned.

## The validation loss

Measure the validation loss of your model. This part could be used, as in our previous notebook, in deciding what is a likely, vs. unlikely summary for an article.

We will use the code here with the unreleased test-set to evaluate your model.

In [98]:
gc.collect()
model_id = "test1"
save_dict = th.load(root_folder+'models/part2/'+f"model_{model_id}.pt", map_location='cpu')
model = Transformer(**save_dict['kwargs'])
model.load_state_dict(save_dict['model_state_dict'])
set_device('cpu')
model.eval()
trainer.model = model

In [99]:
gc.collect()
losses = []
for i in tqdm(range(100)):
    batch = build_batch(d_valid, 1)
    # Build the feed-dict connecting placeholders and mini-batch
    batch_input, batch_input_mask, batch_output, batch_output_mask = [th.tensor(tensor) for tensor in batch]
    batch = {'source_sequence': batch_input, 'target_sequence': batch_output,
            'encoder_mask': batch_input_mask, 'decoder_mask': batch_output_mask}
    valid_loss, accuracy = trainer(batch,optimize=False)
    losses.append(float(valid_loss.cpu().item()))
print("Validation loss:", np.mean(losses))

100%|██████████| 100/100 [00:30<00:00,  3.25it/s]

Validation loss: 3.901614794721827





In [100]:
# Your best performing model should go here.
os.makedirs(root_folder+"best_models",exist_ok=True)
best_model_file = root_folder+"best_models/part2_best_model.pt"
th.save(save_dict,best_model_file)

## Generating an article's summary

This model we have built is meant to be used to generate summaries for new articles we do not have summaries for.
We got a [news article](https://www.chicagotribune.com/news/local/breaking/ct-met-officer-shot-20190309-story.html) from the Chicago Tribune about a police shooting, and want to use our model to produce a summary.

As you will see, our model is still limited in its ability, and will most likely not produce an interpretible summary, however, with more data and training, this model would be able to produce good summaries.

In [48]:
### No dropout, 1000 iter 결과
article_text = "A 34-year-old Chicago police officer has been shot in the shoulder during the execution of a search warrant in the Humboldt Park neighborhood, police say. The alleged shooter, a 19-year-old woman, was in custody. The shooting happened about 7:20 p.m. in the 2700 block of West Potomac Avenue, police said. The officer, part of the Grand Central District tactical unit, was taken to Stroger Hospital. While officers were serving a \"typical\" search warrant for \"narcotics and illegal weapons\" and were attempting to reach a rear door, \"a shot was fired,\" striking the tactical officer in the shoulder, said Chicago police Superintendent Eddie Johnson during a news briefing outside the hospital. He said the officer, who has about four or five years on the job, was \"stable\" but in critical condition. \"His family is here,\" Johnson said. \"He’s talking a lot and just wants the ordeal to be over.\" He said this incident serves as just another reminder of how dangerous a police officer’s job is. At the scene of the shooting, crime tape closed Potomac from Washtenaw Avenue to California Avenue and encompassed the alley west of the brick apartment building, south of Potomac. Dozens of officers stood in the alley, while even more walked up and down the street. Neighbors gathered at the edge of the yellow tape on the sidewalk along California and watched them work. Standing next to a man, a woman talked to police in the crime scene, across the street. \"We're not under arrest? We can go?\" the woman checked with officers. They told her she could go, and she and the man walked underneath the yellow tape and out of the crime scene."
input_length = 400
output_length = 100

# Process the capitalization with the preprocess_capitalization of the capita package.
article_text = capita.preprocess_capitalization(article_text)

# Numerize the tokens of the processed text using the loaded sentencepiece model.
numerized = sp.EncodeAsIds(article_text)
# Pad the sequence and keep the mask of the input
padded, mask = pad_sequence(numerized, pad_index, input_length)

# Making the news article into a batch of size one, to be fed to the neural network.
encoder_input = np.array([padded])
encoder_mask = np.array([mask])

decoded_so_far = [0]

for j in range(output_length):
    padded_decoder_input, decoder_mask = pad_sequence(decoded_so_far, pad_index, output_length)
    padded_decoder_input = [padded_decoder_input]
    decoder_mask = [decoder_mask]
    print("========================")
    print(padded_decoder_input)
    # Use the model to find the distrbution over the vocabulary for the next word
    batch = (encoder_input,encoder_mask,padded_decoder_input,decoder_mask)
    batch_input, batch_input_mask, batch_output, batch_output_mask = [th.tensor(tensor) for tensor in batch]
    batch = {'source_sequence': batch_input, 'target_sequence': batch_output,
            'encoder_mask': batch_input_mask, 'decoder_mask': batch_output_mask}
    logits = trainer.model(**batch).cpu().detach().numpy()

    chosen_words = np.argmax(logits, axis=2) # Take the argmax, getting the most likely next word
    decoded_so_far.append(int(chosen_words[0, j])) # We add it to the summary so far


print("The final summary:")
print("".join([vocab[i] for i in decoded_so_far]).replace("▁", " "))

[[0, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 99

In [117]:
### dropout, 10000 iter, feature 수 수정 결과
article_text = "A 34-year-old Chicago police officer has been shot in the shoulder during the execution of a search warrant in the Humboldt Park neighborhood, police say. The alleged shooter, a 19-year-old woman, was in custody. The shooting happened about 7:20 p.m. in the 2700 block of West Potomac Avenue, police said. The officer, part of the Grand Central District tactical unit, was taken to Stroger Hospital. While officers were serving a \"typical\" search warrant for \"narcotics and illegal weapons\" and were attempting to reach a rear door, \"a shot was fired,\" striking the tactical officer in the shoulder, said Chicago police Superintendent Eddie Johnson during a news briefing outside the hospital. He said the officer, who has about four or five years on the job, was \"stable\" but in critical condition. \"His family is here,\" Johnson said. \"He’s talking a lot and just wants the ordeal to be over.\" He said this incident serves as just another reminder of how dangerous a police officer’s job is. At the scene of the shooting, crime tape closed Potomac from Washtenaw Avenue to California Avenue and encompassed the alley west of the brick apartment building, south of Potomac. Dozens of officers stood in the alley, while even more walked up and down the street. Neighbors gathered at the edge of the yellow tape on the sidewalk along California and watched them work. Standing next to a man, a woman talked to police in the crime scene, across the street. \"We're not under arrest? We can go?\" the woman checked with officers. They told her she could go, and she and the man walked underneath the yellow tape and out of the crime scene."
input_length = 400
output_length = 100

# Process the capitalization with the preprocess_capitalization of the capita package.
article_text = capita.preprocess_capitalization(article_text)

# Numerize the tokens of the processed text using the loaded sentencepiece model.
numerized = sp.EncodeAsIds(article_text)
# Pad the sequence and keep the mask of the input
padded, mask = pad_sequence(numerized, pad_index, input_length)

# Making the news article into a batch of size one, to be fed to the neural network.
encoder_input = np.array([padded])
encoder_mask = np.array([mask])

decoded_so_far = [0]

for j in range(output_length):
    padded_decoder_input, decoder_mask = pad_sequence(decoded_so_far, pad_index, output_length)
    padded_decoder_input = [padded_decoder_input]
    decoder_mask = [decoder_mask]
    print("========================")
    print(padded_decoder_input)
    # Use the model to find the distrbution over the vocabulary for the next word
    batch = (encoder_input,encoder_mask,padded_decoder_input,decoder_mask)
    batch_input, batch_input_mask, batch_output, batch_output_mask = [th.tensor(tensor) for tensor in batch]
    batch = {'source_sequence': batch_input, 'target_sequence': batch_output,
            'encoder_mask': batch_input_mask, 'decoder_mask': batch_output_mask}
    logits = trainer.model(**batch).cpu().detach().numpy()
    
    chosen_words = np.argmax(logits, axis=2) # Take the argmax, getting the most likely next word
    decoded_so_far.append(int(chosen_words[0, j])) # We add it to the summary so far


print("The final summary:")
print("".join([vocab[i] for i in decoded_so_far]).replace("▁", " "))

[[0, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 99

## Word vectors

The model we train learns word representations for each word in our vocabulary. A word represention is a vector of **dim** size.

It is common in NLP to inspect the word vectors, as some properties of language often appear in the embedding structure.


We are going to load the word embeddings learned by our model, and inspect it.
Because our network was not trained for long, we are going for the simplest patterns, but if we let the network train longer, it learns more complex, semantic patterns.

In [102]:
# We help you load the matrix, as it is hidden within the Transformer structure.
E = trainer.model.encoder.embedding_layer.embedding.weight.cpu().detach().numpy()

print("The embedding matrix has shape:", E.shape)
print("The vocabulary has length:", len(vocab))

The embedding matrix has shape: (10000, 320)
The vocabulary has length: 10000


Pronouns serve very similar purposes, therefore we should expect the representation of "he" and "she" to be similar, and have cosine similarity.

- **TODO**:  Find the cosine similarity between the vectors that represent words "she" and "he".
- **TODO**:  Find the cosine similarity between the vectors that represent words "more" and "less".

We can contrast that with the cosine similarity to a random, non-related word, like "ball", or "gorilla".
- **TODO**: Compute the cosine similarity between "she" and "ball".
- **TODO**: Compute the cosine similarity between "more" and "protest".



In [121]:
def cosine_sim(v1, v2):
    # TODO: Implement the cosine similarity of 2 vectors. Careful: the words might not have unit norm.
    output = np.dot(v1, v2) / np.sqrt(np.sum(v1**2) * np.sum(v2**2))
    return output

for w1, w2 in [("she", "he"), ("more", "less"), ("she", "ball"), ("more", "gorilla")]:
    w1_index = vocab.index('▁'+w1) # The index of the first  word in our vocabulary
    w2_index = vocab.index('▁'+w2) # The index of the second word in our vocabulary
    w1_vec = E[w1_index] # Get the embedding vector of the first  word
    w2_vec = E[w2_index] # Get the embedding vector of the second word
    
    print(w1," vs. ", w2, "similarity:",cosine_sim(w1_vec, w2_vec))
validate_to_array(lambda f,i: (f(*i),i), (cosine_sim,tuple(20*np.random.random((2,1000))-1)),'cosine_sim', 'cosine_sim') 

she  vs.  he similarity: 0.0071817436
more  vs.  less similarity: 0.047976136
she  vs.  ball similarity: -0.047493313
more  vs.  gorilla similarity: -0.011038568


  arr = np.asanyarray(arr)


These effects are unfortunately small, as we have only trained the network on a few hours on a few thousand articles.
However, the same model trained for longer on more data exhibits many interesting semantic and syntactic patterns, such as:

- Words vectors with high cosine similarity usually represent words that have semantic similarity (such as duck and pigeon)
- Analogies can occur, a famous case is that of: woman - man + king ≈ queen. Or france - paris + rome ≈ italy.

- Looking at top-k similar words can help find synonyms.

To read examples of more complex patterns that appear in word embedding spaces, read [this blog](https://explosion.ai/blog/sense2vec-with-spacy). To play with a live demo and try similarities on rich word embeddings, [go here.](https://explosion.ai/demos/sense2vec)