In [4]:
import os 
import pickle
import requests
import numpy as np

__file__ = ""
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)
with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

length of dataset in characters: 1,115,394


In [6]:
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65


In [8]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

In [14]:
from collections import Counter, defaultdict
from itertools import chain

In [32]:
unigrams = dict(Counter(data))
bigrams = dict(Counter(chain(zip(data[::2], data[1::2]), zip(data[1::2], data[2::2]))))
bigrams_cond = defaultdict(dict)
for (w1, w2), cnt in bigrams.items():
    bigrams_cond[w1][w2] = cnt

In [48]:
from dataclasses import dataclass
import itertools 
import logging
import random 
import math
import numpy as np
import pickle
import time
import sys
from typing import List, Optional, Tuple

In [50]:
logging.getLogger().setLevel(logging.INFO)

In [52]:
@dataclass
class DataArgs:
    k: int = 0
    seq_length: int = 256
    show_latents: bool = False
    fixed_special_toks: bool = False
    special_toks_offset: int = 0
    output_counter: bool = True
    no_repeat: bool = False

In [54]:
class Dataset:
    def __init__(self, args: DataArgs,
                 train_test: Optional[str] = None,
                 bigram_outs: Optional[bool] = False):
        self.k = args.k
        self.seq_length = args.seq_length
        self.show_latents = args.show_latents
        self.train_test = train_test
        self.output_counter = args.output_counter
        self.no_repeat = args.no_repeat
        self.bigram_outs = bigram_outs

        # init distributions
        self.meta = pickle.load(open('data/meta.pkl', 'rb'))
        self.itos = self.meta['itos']
        self.stoi = self.meta['stoi']
        self.num_tokens = self.meta['vocab_size']
        self.tok_range = list(np.arange(self.num_tokens))
        self.n_train_toks = self.num_tokens

        self.marginal = np.zeros(self.num_tokens)
        for k, cnt in self.meta['unigrams'].items():
            self.marginal[self.stoi[k]] = cnt
        self.marginal /= self.marginal.sum()

         # conditionals
        self.cond = [np.zeros(self.num_tokens) for _ in range(self.num_tokens)]
        for (w1, w2), cnt in self.meta['bigrams'].items():
            self.cond[self.stoi[w1]][self.stoi[w2]] += cnt
        for i in range(self.num_tokens):
            self.cond[i] /= self.cond[i].sum()

        # special tokens
        self.idxs = None
        if args.fixed_special_toks:
            # use unigram marginals
            self.idxs = list(self.marginal.argsort()[
                             self.num_tokens-self.k-args.special_toks_offset:self.num_tokens-args.special_toks_offset])
            

In [9]:
import torch

vocab_size = 3
alpha = 1.
order = 1
rho = 0.5
num_state_order = vocab_size * order
random_rows_size = int(rho * num_state_order)
batch_size = 4
epochs = 1
num_samples = epochs * batch_size


dirichlet_dist = torch.distributions.Dirichlet(torch.ones(vocab_size, device="cpu")*alpha)
base_trans_mat = dirichlet_dist.sample((num_state_order,))
base_trans_mat /= base_trans_mat.sum(dim=-1, keepdim=True)
random_rows = torch.randperm(num_state_order)[:random_row_size]



trans_mat = base_trans_mat.unsqueeze(0).repeat((num_samples, 1, 1)) # Shape: (num_samples, num_states_order, num_states)
trans_random = dirichlet_dist.sample((num_samples, random_rows_size,))  # Shape: (num_samples, random_rows_size, num_states)
trans_mat[:, random_rows] = trans_random


In [25]:
# Fixed Random Markov chain sampler
class FRMarkovSampler:
    def __init__(self, config):
        self.seq_len = config.seq_len
        self.num_states = config.vocab_size
        self.order = config.order
        self.num_states_order = self.num_states ** self.order
        self.batch_size = config.batch_size
        self.test_size = config.test_size
        self.device = config.device
        self.dirichlet_dist = torch.distributions.Dirichlet(torch.ones(self.num_states, device=self.device)*config.alpha)
        self.powers = (self.num_states ** torch.arange(self.order - 1, -1, -1, device=self.device)).long()
        # Sample all transition probabilities in one go
        self.base_trans_matrix = self.dirichlet_dist.sample((self.num_states_order,))  # Shape: (num_states_order, num_states)
        self.base_trans_matrix /= self.base_trans_matrix.sum(dim=1, keepdim=True)
        self.random_rows_size = int(config.rho * self.num_states_order) # proportion of rows that have a random transition
        self.random_rows = torch.randperm(self.num_states_order)[:self.random_rows_size] # pick random rows
        
    
    def generate(self, epochs=1, mode:str="train")-> torch.Tensor:
        num_samples = self.batch_size if mode == "train" else self.test_size
        num_samples *= epochs
        trans_mat = self.base_trans_matrix.unsqueeze(0).repeat((num_samples, 1, 1)) # Shape: (num_samples, num_states_order, num_states)
        trans_random = self.dirichlet_dist.sample((num_samples, self.random_rows_size,))  # Shape: (num_samples, random_rows_size, num_states)
        trans_mat[:, self.random_rows] = trans_random
        
        range_vecs = torch.arange(num_samples, device=self.device)
        
        # Initialize the samples tensor
        samples = torch.zeros((num_samples, self.seq_len), dtype=torch.long, device=self.device)
        
        state = torch.randint(high=self.num_states, size=(num_samples, self.order), device=self.device) # Shape: (num_samples, order)
        samples[:, :self.order] = state
            
        for t in range(self.order, self.seq_len):
            state_indices = torch.sum(state * self.powers, dim=1) #shape: (num_samples,)
            probs = trans_mat[range_vecs, state_indices, :]  # Shape: (num_samples, num_states)
            
            # Sample the next states for the entire batch
            next_states = torch.multinomial(probs, num_samples=1).squeeze(1)
            
            # Update the sequence with the sampled next states
            samples[:, t] = next_states
            
            # Update the state window (shift left and append the new state)
            state[:, :-1] = state[:, 1:]  # Shift left
            state[:, -1] = next_states    # Append new state
        
        return samples.reshape(epochs, -1, self.seq_len), probs.reshape(epochs, -1, self.num_states)

In [29]:
from dataclasses import dataclass

@dataclass
class test_config:
    seq_len:int = 20
    vocab_size:int = 3
    order:int = 2
    batch_size:int = 4
    test_size:int = 3
    device:str = "cpu"
    alpha:float = 1.
    rho:float = 0.2

config = test_config
sampler = FRMarkovSampler(config)
sampler.generate()

(tensor([[[2, 1, 2, 0, 0, 1, 1, 1, 1, 2, 0, 0, 1, 1, 1, 2, 1, 2, 1, 2],
          [0, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2],
          [1, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0],
          [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 0, 0, 1, 2, 0, 2, 0, 2]]]),
 tensor([[[0.0060, 0.0371, 0.9569],
          [0.0060, 0.0371, 0.9569],
          [0.0801, 0.1762, 0.7437],
          [0.5488, 0.2621, 0.1891]]]))