In [20]:
## Mymodel definition
from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config, PreTrainedModel
import torch
from torch.nn import CrossEntropyLoss



class MyModel(PreTrainedModel):
    config_class = GPT2Config

    def __init__(self, config):
        super().__init__(config)
        self.encoder = GPT2Model(config)
        self.second_encoder = GPT2Model(config)
        self.decoder = GPT2LMHeadModel(config)

    def forward(self, input_ids, labels=None, attention_mask=None):
        encoder_outputs = self.encoder(input_ids)
        hidden_embedding = encoder_outputs.last_hidden_state[:,-1,:].unsqueeze(1)
        # just to obtain the hidden embeddings
        with torch.no_grad():
            decoder_hidden_inputs = self.second_encoder(input_ids, output_hidden_states=True).hidden_states[0]
        #hidden_embedding_dim = hidden_embedding.shape[2]
        updated_input = torch.cat((hidden_embedding, decoder_hidden_inputs), dim=1)
        logits = self.decoder(inputs_embeds=updated_input)['logits']
        logits = F.log_softmax(logits, dim=-1)
        shifted_prediction_scores = logits[:, 1:-1, :]
        
        labels[attention_mask == 0] = -100 
        labels = labels[:, 1:]
        loss_fct = CrossEntropyLoss()
        lm_loss = loss_fct(shifted_prediction_scores.contiguous().view(-1, self.config.vocab_size), labels.contiguous().view(-1))
        return {'loss': lm_loss, 'logits':logits[:,1:,:]}
    

## defining tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('google-t5/t5-small')
#tokenizer = AutoTokenizer.from_pretrained('gpt2')

context_length = 700
def tokenize(element):
    #print('element is ', len(element['text']))
    #return {'input_ids': []}
    #print('len is ', ('</s>'.join(x) for x in element['text']).type)
    outputs = tokenizer(
        element['text'],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )

    input_batch = []
    for length, input_ids in zip(outputs['length'], outputs['input_ids']):
        #print('last id is ', input_ids[-1])
        if length <= context_length:
            input_batch.append(input_ids)
    #print('batch length is ', len(input_batch))
    return {'input_ids': input_batch}




In [21]:
checkpoint = './model3weights_2024-07-04--16:34:15'
model = MyModel.from_pretrained(checkpoint)


