In [None]:
import csv

with open("data/songs.csv") as f:
    reader = csv.reader(f)
    header = next(reader)
    data = [row for row in reader] 

In [None]:
data[0]

In [None]:
from transformers import GPT2TokenizerFast, PreTrainedTokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained('openai-community/gpt2')
tokenizer.bos_token = '<s>'
tokenizer.eos_token = '</s>'
tokenizer.pad_token = '<|endoftext|>'

In [None]:
import numpy as np
from tqdm import tqdm
from typing import List, Union, Dict

class TrigramLM:
    def __init__(self, tokenizer: PreTrainedTokenizerFast):
        self.tokenizer = tokenizer
        self.vocab_size = tokenizer.vocab_size + 2 # for <s> and </s>
        # TODO: check if this is correct and if we need to add <|endoftext|> to the vocab
        self.unigram_counts = {}
        self.bigram_counts = {}
        self.trigram_counts = {}

    def tokenize(self, text: str) -> List[str]:
        return self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))

    def train(self, data: List[List[str]]) -> None:
        tokenized_data = [self.tokenize(i) for i in data]

        for row in tqdm(tokenized_data, desc="Counting"):
            row = ["<s>"] + row + ["</s>"]

            for j,_ in enumerate(row):
                self.unigram_counts[row[j]] = self.unigram_counts.get(row[j], 0) + 1

                if j > 0:
                    if row[j-1] not in self.bigram_counts:
                        self.bigram_counts[row[j-1]] = {}
                    self.bigram_counts[row[j-1]][row[j]] = self.bigram_counts[row[j-1]].get(row[j], 0) + 1

                if j > 1:
                    if row[j-2] not in self.trigram_counts:
                        self.trigram_counts[row[j-2]] = {}
                    if row[j-1] not in self.trigram_counts[row[j-2]]:
                        self.trigram_counts[row[j-2]][row[j-1]] = {}
                    self.trigram_counts[row[j-2]][row[j-1]][row[j]] = self.trigram_counts[row[j-2]][row[j-1]].get(row[j], 0) + 1
        return None
    
    def add_one_smoothed_prob(self, n_counts: Union[int, List[int]], d_counts: int) -> Union[float, List[float]]:
        if isinstance(n_counts, int):
            return (n_counts + 1) / (d_counts + self.vocab_size)
        else:
            return [self.add_one_smoothed_prob(n, d_counts) for n in n_counts]

    def nextProb(self, history_toks: List[str], next_toks: List[str]) -> float:
        assert hasattr(history_toks, "__len__")
        if len(history_toks) == 0:
            n_counts = [self.unigram_counts.get(tok, 0) for tok in next_toks]
            d_counts = self.bigram_count

        elif len(history_toks) == 1:
            prev_tok = history_toks[0]
            n_counts = [self.bigram_counts.get(prev_tok, {}).get(tok, 0) for tok in next_toks]
            d_counts = self.unigram_counts.get(prev_tok, 0)

        else:
            prev_tok1, prev_tok2 = history_toks[-2:]
            n_counts = [self.trigram_counts.get(prev_tok1, {}).get(prev_tok2, {}).get(tok, 0) for tok in next_toks]
            d_counts = self.bigram_counts.get(prev_tok1, {}).get(prev_tok2, 0)

        return self.add_one_smoothed_prob(n_counts, d_counts)


In [None]:
tokenized_data = [tokenizer.convert_ids_to_tokens(tokenizer.encode(row[2])) for row in data]
# TODO: check if there is an issue due to sequence length > 1024
# TODO: check if custom newline handling is needed

In [None]:
lmodel = TrigramLM(tokenizer)
lmodel.train([i[2] for i in data])

In [None]:
"|<endoftext>|" in tokenizer.vocab

In [None]:
lmodel.nextProb(["<s>", "I"], ["Ġremember", "Ġwhen", "ĠI", "Ġwas", "Ġyoung"])

In [None]:
sum([k for i, j in lmodel.trigram_counts.items() for m, n in j.items() for k in n.values()])

In [None]:
len(lmodel.unigram_counts)

In [None]:
lmodel.vocab_size

In [None]:
sum([len(row[2]) for row in data])

In [None]:
142982/490572