In [1]:
import logging
import os
import random
from typing import List
from utils import remove_punct, tokenize

import torch
from transformers import (
    MBartTokenizer,
    MBartForConditionalGeneration,
    BartTokenizer,
    BartForConditionalGeneration,
)

logging.basicConfig(level=logging.INFO)

SEQ_TO_SEQ_MODEL = "mbart"
LANGUAGE = "fi"
DATASET = "wikisource"
SPECIAL_TOKENS = []
CHECKPOINT = ""


language_codes = {
    "en": "en_XX",
    "fi": "fi_FI",
}

########################################################################################################################
# load model and tokenizer
########################################################################################################################

MODEL_DIR = os.path.join(
    os.environ["MODELDIR_UNI"],
    "finnishPoetryGeneration",
    "ukko2",
    "{}-{}-{}".format(DATASET, LANGUAGE, SEQ_TO_SEQ_MODEL),
)
if CHECKPOINT:
    MODEL_DIR = os.path.join(MODEL_DIR, CHECKPOINT)

if SEQ_TO_SEQ_MODEL == "mbart":
    MODEL = "facebook/mbart-large-cc25"

    tokenizer = MBartTokenizer.from_pretrained(
        MODEL,
        src_lang=language_codes[LANGUAGE],
        tgt_lang=language_codes[LANGUAGE],
        additional_special_tokens=SPECIAL_TOKENS,
    )
    model = MBartForConditionalGeneration.from_pretrained(MODEL)
    model.config.decoder_start_token_id = tokenizer.lang_code_to_id[
        language_codes[LANGUAGE]
    ]
elif SEQ_TO_SEQ_MODEL == "bart":
    MODEL = "facebook/bart-large"

    tokenizer = BartTokenizer.from_pretrained(
        MODEL, additional_special_tokens=SPECIAL_TOKENS
    )
    model = BartForConditionalGeneration.from_pretrained(MODEL)
else:
    raise NotImplementedError

model.resize_token_embeddings(len(tokenizer))
logging.info("Model vocab size is {}".format(model.config.vocab_size))
model.load_state_dict(
    torch.load(
        os.path.join(MODEL_DIR, "pytorch_model.bin"), map_location=torch.device("cpu")
    )
)


########################################################################################################################
# Helper functions
########################################################################################################################


def get_ngrams(text: str, n: int) -> List[str]:
    text_chars = "".join(c for c in text if c.isalpha())
    return [text_chars[i : i + n] for i in range(len(text_chars) - n)]


def ngram_similarity(string1: str, string2: str, n: int) -> float:

    ngrams_1 = set(get_ngrams(string1, n))
    ngrams_2 = set(get_ngrams(string2, n))

    common_ngrams = ngrams_1.intersection(ngrams_2)
    all_ngrams = ngrams_1.union(ngrams_2)

    return len(common_ngrams) / len(all_ngrams)


def token_similarity(string1: str, string2: str) -> float:

    tokens_1 = set([t.lower() for t in tokenize(remove_punct(string1))])
    tokens_2 = set([t.lower() for t in tokenize(remove_punct(string2))])

    common_tokens = tokens_1.intersection(tokens_2)
    all_tokens = tokens_1.union(tokens_2)

    return len(common_tokens) / len(all_tokens)


########################################################################################################################
# Generate text
########################################################################################################################


def get_next_line_candidates(
    input_line: str, keywords: List[str] = None, separator: str = ">>>SEP<<<"
) -> List[str]:
    if keywords:
        source = (
            " ".join(random.sample(keywords, max(len(keywords) - 1, 1)))
            + " "
            + separator
            + " "
            + input_line
        )
    else:
        source = input_line
    logging.debug(source)
    encoded = tokenizer.encode(
        source, padding="max_length", max_length=32, truncation=True
    )
    encoded = torch.tensor(encoded).unsqueeze(0)

    sample_outputs = model.generate(
        encoded,
        do_sample=True,
        max_length=16,
        num_beams=5,
        # repetition_penalty=5.0,
        early_stopping=True,
        num_return_sequences=5,
    )

    candidates = [
        tokenizer.decode(sample_output, skip_special_tokens=True)
        for sample_output in sample_outputs
    ]
    logging.info("Generated candidates {}".format(candidates))
    return candidates


