In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
import os
import json
import torch
import pickle
import random
import math
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, util

import load_data
from load_data import GenderDataset, gender_data_collate_fn
from models.encoder_t5 import EncoderT5
from models.classifier_bert import ClassifierBERT
from models.similarity_sent_enc import encode_for_similarities

In [3]:
from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

In [21]:
print(list(model.transformer.wte.named_parameters()))

[('weight', Parameter containing:
tensor([[-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453],
        [ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],
        [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],
        ...,
        [-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],
        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],
        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]],
       requires_grad=True))]


In [42]:
def calculatePerplexity(sentence, model, tokenizer):
    input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0) 
    input_ids = input_ids.to('cpu')
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
    loss, logits = outputs[:2]
    return math.exp(loss)

In [80]:
def calculatePerplexityAlter(sentence, model, tokenizer):

    input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0) 
    input_ids = input_ids.to('cpu')
    
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
    loss, logits = outputs[:2]

    shift_logits = logits[..., :-1, :].contiguous()

    labels_logits = torch.zeros(1, len(input_ids[0]), tokenizer.vocab_size)
    for i, label in enumerate(input_ids[0]):
        labels_logits[0][i][label] = 1.0
        print(i, label)        
        print(logits[0][i][label])
    shift_labels_logits = labels_logits[..., 1:, :].contiguous()

    # Flatten the tokens
    loss_fct = nn.CrossEntropyLoss()
    print(input_ids, len(input_ids))
    print(labels_logits.size())
    print(shift_logits.size())
    print(shift_labels_logits.size())
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels_logits.view(-1, shift_labels_logits.size(-1)))
    return math.exp(loss)

In [55]:
def calculatePerplexity2(sentence1, sentence2, weight, model, tokenizer):
    input_ids1 = torch.tensor(tokenizer.encode(sentence1)).unsqueeze(0) 
    input_ids2 = torch.tensor(tokenizer.encode(sentence2)).unsqueeze(0) 
    inputs_embeds1 = model.transformer.wte(input_ids1)
    inputs_embeds2 = model.transformer.wte(input_ids2)
    inputs_embeds = (inputs_embeds1 * weight + inputs_embeds2 * (1-weight)).to('cpu')
    with torch.no_grad():
        outputs = model(inputs_embeds=inputs_embeds, labels=input_ids2)
    loss, logits = outputs[:2]
    return math.exp(loss)

In [82]:
def calculatePerplexity2Alter(sentence1, sentence2, weight, model, tokenizer):

    input_ids1 = torch.tensor(tokenizer.encode(sentence1)).unsqueeze(0) 
    input_ids2 = torch.tensor(tokenizer.encode(sentence2)).unsqueeze(0) 
    inputs_embeds1 = model.transformer.wte(input_ids1)
    inputs_embeds2 = model.transformer.wte(input_ids2)
    inputs_embeds = (inputs_embeds1 * weight + inputs_embeds2 * (1-weight)).to('cpu')
    
    with torch.no_grad():
        outputs = model(inputs_embeds=inputs_embeds, labels=input_ids2)
    loss, logits = outputs[:2]

    shift_logits = logits[..., :-1, :].contiguous()

    labels_logits = torch.zeros(1, len(input_ids1[0]), tokenizer.vocab_size)
    for i, label in enumerate(input_ids1[0]):
        labels_logits[0][i][label] += weight
    for i, label in enumerate(input_ids2[0]):
        labels_logits[0][i][label] += (1-weight)
    shift_labels_logits = labels_logits[..., 1:, :].contiguous()

    # Flatten the tokens
    loss_fct = nn.CrossEntropyLoss()
    print(input_ids1, len(input_ids1))
    print(input_ids2, len(input_ids2))
    print(labels_logits.size())
    print(shift_logits.size())
    print(shift_labels_logits.size())
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels_logits.view(-1, shift_labels_logits.size(-1)))
    return math.exp(loss)

In [59]:
t1 = "I feel good today"
t2 = "I feel bad today"
t3 = "I good good today"
t4 = "I bad bad today"

In [56]:
t5 = "I think this has became better"
t6 = "I think this has become better"

In [60]:
print(calculatePerplexity(t1, model, tokenizer))
print(calculatePerplexity(t2, model, tokenizer))

319.5520239164024
406.9546691785979


In [81]:
print(calculatePerplexityAlter(t1, model, tokenizer))
print(calculatePerplexityAlter(t2, model, tokenizer))

0 tensor(40)
tensor(-38.7881)
1 tensor(1254)
tensor(-110.9616)
2 tensor(922)
tensor(-97.1950)
3 tensor(1909)
tensor(-84.6861)
tensor([[  40, 1254,  922, 1909]]) 1
torch.Size([1, 4, 50257])
torch.Size([1, 3, 50257])
torch.Size([1, 3, 50257])
319.55187154215974
0 tensor(40)
tensor(-38.7881)
1 tensor(1254)
tensor(-110.9616)
2 tensor(2089)
tensor(-96.1995)
3 tensor(1909)
tensor(-92.9789)
tensor([[  40, 1254, 2089, 1909]]) 1
torch.Size([1, 4, 50257])
torch.Size([1, 3, 50257])
torch.Size([1, 3, 50257])
406.9544751275362


