<a href="https://colab.research.google.com/github/navjot12/improving_empathetic_nlg/blob/main/Persona_Embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import auth
auth.authenticate_user()

# https://cloud.google.com/resource-manager/docs/creating-managing-projects
project_id = 'improving-empathetic-nlg'
!gcloud config set project {project_id}

Updated property [core/project].


In [2]:
!pip install datasets
!pip install transformers
!pip install sentencepiece

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
from datetime import datetime
from datasets import load_dataset
import torch
from transformers import RobertaTokenizer, RobertaModel, LongformerTokenizer,  LongformerModel

In [4]:
def get_persona_sentences(split):
  # Load data from HuggingFace
  dataset = load_dataset("pec", "all")[split]
  print(dataset)

  # Dictionary from speaker name to persona sentences
  speaker_persona_sentences = {}
  for example in dataset:
    response_speaker = example['response_speaker']
    
    if response_speaker not in speaker_persona_sentences.keys():
      speaker_persona_sentences[response_speaker] = set()
    
    # Collect persona sentences and add to set.
    for persona_sentence in example['personas']:
      if persona_sentence not in speaker_persona_sentences[response_speaker]:
        speaker_persona_sentences[response_speaker].add(persona_sentence)

  return speaker_persona_sentences


def get_tokenizer_and_model(model_name):
  if model_name == 'Roberta':
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    model = RobertaModel.from_pretrained("roberta-base")
    return tokenizer, model

  elif model_name == 'Longformer':
    tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
    model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
    return tokenizer, model

  else:
    raise Exception('model_name argument not in [Roberta, Longformer]')


def model_forward_pass(model_name, model, tokenizer, cat_persona_sentences):
  with torch.no_grad():

    if model_name == 'Roberta':
      inputs = tokenizer(cat_persona_sentences, truncation=True, return_tensors="pt")
      outputs = model(**inputs, output_hidden_states=True)
      return outputs
    
    elif model_name == 'Longformer':
      # batch of size 1
      input_ids = torch.tensor(tokenizer.encode(cat_persona_sentences)).unsqueeze(0)

      # global attention mask to attend locally within a persona sentence
      # and globally among special tokens.
      global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long)
      for ix in range(len(input_ids[0])):
        if input_ids[0][ix] in [0, 2]:
          global_attention_mask[0][ix] = 1.0

      outputs = model(input_ids,
                      global_attention_mask=global_attention_mask,
                      output_hidden_states=True)
      
      return outputs

    else:
      raise Exception('model_name argument not in [Roberta, Longformer]')
    

def get_persona_embeddings(split, model_name):
  # Load persona sentences
  speaker_persona_sentences = get_persona_sentences(split)
  print('>>> get_persona_sentences: persona sentences for %s speakers loaded.' % len(speaker_persona_sentences.keys()))
    
  # Load BERT based model with pretrained weights to create persona embeddings
  tokenizer, model = get_tokenizer_and_model(model_name)
  print('>>> %s tokenizer and model loaded.' % model_name)

  # Put model in eval mode
  model.eval()

  # Dictionary from speaker name to persona embedding
  speaker_personas = {}
  
  # Print periodic logs
  count, num_speakers = 0, len(speaker_persona_sentences)
  deciles = [int(ix * num_speakers / 100) for ix in range(5, 100, 10)]
  print('>>> Creating %s persona embeddings for %s data at %s' % \
        (num_speakers, split, datetime.now()))

  for speaker in speaker_persona_sentences.keys():
    count += 1
    if count in deciles:
      print('- Creating %s th persona embedding at %s' % (count, datetime.now()))

    # Concatenate persona sentences and add special tokens in between
    cat_persona_sentences = ' </s> '.join(['<s> ' + sentence + ' </s>' \
                                          for sentence in speaker_persona_sentences[speaker]])
    
    # Get outputs object with all hidden states
    outputs = model_forward_pass(model_name, model, tokenizer, cat_persona_sentences)

    # Get last four layers.
    last_four_layers = [outputs.hidden_states[i] for i in (-1, -2, -3, -4)]

    # Cast layers to a tuple and concatenate over the last dimension
    cat_hidden_states = torch.cat(tuple(last_four_layers), dim=-1)

    # Take the mean of the concatenated vector over the token dimension
    speaker_personas[speaker] = torch.mean(cat_hidden_states, dim=1).squeeze()

  return speaker_personas

def serialize_persona_embeddings(dir_path, split='train', model_name='Longformer'):
  # Create persona embeddings from data
  speaker_persona = get_persona_embeddings(split, model_name)
  
  if not dir_path.endswith('/'):
    dir_path += '/'

  file_path = dir_path + split + '-' + model_name + '-persona-embeddings.pt'
  torch.save(speaker_persona, file_path)
  
  print('>>> File serialized at', file_path)

  return file_path

In [None]:
file_path = serialize_persona_embeddings('/tmp/', 'train', 'Longformer')
!gsutil cp {file_path} gs://{moel-data}/



  0%|          | 0/3 [00:00<?, ?it/s]

Dataset({
    features: ['personas', 'context', 'context_speakers', 'response', 'response_speaker'],
    num_rows: 281163
})
>>> get_persona_sentences: persona sentences for 148493 speakers loaded.


Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerModel: ['lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


>>> Longformer tokenizer and model loaded.
>>> Creating 148493 persona embeddings for train data at 2022-10-25 19:33:27.344847


In [None]:
file_path = serialize_persona_embeddings('/tmp/', 'validation', 'Longformer')
!gsutil cp {file_path} gs://{moel-data}/

In [None]:
file_path = serialize_persona_embeddings('/tmp/', 'test', 'Longformer')
!gsutil cp {file_path} gs://{moel-data}/