def get_next_line(
    input_line: str, keywords: List[str] = None, separator: str = ">>>SEP<<<"
) -> str:

    candidates = get_next_line_candidates(input_line, keywords, separator)

    candidates = [
        candidate for candidate in candidates if remove_punct(candidate)
    ]  # make sure that lines containing only punctuation are excluded

    if not candidates:
        return input_line

    # compute character n-gram similarity and token similarity with the input line for every candidate
    scored_candidates = [
        (
            candidate,
            token_similarity(input_line, candidate),
            ngram_similarity(input_line, candidate, 3),
        )
        for candidate in candidates
    ]

    candidates = [candidate for candidate in scored_candidates if candidate[1] < 0.4]

    if not candidates:
        logging.debug(
            "No candidates with token similarity lower than threshold value. Returning candidate with lower "
            "similarity"
        )
        return sorted(scored_candidates, key=lambda x: x[1])[0][0]

    return sorted(candidates, key=lambda x: x[2])[0][0]


def iterative_generation(
    input_line: str,
    lines: int,
    keywords: List[str] = None,
    separator: str = ">>>SEP<<<",
) -> str:

    out = input_line
    last_line = input_line
    counter = 0
    stanza_length = 0
    while True:
        next_line = get_next_line(last_line, keywords, separator)
        if last_line[-1] == "!" and next_line[-1] == "!":
            next_line = next_line[:-1] + "."
        elif last_line[-1] == "." and next_line[-1] == ".":
            next_line = next_line[:-1] + ","
        logging.info("Generated poetry line '{}'".format(next_line))
        counter += 1
        out += "\n"
        out += next_line
        stanza_length += 1
        if next_line.strip()[-1] in [".", "!", "?"]:
            if counter >= lines:
                break
            elif stanza_length > 1:
                out += "\n"
                stanza_length = 0
        last_line = next_line

    return out

# import warnings
# warnings.filterwarnings("ignore")

print("DONE")

INFO:root:Model vocab size is 250027


DONE


# Finnish Poetry Generation

Current implementation - in short:

 * one generative model takes keywords as input and generates the first poem line;
 * another generative model takes the 1st line as input, and generate the next verse. This model is used iteratively;
 * each generative model returns multiple outputs, users can select the best candidate verse. For protityping, we select the candidate automatically by implementing a simple heuristic;
 * iterative generation is stopped when the generated verse ends with a full stop, and the total number of generated lines is larger than a given value;
 * to simulate stanzas, an empty line is added after verses ending with a full stop.


### Examples - Generating the First Verse

Sequence to sequence model, training examples are built by taking poetry lines and sampling words from the line. 

The extracted words are used as source (mimic keywords), the poem line from which the words are extracted is used as target.

At inference time, the model returns a poem line given some input keywords provided by the user.

mBART model fine-tuned with poetry lines from Gutenberg project and Wikisource (@Sardana can comment on this).


#### Some examples with 2 random keywords

**amerikkalainen onnettomuus**
 * Onnettomuus on minussa, minä amerikkalainen,
 * Ja onnettomuus rasittaa, – minä olen amerikkalainen,

**eläin uudistus**
 * Eläin uudistus. Maa vapisi, puu laski,
 * Eläin uudistus. Maa ja taivas, maa
 * Eläin uudistus. Maa vapisi, puu lakastui,
 * Eläin uudistus. Maa vapisi, puu katosi,
 * Eläin uudistus! tuo ihmeinen eläin, tuo

**uudistus ase**
 * Tule, ase uudistus! Tulkoon valkeus tuo!

**ase korvaus**
 * Korvaus, ase oikea on.
 * Se on ase, millä korvaus
 * Ase, korvaus kalliin hinnan!
 * Ase, korvaus kalliin miehen ja vaivan,
 * Se on ase, jonka korvaus

