In [19]:
""" 
Notebook used to compute
logperplexity differences for snippets of text
"""

from modules import get_tokenizer
from modules import MinGPT_Trainer, MinGPT
import os,torch
import torch.nn.functional as F

tokenizer_location = 'modules/tokenizers/en_tokenizer'

f_model_location = 'saved_models/en_med'
f_model_name = 'en_med'

b_model_location = 'saved_models/en_med_b'
b_model_name = 'en_med_b'	

toki = get_tokenizer(m_path=tokenizer_location)

def load_model(model_location,model_name):
    """
    Loads a model which has been exported with
    export_model.

    Args :
        model_location : folder containing exported .py and .config files.
        model_name : name of exported model, i.e. [name].py
    """
    model = MinGPT(**torch.load(os.path.join(model_location,model_name+'.config')))
    model.load_state_dict(torch.load(os.path.join(model_location,model_name+'.pt')))
    return model

def export_model(state_path, save_location, model_name):
    MinGPT_Trainer.save_model_from_state(state_path, save_location, name=model_name)


In [20]:
model_f = load_model(f_model_location,f_model_name)
model_b = load_model(b_model_location,b_model_name)

number of parameters: 405.50M
Without head : 354.04M
number of parameters: 405.50M
Without head : 354.04M


In [20]:
export_model('en_med_backwards.state','saved_models/en_med_b','en_med')   

Saved weights of en_med.pt at saved_models/en_med_b/en_med.pt  !


In [21]:
def tokenize_with_bos(text,tokenizer,backward=False):
    text = tokenizer.tokenize(text)

    if(backward):
        text=torch.flip(text,[1])
    # return text
    return torch.cat([torch.tensor([[0]]),text],dim=1)

In [22]:

text_loc = 'sample_text.txt'
text_to_test = open(text_loc).read()

tokenized = tokenize_with_bos(text_to_test,toki)
inv_tokenized = tokenize_with_bos(text_to_test,toki,backward=True)
# print(tokenized)
# print(inv_tokenized)


print(toki.detokenize(tokenized)[:50])
print(toki.detokenize(inv_tokenized.flip([1]))[:50])

once upon a time, there was a prince. He was absol
once upon a time, there was a prince. He was absol


In [23]:
def get_perplexities(text,model,tokenizer,backward=False):
    text = tokenize_with_bos(text,tokenizer,backward) # (1,seq_len)
    model.eval()

    if(text.shape[1]>256):
        print('text too long, truncating')
        text = text[:,:256]
    print('end of text : ',toki.detokenize(text[:,-10:]))

    with torch.no_grad():
        out = model(text) # (1,seq_len,vocab_size)
        B,T,V = out.size()
        out = out.reshape(B*T,V)
        text = text.reshape(B*T)
        loss = F.cross_entropy(out[:-1,:], text[1:] ,reduction='none')
        return loss

f_perp = get_perplexities(text_to_test,model_f,toki,backward=False)
b_perp = get_perplexities(text_to_test,model_b,toki,backward=True)

print('Forward perp: ', f_perp, ' : ', f_perp.mean())
print('Back perp: ', b_perp, ' : ', b_perp.mean())

end of text :   best there was, no question.
end of text :   wasere th, time a upononce
Forward perp:  tensor([4.4576e+00, 1.5666e-02, 2.5615e-02, 1.7123e+00, 2.1794e+00, 3.3385e-04,
        8.7502e-01, 7.2953e-05, 4.9712e-01, 5.8143e-01, 7.0945e+00, 3.2831e+00,
        2.5014e+00, 9.4023e-01, 7.6368e+00, 3.2378e+00, 4.5349e-04, 3.2358e+00,
        6.3885e+00, 3.8427e-03, 8.4452e-01, 3.6221e-04, 4.8725e-01, 1.7730e+00,
        4.8164e+00, 2.3415e+00, 9.9012e-01])  :  tensor(2.0711)
Back perp:  tensor([7.5085e+00, 4.9830e+00, 6.2113e-01, 7.5087e+00, 2.5591e+00, 1.1853e-03,
        7.3156e-04, 1.2159e-02, 5.4379e+00, 2.2889e-01, 1.5461e-02, 6.2103e+00,
        4.3245e-01, 3.9755e+00, 1.5103e-01, 1.0787e+01, 1.7348e+00, 2.4978e+00,
        3.1441e+00, 2.6153e-02, 5.3300e-03, 1.4460e-02, 1.3916e+00, 3.7435e+00,
        2.5579e-01, 3.6349e-03, 5.6981e+00])  :  tensor(2.5536)