Some weights of MyModel were not initialized from the model checkpoint at ./model3weights_2024-07-04--16:34:15 and are newly initialized: ['decoder.lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
## don't run this!!
from datasets import load_from_disk
tokenized_dataset = load_from_disk('model3-outputs')
tokenized_dataset

In [7]:
## load tinystories dataset in which every row of dataset is a single story
from datasets import load_dataset
ds = load_dataset('roneneldan/TinyStories')


Repo card metadata block was not found. Setting CardData to empty.


In [22]:
tokenized_dataset_raw = ds.map(tokenize, batched=True, remove_columns=ds['train'].column_names)

In [23]:

import numpy as np
tokenized_dataset = tokenized_dataset_raw['validation'].select(np.arange(10))


In [17]:
#tokenized_dataset
input_text = tokenizer.decode(tokenized_dataset['input_ids'][3])
input_text

'Once upon a time, there was a thoughtful girl named Sue. Sue loved to help her mom around the house. One day, her mom asked her to wipe the table after they ate their lunch. Sue was happy to help.\n\nAs Sue was wiping the table, she saw a pretty candle on the window sill. The candle was her mom\'s favorite. Sue wanted to do something nice for her mom, so she said, "Mom, can I light the candle for you?" Her mom said, "Yes, but be very careful."\n\nSue carefully lit the candle and put it on the table. Her mom was so happy to see the pretty candle. They both sat and watched the candle burn. Sue\'s mom said, "Thank you, Sue, for being so thoughtful and careful." Sue felt proud that she could help her mom.\n\nThe moral of the story is to always be thoughtful and careful when helping others.'

In [8]:
## don't  run this!!
import torch
input_ids1 = tokenizer('Hello I am a professor!', return_tensors='pt')
input_ids2 = tokenizer('Hi I am a professor!', return_tensors='pt')
hidden_embedding1 = model.encoder(**input_ids1).last_hidden_state[0,-1,:]
hidden_embedding2 = model.encoder(**input_ids2).last_hidden_state[0,-1,:]
print('norm is ', torch.norm(hidden_embedding1 - hidden_embedding2))


norm is  tensor(3.1669, grad_fn=<LinalgVectorNormBackward0>)


In [None]:
## don't run this!!
hidden_embedding1 = model.encoder(torch.tensor([tokenized_dataset['input_ids'][0]])).last_hidden_state[0,-1,:]
hidden_embedding1
#type(input_ids1['input_ids'])

In [24]:
# find the closest neighbor
def find_similar(input_tokens, list_of_tokens):
    list_of_hiddens = []
    for tokens in list_of_tokens:
        #print('shape is ', len(tokens))
        list_of_hiddens.append(model.encoder(torch.tensor([tokens])).last_hidden_state[0,-1,:])
    #print('shape is ',list_of_hiddens.shape)
    dist_min = 10000
    close_tokens = []
    #print('input tokens')
    hidden_embedding = model.encoder(torch.tensor([input_tokens]))[0]#.last_hidden_state[0,-1,:]
    count = 0
    for hidden_embedding2 in list_of_hiddens:
        #print('hidden_embedding2 shape is ', hidden_embedding2.shape)
        dist = torch.norm(hidden_embedding - hidden_embedding2)
        #print('dist is ', dist)
        if dist < dist_min:
            close_tokens = list_of_tokens[count]
            dist_min = dist
        count += 1
        #print('hidden_embedding2 shape is ', hidden_embedding2.shape)
    return [close_tokens, dist_min]
    #return close_tokens
    #return hidden_embedding

In [25]:
[close_tokens, dist_min] = find_similar(tokenized_dataset['input_ids'][3], tokenized_dataset['input_ids'][4:])
close_text = tokenizer.decode(close_tokens)


In [26]:
import textwrap
def wrap(line):
    broken = textwrap.wrap(line,70, break_long_words=False)
    #print('broken is ', broken)
    return '\n'.join(broken)

w_input_text = wrap(input_text)
w_close_text = wrap(close_text)
print('INPUT TEXT IS: ', w_input_text)
print('CLOSE TEXT IS: ', w_close_text)

INPUT TEXT IS:  Once upon a time, there was a thoughtful girl named Sue. Sue loved to
help her mom around the house. One day, her mom asked her to wipe the
table after they ate their lunch. Sue was happy to help.  As Sue was
wiping the table, she saw a pretty candle on the window sill. The
candle was her mom's favorite. Sue wanted to do something nice for her
mom, so she said, "Mom, can I light the candle for you?" Her mom said,
"Yes, but be very careful."  Sue carefully lit the candle and put it
on the table. Her mom was so happy to see the pretty candle. They both
sat and watched the candle burn. Sue's mom said, "Thank you, Sue, for
being so thoughtful and careful." Sue felt proud that she could help
her mom.  The moral of the story is to always be thoughtful and
careful when helping others.
CLOSE TEXT IS:  Once upon a time, there was a kind farmer. He had a big cow. The cow
was sad. The farmer did not know why. One day, a little boy came to
the farm. He saw the sad cow. The boy knee

In [20]:
## don't run this!!
## look at an arbitray story
str = wrap(tokenizer.decode(tokenized_dataset['input_ids'][3]))
str

'Once upon a time, there was a thoughtful girl named Sue. Sue loved to\nhelp her mom around the house. One day, her mom asked her to wipe the\ntable after they ate their lunch. Sue was happy to help. As Sue was\nwiping the table, she saw a pretty candle on the window sill. The\ncandle was her mom\'s favorite. Sue wanted to do something nice for her\nmom, so she said, "Mom, can I light the candle for you?" Her mom said,\n"Yes, but be very careful." Sue carefully lit the candle and put it on\nthe table. Her mom was so happy to see the pretty candle. They both\nsat and watched the candle burn. Sue\'s mom said, "Thank you, Sue, for\nbeing so thoughtful and careful." Sue felt proud that she could help\nher mom. The moral of the story is to always be thoughtful and careful\nwhen helping others.</s>'

In [40]:
import nbformat as nbf
with open('load_model_hidden_state.ipynb', 'r') as f:
    nb = nbf.read(f, as_version=4)
with open('./load_model_hidden_state/two-similar-stories-found-within-700-stories.ipynb', 'w') as f:
    nbf.write(nb, f)