In [1]:
import numpy as np
import torch as torch
import torch.nn as nn
from tqdm import tqdm
from transformers import GPT2Config, GPT2Model
from random import choices, Random, sample, random
from random import seed as randomseed
from matplotlib import pyplot as plt
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # change 'cuda' to something else...?
print(f'Using device: {device}')

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


In [None]:
# TODO put num_symbols inside of the main training loop and parametrize these functions with it
num_symbols = 3

# generator for getting uniformly random transition matrices
# dirichlet_markov_ensemble.sample([n]) returns a tensor of n uniformly random num_symbols x num_symbols transition matrices
dirichlet_markov_ensemble = torch.distributions.dirichlet.Dirichlet(torch.ones((num_symbols,num_symbols), device = device))


# faster method of finding stationary distribution of a markov chain, technically an approximation, and I haven't tested for many different number of symbols. Robust method commented out below
def stationary_distribution(prob):
    return torch.linalg.matrix_power(prob, power).mean(axis=1)
# only for the stationary_distribution function, chosen a bit arbitrarily
power = int(2**np.ceil(np.log2(num_symbols*3)))
# def stationary_distribution(prob):
  
#     evals, evecs = torch.linalg.eig(prob.transpose(1,2))
#     # evec1 = evecs[torch.isclose(evals, torch.ones(1, dtype=torch.complex64)),:]
#     #Since np.isclose will return an array, we've indexed with an array
#     #so we still have our 2nd axis.  Get rid of it, since it's only size 1.
#     #  evec1 = evec1[...,0].real
#     evec1 = evecs[range(len(evecs)),:, torch.argmax(evals.real, dim=1).squeeze()]
#     stationary = (evec1.T / evec1.sum(axis=1)).T
#     # print(stationary)

#     #eigs finds complex eigenvalues and eigenvectors, so you'll want the real part.
#     return stationary.real

# samples autoregressively from markov chains for length symbols. first symbol is sampled from stationary distribution. if transition_matrices is an integer, then that many uniformly random ones are sampled
def data_gen(length, transition_matrices=None):
    if transition_matrices is None:
        transition_matrices = dirichlet_markov_ensemble.sample([64])
    elif type(transition_matrices) is int:
        transition_matrices = dirichlet_markov_ensemble.sample([transition_matrices])
    
    stat_dists = stationary_distribution(transition_matrices)
    # thresholds = transition_matrices.cumsum(axis = 2).tolist()
    output = torch.zeros(len(transition_matrices), length, dtype=int, device = device)
    output[:, 0] = torch.multinomial(stat_dists, 1).squeeze()
    cons = torch.arange(len(transition_matrices), device = device) * 3
    for ind in range(1, length):
        temp = transition_matrices.view(num_symbols*len(transition_matrices), num_symbols)[cons + output[:,ind-1]]
        output[:,ind] = torch.multinomial(temp,1).squeeze()
    return output.to(device)

def mixture_transition_matrices(num_matrices, fixed):
    return torch.stack([dirichlet_markov_ensemble.sample() if random()>0.5 else fixed for _ in range(num_matrices)])

In [None]:
# if we want to calculate how close the model is to various strategies, these will help:
def unigram_predict_next(sample):
    counts = torch.bincount(sample, minlength=32)
    return counts/counts.sum()

@torch.no_grad()
def ngram(inp, num_symbols=num_symbols, n=2):    
    ngrams = zip(*[inp[i:] for i in range(n)])
    candidate = torch.ones(num_symbols, dtype=torch.float)
    check = tuple(inp[-n+1:])
    for i in ngrams:
        if i[:-1] == check or n == 1:
            candidate[i[-1]]+=1
    candidate = F.normalize(candidate, p=1, dim=0)
    return candidate


def model_alg_kl_div(model, test_batch, alg):
    preds = model(test_batch)[:,-1]
    return F.kl_div(F.log_softmax(preds, dim=1), torch.stack([alg(x) for x in test_batch]).to(preds.device), reduction="sum")/ len(test_batch)

In [None]:
class gpt(nn.Module):
    def __init__(self, input_dim, output_dim, drop, hid_dim = 512, n_head = 8, n_layer = 6, max_position = 50):
        super().__init__()
        config = GPT2Config(vocab_size = input_dim, n_embd= hid_dim, n_layer = n_layer, n_head = n_head, \
                            activation_function= 'gelu', n_positions= max_position, \
                             resid_pdrop = drop, embd_pdrop = drop, attn_pdrop = drop, use_cache=False, n_inner = hid_dim*4)
        
        self.GPT2= GPT2Model(config)
        self.lin2= nn.Linear(hid_dim, output_dim)

    def forward(self, x):
        hidden = self.GPT2(input_ids= x, attention_mask = torch.ones_like(x)).last_hidden_state
        last= self.lin2(hidden)

        return last
    
