In [4]:
import os 
import pickle
import requests
import numpy as np

__file__ = ""
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)
with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

length of dataset in characters: 1,115,394


In [6]:
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65


In [8]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

In [14]:
from collections import Counter, defaultdict
from itertools import chain

In [32]:
unigrams = dict(Counter(data))
bigrams = dict(Counter(chain(zip(data[::2], data[1::2]), zip(data[1::2], data[2::2]))))
bigrams_cond = defaultdict(dict)
for (w1, w2), cnt in bigrams.items():
    bigrams_cond[w1][w2] = cnt

In [48]:
from dataclasses import dataclass
import itertools 
import logging
import random 
import math
import numpy as np
import pickle
import time
import sys
from typing import List, Optional, Tuple

In [50]:
logging.getLogger().setLevel(logging.INFO)

In [52]:
@dataclass
class DataArgs:
    k: int = 0
    seq_length: int = 256
    show_latents: bool = False
    fixed_special_toks: bool = False
    special_toks_offset: int = 0
    output_counter: bool = True
    no_repeat: bool = False

In [54]:
class Dataset:
    def __init__(self, args: DataArgs,
                 train_test: Optional[str] = None,
                 bigram_outs: Optional[bool] = False):
        self.k = args.k
        self.seq_length = args.seq_length
        self.show_latents = args.show_latents
        self.train_test = train_test
        self.output_counter = args.output_counter
        self.no_repeat = args.no_repeat
        self.bigram_outs = bigram_outs

        # init distributions
        self.meta = pickle.load(open('data/meta.pkl', 'rb'))
        self.itos = self.meta['itos']
        self.stoi = self.meta['stoi']
        self.num_tokens = self.meta['vocab_size']
        self.tok_range = list(np.arange(self.num_tokens))
        self.n_train_toks = self.num_tokens

        self.marginal = np.zeros(self.num_tokens)
        for k, cnt in self.meta['unigrams'].items():
            self.marginal[self.stoi[k]] = cnt
        self.marginal /= self.marginal.sum()

         # conditionals
        self.cond = [np.zeros(self.num_tokens) for _ in range(self.num_tokens)]
        for (w1, w2), cnt in self.meta['bigrams'].items():
            self.cond[self.stoi[w1]][self.stoi[w2]] += cnt
        for i in range(self.num_tokens):
            self.cond[i] /= self.cond[i].sum()

        # special tokens
        self.idxs = None
        if args.fixed_special_toks:
            # use unigram marginals
            self.idxs = list(self.marginal.argsort()[
                             self.num_tokens-self.k-args.special_toks_offset:self.num_tokens-args.special_toks_offset])
            