In [61]:
print(calculatePerplexity2(t1, t2, 0.5, model, tokenizer))
print(calculatePerplexity2(t1, t2, 0, model, tokenizer))
print(calculatePerplexity2(t1, t2, 1, model, tokenizer))
print(calculatePerplexity2(t1, t2, 0.75, model, tokenizer))

296.994513333816
406.9546691785979
269.46985470486254
273.69302759955264


In [84]:
print(calculatePerplexity2Alter(t1, t2, 0.5, model, tokenizer))
print(calculatePerplexity2Alter(t1, t2, 0, model, tokenizer))
print(calculatePerplexity2Alter(t1, t2, 1, model, tokenizer))
print(calculatePerplexity2Alter(t1, t2, 0.75, model, tokenizer))

tensor([[  40, 1254,  922, 1909]]) 1
tensor([[  40, 1254, 2089, 1909]]) 1
torch.Size([1, 4, 50257])
torch.Size([1, 3, 50257])
torch.Size([1, 3, 50257])
323.41764039783124
tensor([[  40, 1254,  922, 1909]]) 1
tensor([[  40, 1254, 2089, 1909]]) 1
torch.Size([1, 4, 50257])
torch.Size([1, 3, 50257])
torch.Size([1, 3, 50257])
406.9544751275362
tensor([[  40, 1254,  922, 1909]]) 1
tensor([[  40, 1254, 2089, 1909]]) 1
torch.Size([1, 4, 50257])
torch.Size([1, 3, 50257])
torch.Size([1, 3, 50257])
319.55187154215974
tensor([[  40, 1254,  922, 1909]]) 1
tensor([[  40, 1254, 2089, 1909]]) 1
torch.Size([1, 4, 50257])
torch.Size([1, 3, 50257])
torch.Size([1, 3, 50257])
311.0192516335582


In [62]:
print(calculatePerplexity(t6, model, tokenizer))
print(calculatePerplexity(t5, model, tokenizer))

127.09786136846539
402.74780255825715


In [83]:
print(calculatePerplexityAlter(t6, model, tokenizer))
print(calculatePerplexityAlter(t5, model, tokenizer))

0 tensor(40)
tensor(-38.7881)
1 tensor(892)
tensor(-109.3005)
2 tensor(428)
tensor(-94.1043)
3 tensor(468)
tensor(-102.1964)
4 tensor(1716)
tensor(-95.2517)
5 tensor(1365)
tensor(-111.6195)
tensor([[  40,  892,  428,  468, 1716, 1365]]) 1
torch.Size([1, 6, 50257])
torch.Size([1, 5, 50257])
torch.Size([1, 5, 50257])
127.09786136846539
0 tensor(40)
tensor(-38.7881)
1 tensor(892)
tensor(-109.3005)
2 tensor(428)
tensor(-94.1043)
3 tensor(468)
tensor(-102.1964)
4 tensor(2627)
tensor(-98.9415)
5 tensor(1365)
tensor(-113.3018)
tensor([[  40,  892,  428,  468, 2627, 1365]]) 1
torch.Size([1, 6, 50257])
torch.Size([1, 5, 50257])
torch.Size([1, 5, 50257])
402.74780255825715


In [86]:
print(calculatePerplexity2Alter(t6, t5, 0, model, tokenizer))
print(calculatePerplexity2Alter(t6, t5, 0.5, model, tokenizer))
print(calculatePerplexity2Alter(t6, t5, 0.75, model, tokenizer))
print(calculatePerplexity2Alter(t6, t5, 1, model, tokenizer))

tensor([[  40,  892,  428,  468, 1716, 1365]]) 1
tensor([[  40,  892,  428,  468, 2627, 1365]]) 1
torch.Size([1, 6, 50257])
torch.Size([1, 5, 50257])
torch.Size([1, 5, 50257])
402.74780255825715
tensor([[  40,  892,  428,  468, 1716, 1365]]) 1
tensor([[  40,  892,  428,  468, 2627, 1365]]) 1
torch.Size([1, 6, 50257])
torch.Size([1, 5, 50257])
torch.Size([1, 5, 50257])
227.46442371626114
tensor([[  40,  892,  428,  468, 1716, 1365]]) 1
tensor([[  40,  892,  428,  468, 2627, 1365]]) 1
torch.Size([1, 6, 50257])
torch.Size([1, 5, 50257])
torch.Size([1, 5, 50257])
171.145853289958
tensor([[  40,  892,  428,  468, 1716, 1365]]) 1
tensor([[  40,  892,  428,  468, 2627, 1365]]) 1
torch.Size([1, 6, 50257])
torch.Size([1, 5, 50257])
torch.Size([1, 5, 50257])
127.09786136846539