def test(model, loss_fn, test_batch):
    x_t = test_batch
    test_starting_at = 0
    preds_valid = model(x_t[:, test_starting_at:-1])
    return loss_fn(preds_valid.mT, x_t[:, test_starting_at+1:]).cpu()

def test_true(model, loss_fn, test_batch, test_matrices, device):
    x_t = test_batch 
    test_starting_at = 0
    preds_valid = model(x_t[:, test_starting_at:-1])
    labels = test_matrices.to(device)[torch.arange(test_matrices.size(0)).unsqueeze(1), x_t[:, test_starting_at:-1]]
    return F.kl_div(F.log_softmax(preds_valid, dim=-1), labels, reduction="none").mT.mean(dim=[0,1]).cpu()


# training code! 
# length is the length of the training sequences
# every is how often to test the model
# fixed is the fixed transition matrix
# data_gen_batches is how many batches to compute at once. bigger is better effeciency (None sets it to do all) but uses up more memory
# TODO add ability to train on different distributions (mixture, just uniform, just fixed, etc)
def trainer(model, iters, opt, loss_fn, length, every= 2000, device= device, fixed = None, batchsize=64, data_gen_batches = None):
    model.train()
    test_loss = []
    train_loss = []
    if data_gen_batches is None:
        data_gen_batches = iters + 1
    fixed_batch = torch.stack([fixed]*batchsize)
    fixed_test_batch = data_gen(length, fixed_batch).to(device)
    fixed_test_loss = []
    test_batch_matrices = dirichlet_markov_ensemble.sample([batchsize])
    test_batch = data_gen(length, test_batch_matrices).to(device)
    data = data_gen(length, mixture_transition_matrices(batchsize*data_gen_batches, fixed))
    data_index = 0
    #TODO have tqdm show the train loss
    for i in tqdm(range(iters+1), ncols = 100, desc = "Progress", position = 0, leave = True):

        if i%every == 0:
            model.eval()
            with torch.no_grad():
                train_loss.append(test(model, loss_fn, test_batch))
                test_loss.append(test_true(model, loss_fn, test_batch, test_batch_matrices, device))
                fixed_test_loss.append(test_true(model, loss_fn, fixed_test_batch, fixed_batch, device))
            model.train()
        
        x_t = data[data_index:data_index + batchsize]
        data_index+= batchsize
        if data_index >= batchsize*data_gen_batches:
            data = data_gen(length, mixture_transition_matrices(batchsize*data_gen_batches, fixed))
            data_index = 0
        preds= model(x_t[:,:-1])
        loss= loss_fn(preds.mT, x_t[:, 1:])
        loss.backward()
        opt.step()
        opt.zero_grad()

    print("final loss= %f"%(loss.detach().cpu().numpy()))
    
    return torch.stack(test_loss), torch.stack(fixed_test_loss), train_loss

In [None]:
# Main cell to run the training!
seed = 4
torch.manual_seed(seed)
randomseed(seed)
np.random.seed(seed)
fixed = dirichlet_markov_ensemble.sample()
length = 30
model = gpt(3, 3, drop = 0, hid_dim = 16, n_head = 1, n_layer = 2, max_position = length - 1)
model.to(device= device)
model.train()
opt = torch.optim.AdamW(model.parameters(), lr= 1e-3, weight_decay= 0)
loss_fn = nn.CrossEntropyLoss()
print(fixed)
temp = trainer(model, 4000, opt, loss_fn, length, every = 4, device= device, fixed = fixed)
accs, fixed_accs, train_loss = temp

In [None]:
plt.plot(accs.mean(axis=1), label = "Test loss")
plt.plot(fixed_accs.mean(axis=1), label = "Test loss on fixed transition matrix")
plt.xlabel("Amount of training")
plt.ylabel("Test Loss at last token")
plt.title("Training Transformer on MC-ICL\n(3 symbols, length 101 training)")
plt.legend()
plt.show()
pos = -1
plt.plot(accs[-1], label = "Test loss")
plt.plot(fixed_accs[-1], label = "Test loss on fixed transition matrix")
plt.xlabel("Tokens of Context")
plt.ylabel("Test Loss at last token")
plt.title("Training Transformer on Mixture\n(3 symbols, length 101 training)")
plt.legend()
plt.show()
plt.plot(train_loss)
plt.show()
#plot loss at different tokens