In [61]:
#References https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torchtext.data.metrics import bleu_score


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [62]:
SOS_token = 0
EOS_token = 1

In [63]:
class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1


In [64]:
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )


In [65]:
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    # s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [66]:
import os



def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines

    lines = open('./task_2_event_summarization_train.tsv', encoding='utf-8').\
        read().strip().split('\n')


    # Split every line into pairs and normalize


    modified_lines = []
    for l in lines:
        line = l.split('\t')
        s = " ".join(line[:-1])
        e = line[-1]
        modified_lines.append("\t".join([s, e]))
    print(len(modified_lines))
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in modified_lines]
    print(len(lines))
    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)
    

    return input_lang, output_lang, pairs

In [67]:
MAX_LENGTH = 250


In [68]:
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    _,_,pairs2 = readValidLangs(lang1, lang2,reverse)
    
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    for pair in pairs2:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])

    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

In [69]:
input_lang, output_lang, pairs = prepareData('keywords', 'english', False)
print(random.choice(pairs))

Reading lines...
90001
90001
Reading lines...
11704
11704
Read 90001 sentence pairs
Trimmed to 90001 sentence pairs
Counting words...
Counted words:
eng 25392
fra 98017
['03-january-2020 twitter 3 explosions/remote violence air/drone strike military forces of the united states (2017-2021) 8 al shabaab 2 28 bacaw', '03 january 2020 . us carried out an airstrike in bacaw and killed three al shabaab militants .']


In [70]:
pairs[1]

['18-december-2019 vanguard (nigeria) 0 protests peaceful protest protesters (nigeria) 6  0 60 iwhreka',
 'on 18 december 2019, tens of women demonstrated over the recent violence between edjophe and iwhreka communities . [size=tens]']

In [71]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [72]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [74]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]


def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)


def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

In [75]:
teacher_forcing_ratio = 0.5

