In [None]:
#Configure git user name and user e-mail
!git config global --user.name "root-goksenin"
!git config global --user.emil "goksenin.yuksel@outlook.com"
#Clone the repo into /content files
!git clone https://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@github.com/fbaratov/fact-group21.git
# Change the directory to lick-caption-bias
%cd fact-group21

In [None]:
!pip install jupyterlab
!pip install ipywidgets
!pip install bertviz

In [None]:
from transformers import BertTokenizer
from transformers import PYTORCH_PRETRAINED_BERT_CACHE
from transformers import BertConfig, WEIGHTS_NAME, CONFIG_NAME
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import BertModel
from transformers import BertPreTrainedModel
from bertviz import head_view
import torch
# import age variables from utils
from age_utils import (
  young_words,
  old_words,
  age_words
)
# import age functiond from utils
from age_utils import (
  gender_pickle_generator,
  race_pickle_generator,
  label_human_caption,
  label_human_annotations,
  match_labels,
  make_train_test_split,

)
from age_dataset import BERT_ANN_leak_data, BERT_MODEL_leak_data
from collections import namedtuple
import nltk
nltk.download('punkt')

In [10]:
def load_bert_model(file_name, model):
  '''
  Load the model from the weight file.
  Please note that the weights will also contain the classification head.
  So initialize the model accordingly.
  Arguments
  ---------
  file_name : str
      path to the weights
  model : torch.nn.Module
      BERT Model
  '''
  model.load_state_dict(torch.load(file_name))

def visualize_attention(model, data, tokenizer):
  '''
  Visualize the attention from the bert model.
  Arguments
  ---------
  model : torch.nn.Module
      BERT Model
  data : torch.nn.Dataset
      Annotation data from authors paper
  tokenizer : BertTokenizer
      Uncased bert tokenizer for input parsing
  
  Returns
  -------
  IPYWidget that displays attention matrix.
  
  '''
  # Get outputs from the dataloader.
  input_ids, attention_mask, token_type_ids, age_target, img_id = data[0]
  # Convert everything to gpu.
  input_ids = torch.unsqueeze(input_ids.cuda(), dim = 0)
  attention_mask = torch.unsqueeze(attention_mask.cuda(), dim = 0)
  token_type_ids = torch.unsqueeze(token_type_ids.cuda(), dim = 0)
  age_target = torch.squeeze(age_target).cuda()
  # Pass the arguments to model
  outputs = model(input_ids, 
                  attention_mask=attention_mask, 
                  token_type_ids=token_type_ids)
  # Get the attention matrices.
  attention = outputs[-1]    
  # Convert tokens into words so it is readable
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
  # Find the index of [SEP] token. This indicates that sentence has finished.
  first_sep = list(tokens).index('[SEP]') 
  # Get rid of padding attentions and tokens.
  new_attention = torch.zeros((len(attention), 1, len(attention), first_sep + 1, first_sep + 1))
  for id,att in enumerate(attention):
    new_attention[id] = att[:, :,  :first_sep + 1, :first_sep + 1]
  # Return the widget.
  return head_view(new_attention, tokens[:first_sep + 1], html_action = 'return')  # Display model view

In [None]:
# Emulete fake args.
ARG = namedtuple('args', ['batch_size', 'workers','test_ratio','task', 'align_vocab','max_seq_length', 'cap_model'])
args = ARG(64,1,0.1, 'captioning', True, 64, 'nic')

# Load the bert model and tokenizer
lang_model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True).cuda()
# file_name = ...
# load_bert_model(file_name, model)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Define the human annotation entries using age_utils.
age_val_obj_cap_entries = label_human_annotations(gender_pickle_generator('human'),young_words,old_words) # Human captions

d_train, d_test = make_train_test_split(args, age_val_obj_cap_entries)

# Define the dataset from authors code
trainANNCAPobject = BERT_ANN_leak_data(d_train, d_test, args, age_val_obj_cap_entries, age_words, tokenizer,
                                                args.max_seq_length, split='train', caption_ind=0)
testANNCAPobject = BERT_ANN_leak_data(d_train, d_test, args, age_val_obj_cap_entries, age_words, tokenizer,
                                                args.max_seq_length, split='test', caption_ind=0)


In [14]:
html = visualize_attention(lang_model,trainANNCAPobject,tokenizer)
with open("head_view.html", 'w') as file:
    file.write(html.data)