In [None]:
# Start by installing required libraries (mainly Transformers)
!pip install transformers==4.17.0
!pip install scikit-learn
!pip install hydra-core
!pip install pronouncing
!pip install spacy

In [None]:
# Only needed when running in colab
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)

In [None]:
!git clone https://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@github.com/coderalo/11785-automatic-poetry-generation.git

In [None]:
import copy
import glob
import json
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pronouncing
import random
import shutil
import string as string_utils
import sys
import tempfile
import torch
import torch.nn.functional as F
import torch.optim as optim
import tqdm.notebook as tqdm
import yaml

from hydra import compose
from hydra import initialize_config_dir
from omegaconf import OmegaConf
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM
from transformers import GPT2LMHeadModel
from transformers import GPT2Model
from transformers import GPT2Tokenizer

## Rhyming distance calculation

In [None]:
from spacy.tokenizer import Tokenizer
from spacy.lang.en import English

nlp = English()
tokenizer = nlp.tokenizer

In [None]:
def check_rhyme(limerick):
    assert len(limerick) == 5
    for idx, line in enumerate(limerick):
        while len(line) >= 1 and line[-1] in string_utils.punctuation:
            line = line[:-1]
        limerick[idx] = line

    for line in limerick:
        if line == "":
            return None

    words = []
    for line in limerick:
        words.append(tokenizer(line)[-1].text)

    pairs = [[0, 1], [2, 3], [0, 4], [1, 4]]
    distance = 0.
    for pair in pairs:
        word_0, word_1 = words[pair[0]], words[pair[1]]
        phones_0 = pronouncing.phones_for_word(word_0)
        if phones_0 == []:
            return None
        phones_0 = pronouncing.rhyming_part(phones_0[0])
        phones_1 = pronouncing.phones_for_word(word_1)
        if phones_1 == []:
            return None
        phones_1 = pronouncing.rhyming_part(phones_1[0])
        if phones_0 != phones_1:
            distance += 1 / len(pairs)

    # if flag is False:
    #     print(limerick)
      
    return distance

## Vocabulary coverage calculation

In [None]:
from collections import Counter

In [None]:
data = json.load(open("/content/drive/MyDrive/11-785-final/data/limericks.json", 'r'))
oedilf_word_freq = Counter()
for key, value in data["limericks"].items():
    lines = value["lines"]
    for line in lines:
        words = [token.text for token in tokenizer(line)]
        oedilf_word_freq.update(words)

for punct in string_utils.punctuation:
    if punct in oedilf_word_freq:
        oedilf_word_freq.pop(punct)

In [None]:
def get_word_freq(files):
    generated_word_freq = Counter()

    for filename in files:
        with open(filename, 'r') as file:
            for _ in range(100):
                limerick = []
                for _ in range(5):
                    limerick.append(file.readline().strip())
                file.readline()
                for line in limerick:

                    words = [token.text for token in tokenizer(line)]
                    generated_word_freq.update(words)

    for punct in string_utils.punctuation:
        if punct in generated_word_freq:
            generated_word_freq.pop(punct)

    return generated_word_freq

In [None]:
def get_coverage(oedilf_word_freq, generated_word_freq, min_word_freq):
    top_words = set()
    for word, count in oedilf_word_freq.most_common():
        if count < min_word_freq:
            break
        top_words.add(word)

    covered, total = 0, 0
    for word, count in generated_word_freq.most_common():
        if word in top_words:
            covered += count
        total += count    

    coverage = covered / total
    return coverage