def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

    loss = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(
            input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[SOS_token]], device=device)

    decoder_hidden = encoder_hidden

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
 
    if use_teacher_forcing:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(
                 decoder_input, decoder_hidden)
            loss += criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di] 

    else:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach() 

            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                break

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [76]:
import time
import math


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [77]:
def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
    start = time.time()
    plot_losses = []
    print_loss_total = 0 
    plot_loss_total = 0  

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    training_pairs = [tensorsFromPair(random.choice(pairs))
                      for i in range(n_iters)]
    criterion = nn.CrossEntropyLoss()
    p_score = []
    for iter in range(1, n_iters + 1):

        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train(input_tensor, target_tensor, encoder,
                     decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
                                         iter, iter / n_iters * 100, math.exp(print_loss_avg)))
            p_score.append(math.exp(print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
    print("P_SCORE",sum(p_score)/len(p_score))


In [79]:
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei],
                                                     encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]

        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS

        decoder_hidden = encoder_hidden

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        for di in range(max_length):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden)
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(output_lang.index2word[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoder_attentions[:di + 1]

In [80]:
%pip install rouge
%pip install nltk



In [81]:
from rouge import Rouge

from nltk.translate.bleu_score import sentence_bleu
reference = [['this', 'is', 'a', 'test'], ['this', 'is' 'test']]
candidate = ['this', 'is', 'a', 'test']
from torchtext.data.metrics import bleu_score

def evaluateRandomly(encoder, decoder, n=10):
    hyps = []
    refs = []
    bleu = []
    for i in range(n):
        candidate_corpus = []
        references_corpus = []

        pair = random.choice(pairs)
        output_words, attentions = evaluate(encoder, decoder, pair[0])
        output_sentence = ' '.join(output_words)
        candidate_corpus.append(pair[0].split())
        references_corpus.append([output_words])
        reference = [output_words]
        candidate = pair[0].split()
        hyps.append(output_sentence)
        refs.append(pair[1])
        bleu.append(sentence_bleu(reference, candidate))
        
        print("INPUT",pair[0])
        print("OUTPUT",output_sentence)
        print("EXP",pair[1])
        print("BLEU ",sentence_bleu(reference, candidate))
    rouge = Rouge()
    scores = rouge.get_scores(hyps, refs, avg=True)
    print("ROUGE =",scores)
    print("BLEU =",sum(bleu)/len(bleu))



In [82]:
def evaluateRandomlyWith(encoder, decoder,pairs, n=10):
    hyps = []
    refs = []
    bleu = []
    for i in range(n):
        candidate_corpus = []
        references_corpus = []

        pair = random.choice(pairs)
        output_words, attentions = evaluate(encoder, decoder, pair[0])
        output_sentence = ' '.join(output_words)
        candidate_corpus.append(pair[0].split())
        references_corpus.append([output_words])
        reference = [output_words]
        candidate = pair[0].split()
        hyps.append(output_sentence)
        refs.append(pair[1])
        bleu.append(sentence_bleu(reference, candidate))
        
        print("INPUT",pair[0])
        print("OUTPUT",output_sentence)
        print("EXP",pair[1])
        print("BLEU ",sentence_bleu(reference, candidate))
    rouge = Rouge()
    scores = rouge.get_scores(hyps, refs, avg=True)
    print("ROUGE =",scores)
    print("BLEU =",sum(bleu)/len(bleu))

In [83]:
from torchtext.data.metrics import bleu_score
def evaluateAll(encoder, decoder):
    candidate_corpus = []
    references_corpus = []
    
    for pair in pairs:
        output_words, attentions = evaluate(encoder, decoder, pair[0])
        
        candidate_corpus.append(pair[0].split())
        references_corpus.append(output_words)
        print("evaluateAll",candidate_corpus[-1],references_corpus[-1])
    print(bleu_score(candidate_corpus, references_corpus))

In [84]:
hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = DecoderRNN(hidden_size, output_lang.n_words).to(device)

trainIters(encoder1, attn_decoder1, 75000           , print_every=5)

0m 1s (- 35m 37s) (5 0%) 75367.7333
0m 3s (- 38m 56s) (10 0%) 77431.8656
0m 4s (- 40m 15s) (15 0%) 145598.6177
0m 5s (- 36m 15s) (20 0%) 61407.0795
0m 7s (- 35m 20s) (25 0%) 164295.6714
0m 8s (- 33m 24s) (30 0%) 1430.8736
0m 8s (- 30m 1s) (35 0%) 27.3658
0m 9s (- 28m 32s) (40 0%) 466.3106
0m 9s (- 27m 17s) (45 0%) 454.9236
0m 10s (- 27m 13s) (50 0%) 11376.6084
0m 12s (- 27m 43s) (55 0%) 18545.3968
0m 13s (- 27m 40s) (60 0%) 19515.6571
0m 14s (- 26m 49s) (65 0%) 650.2898
0m 14s (- 26m 2s) (70 0%) 1247.5989
0m 15s (- 25m 45s) (75 1%) 3109.0999
0m 16s (- 24m 47s) (80 1%) 72.9207
0m 16s (- 23m 50s) (85 1%) 24.0159
0m 17s (- 23m 36s) (90 1%) 5794.7305
0m 18s (- 23m 58s) (95 1%) 1907.8453
0m 19s (- 23m 54s) (100 1%) 38433.6711
0m 20s (- 23m 32s) (105 1%) 3939.8533
0m 21s (- 23m 46s) (110 1%) 10401.4237
0m 22s (- 23m 33s) (115 1%) 80821.8222
0m 23s (- 24m 9s) (120 1%) 1254985.3906
0m 24s (- 24m 25s) (125 1%) 2459966.3919
0m 25s (- 24m 28s) (130 1%) 575742.6235
0m 26s (- 24m 22s) (135 1%) 1686

In [85]:
evaluateRandomly(encoder1, attn_decoder1)
# evaluateAll(encoder1, attn_decoder1)

INPUT 06-december-2019 club mozambique 0 violence against civilians attack islamist militia (mozambique) 3 civilians (mozambique) 7 37 ingoane
OUTPUT on january january 2019, members of the militia attacked and of the . attacked of and . the . . . . the . . . . . . the . . . . <EOS>
EXP on december 6 2019, aswj attacked the village of ingoane (macomia, cabo delgado) . houses were burnt . no fatalities reported .
BLEU  0.17082308213961087
INPUT 18-march-2021 radio okapi 0 protests peaceful protest protesters (democratic republic of congo) 6  0 60 kisangani airport
OUTPUT on 21 january 2020, the of demonstrated demonstrated of the the the their the [size=no report] . <EOS>
EXP on 18 march 2021, agents of the airways authority (rva) gathered outside the company's office at the kisangani airport (kisangani, tshopo) to protest against delayed salary payments, demanding more than 75 month of delayed wages . [size=no report]
BLEU  0.46434528006147097
INPUT 08-august-2018 daily monitor (uganda

Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().


INPUT 17-may-2020 calamada 2 battles armed clash al shabaab 2 amisom: african union mission in somalia (2007-) (kenya) 8 28 kolbiyow
OUTPUT on november . 2020, al shabaab militants carried out and amisom amisom amisom and amisom amisom and amisom amisom and amisom amisom and run lower shabelle) . one soldier were reported . <EOS>
EXP on 17 may 2020, al shabaab militants attacked the amisom/jubaland security forces base at kolbiyow town (badhaadhe, lower juba) . al shabaab claimed to have killed two soldiers after exchange of heavy gunfire from both sides and al shabaab militants retreated from the area .
BLEU  0.14061128804120715
INPUT 17-october-2018 sun (nigeria) 0 protests peaceful protest protesters (nigeria) 6  0 60 alagbaka
OUTPUT on 13 january 2020, supporters of the demonstrated demonstrated at the lga, of the the . the [size=no the report] . the [size=no report] <EOS>
EXP 17 october . pensioners and members of nup protested in front of the governor's home in alagbaka over outs

In [87]:
def readValidLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines

    lines = open('./task_2_event_summarization_valid.tsv', encoding='utf-8').\
        read().strip().split('\n')


    # Split every line into pairs and normalize


    modified_lines = []
    for l in lines:
        line = l.split('\t')
        s = " ".join(line[:-1])
        e = line[-1]
        modified_lines.append("\t".join([s, e]))
    print(len(modified_lines))
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in modified_lines]
    print(len(lines))
    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

In [88]:
_,_,pairs2 = readValidLangs('','',False)

evaluateRandomlyWith(encoder1, attn_decoder1,pairs2,100)


Reading lines...
11704
11704
INPUT 02-april-2019 this day (nigeria); vanguard (nigeria) 5 violence against civilians attack fulani ethnic militia (nigeria) 4 civilians (nigeria) 7 47 gaambetiev
OUTPUT on 13 january 2021, unidentified gunmen attacked and killed and and at the village the at the . the at the . the . <EOS>
EXP 02 april . fulani militias attacked mondo village in gaambetiev area . five residents killed
BLEU  0
INPUT 16-january-2019 newsday (zimbabwe) 0 violence against civilians attack police forces of zimbabwe (2017-) 1 civilians (zimbabwe) 7 17 kuwadzana
OUTPUT on 21 january 2020, police officers shot and injured a police officers at the police station of the . the <EOS>
EXP on jan 16th, on the third day of protests against fuel price hikes, police and soldiers went door to door in mabvuku, kuwaduana and marondera of mashonaland beating people up and forcing them to remove barricades in the streets .
BLEU  0.51268887076046
INPUT 05-january-2022 sun (nigeria) 0 strategic 

Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().


INPUT 08-december-2019 menastream; whatsapp 3 battles armed clash military forces of burkina faso (2015-) 1 jnim: group for support of islam and muslims and/or islamic state (west africa) - greater sahara faction 2 12 ti-n-agadel
OUTPUT on 21 january 2019, iswap (greater sahara) and seized seized and seized seized and seized seized and seized seized and seized seized and seized seized and seized seized and seized and seized seized and seized of and . the the . . <EOS>
EXP on 8 december 2019, the burkinabe army killed three suspected jnim and/or isgs militants in the area of tin agadel, and seized two motorbikes, two ak rifles, a grenade, possible explosives and ied-making materials .
BLEU  0.37794400911552684
INPUT 30-january-2019 flash burkina 0 protests peaceful protest protesters (burkina faso) 6  0 60 hounde
OUTPUT on 21 december 2019, students demonstrated at the school of the school of the school of the school of the school of the school . [size=no report] <EOS>
EXP on january 30

Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().


INPUT 15-september-2021 twitter 0 protests peaceful protest protesters (democratic republic of congo) 6  0 60 moanda
OUTPUT on 21 december 2019, students demonstrated in front of the the city of the the report] . the [size=no report] . <EOS>
EXP on 15 september 2021, a group of people, including lamuka activists, gathered in moanda (moanda, kongo-central) to demand the depoliticization of the country's electoral commission ahead of the 2023 presidential elections . [size=no report]
BLEU  0.3186445002375616
INPUT 17-september-2020 herald (south africa) 0 protests peaceful protest protesters (south africa) 6  0 60 gqeberha
OUTPUT on 21 march 2020, the demonstrated of the the the the the the the . the the the . the the the . the the . report] the the . the report] the . the report] the . report] the . report] the <EOS>
EXP around 17 september 2020 (as reported), civil society, churches, unions and business owners demonstrated outside the city hall in nelson mandela bay also known as port 

Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().


 0.08386721022543209
INPUT 22-august-2019 7 sur 7; politico (drc) 0 riots mob violence rioters (democratic republic of congo) 5 civilians (democratic republic of congo) 7 57 lubumbashi
OUTPUT on 21 january 2021, a group of people attacked killed and injured a woman and injured a woman of the woman and his people of the the . his people and injured a woman of the . the the . report] . the . <EOS>
EXP on 22 august, in lubumbashi, supporters of football club saint eloi lupopo attacked a residence of a civilian, causing damage to the property . [size=no report]
BLEU  0.2239736960328794
INPUT 06-april-2018 twitter 3 battles armed clash cross river communal militia (nigeria) 4 cross river communal militia (nigeria) 4 44 obubra
OUTPUT on 13 january 2021, unidentified gunmen clashed with the unidentified militia and clashed with the lga, and the at . the lga, of the . the . . . the . . the . . . the . . <EOS>
EXP 06 april . 5 people were reported killed and over 2,000 displaced in a clash betw