**korvaus pelkästään**
 * On korvaus pelkästään hurskas.
 * On korvaus pelkästään unelmien hullun.

**pelkästään päätyä**
 * Ei ole meidän päätyä pelkästään.
 * Se elämän on pelkästään päätyä.

**päätyä heikko**
 * Pian alkoi päivä heikko päätyä,
 * Pian alkoi sen matka heikko päätyä.

**heikko säilyttää**
 * Säilyttää, mi heikko on vaan.

**säilyttää kyetä**:
 * Kyetä ei, vaan säilyttää aatetta
 * Kyetä ei, jot'säilyttää vain haluaa.

### Examples - Next-Line Iterative Generation
 
Sequence to sequence model (fine-tuned mBART), training data is build by concatenating lines from poems scraped from finnish wikisource.

First sentence from previous model, with keywords _amerikkalainen, onnettomuus_.

In [2]:
input_line = "Ja tiede koira on : se etsii,"

candidates = list(set([c for c in get_next_line_candidates(input_line) if remove_punct(c)]))

for i, candidate in enumerate(candidates):
    print("{}) {}".format(i + 1, candidate))

out = input_line
next_candidates = candidates

1) Se puuhaa ja puuhaa ja puuhaa,
2) Ja koiran tiede koiran tekee.
3) Kun tiede koira on : se puuhaa,
4) Ja tutkii ja tekee työtä.


This cell can be used multiple times to generate additional lines!

In [4]:
n = 2

next_input_line = next_candidates[n-1]
out += "\n" + next_input_line

print("Partial poem:\n#############\n\n{}\n".format(out))

next_candidates = list(set([c for c in get_next_line_candidates(next_input_line) if remove_punct(c)]))

print("Next-Line Candidates:\n#####################\n")
for i, candidate in enumerate(next_candidates):
    print("{}) {}".format(i + 1, candidate))

Partial poem:
#############

Ja tiede koira on : se etsii,
Ja tutkii ja tekee työtä.
Oi, Herra, anna armosta

Next-Line Candidates:
#####################

1) Sun siunauksens’ aina,
2) Ja voimaa, tarmoa meissä,
3) Ja anteeks’ syntimme,
4) Sun siunattu pelastus!


### Example - Unsupervised Iterative Generation

Used for prototyping, was useful to observe looping behaviour (partially solved by increasing the number of epochs for the next-line generative model).

In [2]:
input_line = "Ja tiede koira on : se etsii,"

poem = iterative_generation(input_line, 5)
print(poem)

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
INFO:root:Generated candidates ['se puuhaa, tutkii ja puuhaa.', 'Ja tutkii ja tekee työtä.', 'Se tutkii ja tekee työtä;', 'Se tutkii ja tekee työtä;', 'Ja tutkii ja tekee työtä.']
INFO:root:Generated poetry line 'se puuhaa, tutkii ja puuhaa.'
INFO:root:Generated candidates ['Oi ihminen, niin monta on murhaa,', 'Ja aina, aina se puuhaa vaan,', 'Niin kauan kuin Suomessa yksikään', 'Se tietää – mut ei tiedä – kumpikaan,', 'Se tietää, ett’ on ihminen se,']
INFO:root:Generated poetry line 'Oi ihminen, niin monta on murhaa,'
INFO:root:Generated candidates ['ei yksi murhaa murhaa.', 'on murhaa, vainoa ja riettautta.', 'on murhaa, vainoa, vainoa.', 'on murhaa, vainoa, vainoa.', 'ei murhaa, vaan taistelua, taistelua.']
INFO:root:Generated p

Ja tiede koira on : se etsii,
se puuhaa, tutkii ja puuhaa.
Oi ihminen, niin monta on murhaa,
ei yksi murhaa murhaa.

Ei estä yksikään, ei!
Vaan vapauden päivä koittaa,
Ja oikeuden aamu koittaa,
Ja oikeutta kaikille se julistaa,
Ja vapauden kieltä se lausuu.
