In [1]:
import torch
from transformers import BertForSequenceClassification, RobertaForSequenceClassification, \
AlbertForSequenceClassification, XLNetForSequenceClassification, CamembertForSequenceClassification, \
FlaubertForSequenceClassification, AdamW, get_linear_schedule_with_warmup, \
BertTokenizer, RobertaTokenizer, AlbertTokenizer, XLNetTokenizer, CamembertTokenizer, FlaubertTokenizer
from sklearn.metrics import classification_report
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import torch.nn.functional as F
from torch.autograd import Variable
import random
import numpy as np
import time
import seaborn as sns
import matplotlib.pyplot as plt
import os

In [2]:
transformer_model = 'bert-base-cased'
verb_segment_ids = 'yes'
num_epochs = 4
use_segment_ids = True
# logger = open(transformer_model + '_' + str(use_segment_ids) + '.log', 'w')
batch_size = 32

In [4]:
def find_real_verb_idx(sentence_list, encoded_sentence, tokenizer):
    
    encoded_verb = tokenizer.encode(sentence_list[1])

    try:
        if len(encoded_verb) == 3: # verb as is + [CLS] + [SEP]
            verb_idx = [encoded_sentence.index(encoded_verb[1])]
        else:
            decoded_verb = tokenizer.convert_ids_to_tokens(encoded_verb)
            decoded_sent = tokenizer.convert_ids_to_tokens(encoded_sentence)
#             print(decoded_sent)
#             print(decoded_verb)
            verb_segment = [seg for seg in decoded_verb 
                            if not any(seg.startswith(x) for x in ['[', '<'])]
            verb_idx = [decoded_sent.index(seg) for seg in verb_segment]
            
        return verb_idx
    except ValueError:
        return False

def tokenize_and_pad(transformer_model, sentences):
    """ We are using .encode_plus. This does not make specialized attn masks 
        like in our selectional preferences experiment. Revert to .encode if
        necessary."""
    
    input_ids = []
    segment_ids = [] # token type ids
    attention_masks = []
    
    if transformer_model.split("-")[0] == 'bert':
        tok = BertTokenizer.from_pretrained(transformer_model)
    elif transformer_model.split("-")[0] == 'roberta':
        tok = RobertaTokenizer.from_pretrained(transformer_model)
    elif transformer_model.split("-")[0] == 'albert':
        tok = AlbertTokenizer.from_pretrained(transformer_model)
    elif transformer_model.split("-")[0] == 'xlnet':
        tok = XLNetTokenizer.from_pretrained(transformer_model)
    elif 'camembert' in transformer_model:
        tok = CamembertTokenizer.from_pretrained(transformer_model)
    elif 'flaubert' in transformer_model:
        tok = FlaubertTokenizer.from_pretrained(transformer_model)

    for sent in sentences:
        sentence = sent[0]

        # encode_plus is a prebuilt function that will make input_ids, 
        # add padding/truncate, add special tokens, + make attention masks 
        encoded_dict = tok.encode_plus(
                        sentence,                      
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        max_length = 128,      # Pad & truncate all sentences.
                        padding = 'max_length',
                        truncation = True,
                        return_attention_mask = True, # Construct attn. masks.
                        # return_tensors = 'pt',     # Return pytorch tensors.
                   )

        # Add the encoded sentence to the list.
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])
        # Add segment ids, add 1 for verb idx
        segment_id = [0] * 128        
        verb_idx = find_real_verb_idx(sent, encoded_dict['input_ids'], tok)       
        if verb_idx: # if False, the verb is not in the first 128 tokens
            for idx in verb_idx:
                segment_id[idx] = 1            
        segment_ids.append(segment_id)    

    return input_ids, attention_masks, segment_ids


def decode_result(transformer_model, encoded_sequence):

    if transformer_model.split("-")[0] == 'bert':
        tok = BertTokenizer.from_pretrained(transformer_model)
    elif transformer_model.split("-")[0] == 'roberta':
        tok = RobertaTokenizer.from_pretrained(transformer_model)
    elif transformer_model.split("-")[0] == 'albert':
        tok = AlbertTokenizer.from_pretrained(transformer_model)
    elif transformer_model.split("-")[0] == 'xlnet':
        tok = XLNetTokenizer.from_pretrained(transformer_model)
    elif 'camembert' in transformer_model:
        tok = CamembertTokenizer.from_pretrained(transformer_model)
    elif 'flaubert' in transformer_model:
        tok = FlaubertTokenizer.from_pretrained(transformer_model)
    
    # decode + remove special tokens
    tokens_to_remove = ['[PAD]', '<pad>', '<s>', '</s>']
    decoded_sequence = [w.replace('Ġ', '').replace('▁', '').replace('</w>', '')
                        for w in list(tok.convert_ids_to_tokens(encoded_sequence))
                        if not w.strip() in tokens_to_remove]
    
    return decoded_sequence

def read_sents(path, marker):
    """ Read the .tsv files with the annotated sentences. 
        File format: sent_id, sentence, verb, verb_idx, label"""

    def open_file(file):
        sentences = []
        labels = []
        
        with open(file, 'r', encoding='utf-8') as f:
            for line in f:
                l = line.strip().split('\t')
                sentences.append([l[-4], l[-3], int(l[-2])])
                labels.append(int(l[-1]))
                
            return sentences,labels
        
    train_sentences, train_labels = open_file(path + '/' + marker + '_train.tsv')    
    val_sentences, val_labels = open_file(path +  '/' + marker + '_val.tsv')
    test_sentences, test_labels = open_file(path +  '/' + marker + '_test.tsv')

    return train_sentences, train_labels, val_sentences, val_labels, test_sentences, test_labels

In [5]:
width = 3
example_sep = 3
word_height = 1
pad = 0.1

def plot_attn(sentence, attentions, layer, heads):
    """Plotspredicted_labelsntion maps for the given example and attention heads."""

    for ei, head in enumerate(heads):
        yoffset = 1
        xoffset = ei * width * example_sep

        attn = attentions[layer][head]
        attn = np.array(attn)
        attn /= attn.sum(axis=-1, keepdims=True)
        words = sentence
        n_words = len(words)

        for position, word in enumerate(words):
            plt.text(xoffset + 0, yoffset - position * word_height, word,
                   ha="right", va="center")
            plt.text(xoffset + width, yoffset - position * word_height, word,
                   ha="left", va="center")
        for i in range(1, n_words):
            for j in range(1, n_words):
                plt.plot([xoffset + pad, xoffset + width - pad],
                 [yoffset - word_height * i, yoffset - word_height * j],
                 color="blue", linewidth=1, alpha=attn[i, j].item())

In [None]:
print('Uses verb segment ids: ' + str(use_segment_ids))
print('Model: ' + transformer_model)

device = torch.device("cpu")

# PARAMETERS
transformer_model = transformer_model
epochs = num_epochs

# read friedrich sentences, choose labels of telicity/duration
train_sentences, train_labels, val_sentences, val_labels, \
test_sentences, test_labels = read_sents("data/friedrich_captions_data", 'telicity')

# make input ids, attention masks, segment ids, depending on the model we will use
train_inputs, train_masks, train_segments = tokenize_and_pad(transformer_model, train_sentences)
print('Loaded sentences and converted.')
# logger.write('\nTrain set: ' + str(len(train_inputs)))
train_inputs = torch.tensor(train_inputs)
train_labels = torch.tensor(train_labels)
if use_segment_ids:
    train_segments = torch.tensor(train_segments)
train_masks = torch.tensor(train_masks)

if use_segment_ids:
    train_data = TensorDataset(train_inputs, train_masks, train_labels, train_segments)
else:
    train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

In [None]:
# train

if ( transformer_model).split("-")[0] == 'bert':
    model = BertForSequenceClassification.from_pretrained(
        transformer_model,
        num_labels = 2, 
        output_attentions = True,
        output_hidden_states = False, 
    )
elif ( transformer_model).split("-")[0] == 'roberta':
    model = RobertaForSequenceClassification.from_pretrained(
        transformer_model, 
        num_labels = 2,   
        output_attentions = True,
        output_hidden_states = False,
    )
elif ( transformer_model).split("-")[0] == 'albert':
    model = AlbertForSequenceClassification.from_pretrained(
        transformer_model, 
        num_labels = 2,   
        output_attentions = True,
        output_hidden_states = False, 
    )
elif ( transformer_model).split("-")[0] == 'xlnet':
    model = XLNetForSequenceClassification.from_pretrained(
        transformer_model, 
        num_labels = 2,   
        output_attentions = True,
        output_hidden_states = False,
    )
elif 'flaubert' in  transformer_model:
    model = FlaubertForSequenceClassification.from_pretrained(
        transformer_model, 
        num_labels = 2,   
        output_attentions = True,
        output_hidden_states = False,
    )
elif 'camembert' in  transformer_model:
    model = CamembertForSequenceClassification.from_pretrained(
        transformer_model, 
        num_labels = 2,   
        output_attentions = True, 
        output_hidden_states = False, 
    )

optimizer = AdamW(model.parameters(), lr = 2e-5, eps = 1e-8)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps)
seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
loss_values = []

for epoch_i in range(0, epochs):
#     logger.write('\n\t======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
#     logger.write('\nTraining...')
    print('\n\t======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('\nTraining...')
#     t0 = time.time()
    total_loss = 0
    model.train()

    for step, batch in enumerate(train_dataloader):
           
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        
        if use_segment_ids:
            b_segments = batch[3].to(device)

        model.zero_grad()        

        if use_segment_ids:
            outputs = model(b_input_ids, 
                        token_type_ids=b_segments, 
                        attention_mask=b_input_mask, 
                        labels=b_labels)
        else:
            outputs = model(b_input_ids, 
                        token_type_ids=None, 
                        attention_mask=b_input_mask, 
                        labels=b_labels)

        loss = outputs[0]
        total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

    avg_train_loss = total_loss / len(train_dataloader)            
    loss_values.append(avg_train_loss)
#     logger.write('\n\tAverage training loss: {0:.2f}'.format(avg_train_loss))
#     logger.write('\n\tTraining epoch took: {:}'.format(format_time(time.time() - t0)))
    print('\n\tAverage training loss: {0:.2f}'.format(avg_train_loss))
#     print('\n\tTraining epoch took: {:}'.format(format_time(time.time() - t0)))

In [None]:
# open unseen test set

test_sentences = []
test_labels = []
    
with open('data/unseen_tests/telicity_unseen.tsv', 'r', encoding='utf-8') as f: # qualitative test 1
# with open('data/unseen_tests/telicity_min_pairs.tsv', 'r', encoding='utf-8') as f:     # minimal pairs test set
# with open('data/unseen_tests/telicity_more_tests.tsv', 'r', encoding='utf-8') as f:
    for line in f:
        l = line.strip().split('\t')
        test_sentences.append([l[-4], l[-3], int(l[-2])])
        test_labels.append(int(l[-1]))

# use_segment_ids = False
test_inputs, test_masks, test_segments = tokenize_and_pad(transformer_model, test_sentences)

test_inputs = torch.tensor(test_inputs)
test_labels = torch.tensor(test_labels)
test_masks = torch.tensor(test_masks)
test_segments = torch.tensor(test_segments)

# Return attentions for each sentence of the test set, attentions per sentence (not batched per layer)

model.eval()

all_inputs = []
sent_attentions = []
sentences = []
predicted_labels = []
prob_prediction = []

for inputs in test_inputs:
    
    test_input = inputs.resize(1, 128)
    
    with torch.no_grad():        
        outputs = model(test_input, 
                        token_type_ids=None, 
                        attention_mask=None)

    logits = outputs[0]
    logits = logits.detach().cpu().numpy()
    attentions = outputs[1]
    
    log_probs = F.softmax(Variable(torch.from_numpy(logits)), dim=-1)

    predicted_labels += np.argmax(logits, axis=1).flatten().tolist()
    prob_prediction += log_probs.tolist()
    
    sentence = decode_result(transformer_model, inputs)
    sentences.append(sentence)
    len_sequence = len(sentence)
    
    temp_attentions = [] # turn attention to (layer, head, size, size)
    
    for layer in attentions:
        temp = torch.squeeze(layer) #remove dimension of batch size = 1
        temp = np.array(temp)[:, :len_sequence, :len_sequence]
        temp_attentions.append(temp)
        
    sent_attentions.append(np.asarray(temp_attentions))

logz = open('plots/unseen/sent_indices.txt', 'w')
logz.write('ID\tSentence\tLabel\tPred.\tProbabilities\n')

for sent_idx in range(len(sentences)):
    print(sent_idx)
    save_path = 'plots/unseen/' + str(sent_idx) + '/' + transformer_model + '_' + verb_segment_ids +   '/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)    
    logz.write('\t'.join([str(sent_idx), ' '.join(sentences[sent_idx]), 
                            str(test_labels[sent_idx]), 
                            str(predicted_labels[sent_idx]),
                            str(prob_prediction[sent_idx][0]),
                            str(prob_prediction[sent_idx][1])
                         ]))
    logz.write('\n')
    for layer in range(12):
        heads = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
#         print('\nLayer:', layer+1, '\n')
        plt.figure(figsize=(50, 4))
        plt.axis("off")
        plot_attn(sentences[sent_idx], sent_attentions[sent_idx], layer, heads)
        plt.savefig(save_path + str(layer+1) + '.png')
        
    plt.cla()
    plt.clf()
    # sent_idx = 2 # target_sentence
    len_sequence = len(sentences[sent_idx])

    attentions_pos = sent_attentions[sent_idx]
    attentions_pos = torch.FloatTensor(attentions_pos).permute(2,1,0,3)

    verb_idx = test_segments[sent_idx].tolist().index(1) + 1 # add 1 for the [CLS] token
    
    #Plot Attention for specifically the verb

    cols = 2
    heads = 12
    rows = int(heads/cols)

    fig, axes = plt.subplots(rows, cols, figsize = (14,30))
    axes = axes.flat
#     print('Sentence: ', ' '.join(sentences[sent_idx]))
#     print ('Attention weights for token: ', sentences[sent_idx][verb_idx])

    for i, att in enumerate(attentions_pos[verb_idx]):
        #im = axes[i].imshow(att, cmap='gray')
        sns.heatmap(att, vmin = 0, vmax = 1, ax = axes[i], xticklabels = sentences[sent_idx])
        axes[i].set_title(f'head - {i+1} ' )
        axes[i].set_ylabel('layers')
        plt.savefig(save_path + 'verb_attn.png')
    plt.cla()
    plt.clf()

In [None]:
# open min test set

test_sentences = []
test_labels = []
    
# with open('data/unseen_tests/telicity_unseen.tsv', 'r', encoding='utf-8') as f: # qualitative test 1
with open('data/unseen_tests/telicity_min_pairs.tsv', 'r', encoding='utf-8') as f:     # minimal pairs test set
# with open('data/unseen_tests/telicity_more_tests.tsv', 'r', encoding='utf-8') as f:
    for line in f:
        l = line.strip().split('\t')
        test_sentences.append([l[-4], l[-3], int(l[-2])])
        test_labels.append(int(l[-1]))

# use_segment_ids = False
test_inputs, test_masks, test_segments = tokenize_and_pad(transformer_model, test_sentences)

test_inputs = torch.tensor(test_inputs)
test_labels = torch.tensor(test_labels)
test_masks = torch.tensor(test_masks)
test_segments = torch.tensor(test_segments)

# Return attentions for each sentence of the test set, attentions per sentence (not batched per layer)

model.eval()

all_inputs = []
sent_attentions = []
sentences = []
predicted_labels = []
prob_prediction = []

for inputs in test_inputs:
    
    test_input = inputs.resize(1, 128)
    
    with torch.no_grad():        
        outputs = model(test_input, 
                        token_type_ids=None, 
                        attention_mask=None)

    logits = outputs[0]
    logits = logits.detach().cpu().numpy()
    attentions = outputs[1]
    
    log_probs = F.softmax(Variable(torch.from_numpy(logits)), dim=-1)

    predicted_labels += np.argmax(logits, axis=1).flatten().tolist()
    prob_prediction += log_probs.tolist()
    
    sentence = decode_result(transformer_model, inputs)
    sentences.append(sentence)
    len_sequence = len(sentence)
    
    temp_attentions = [] # turn attention to (layer, head, size, size)
    
    for layer in attentions:
        temp = torch.squeeze(layer) #remove dimension of batch size = 1
        temp = np.array(temp)[:, :len_sequence, :len_sequence]
        temp_attentions.append(temp)
        
    sent_attentions.append(np.asarray(temp_attentions))
    
logz = open('plots/min_pairs/sent_indices.txt', 'w')
logz.write('ID\tSentence\tLabel\tPred.\tProbabilities\n')

for sent_idx in range(len(sentences)):
    print(sent_idx)
    save_path = 'plots/min_pairs/' + str(sent_idx) + '/' + transformer_model + '_' + verb_segment_ids +   '/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)    
    logz.write('\t'.join([str(sent_idx), ' '.join(sentences[sent_idx]), 
                            str(test_labels[sent_idx]), 
                            str(predicted_labels[sent_idx]),
                            str(prob_prediction[sent_idx][0]),
                            str(prob_prediction[sent_idx][1])
                         ]))
    logz.write('\n')
    for layer in range(12):
        heads = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
#         print('\nLayer:', layer+1, '\n')
        plt.figure(figsize=(50, 4))
        plt.axis("off")
        plot_attn(sentences[sent_idx], sent_attentions[sent_idx], layer, heads)
        plt.savefig(save_path + str(layer+1) + '.png')
        
    plt.cla()
    plt.clf()
    # sent_idx = 2 # target_sentence
    len_sequence = len(sentences[sent_idx])

    attentions_pos = sent_attentions[sent_idx]
    attentions_pos = torch.FloatTensor(attentions_pos).permute(2,1,0,3)

    verb_idx = test_segments[sent_idx].tolist().index(1) + 1 # add 1 for the [CLS] token
    
    #Plot Attention for specifically the verb

    cols = 2
    heads = 12
    rows = int(heads/cols)

    fig, axes = plt.subplots(rows, cols, figsize = (14,30))
    axes = axes.flat
#     print('Sentence: ', ' '.join(sentences[sent_idx]))
#     print ('Attention weights for token: ', sentences[sent_idx][verb_idx])

    for i, att in enumerate(attentions_pos[verb_idx]):
        #im = axes[i].imshow(att, cmap='gray')
        sns.heatmap(att, vmin = 0, vmax = 1, ax = axes[i], xticklabels = sentences[sent_idx])
        axes[i].set_title(f'head - {i+1} ' )
        axes[i].set_ylabel('layers')
        plt.savefig(save_path + 'verb_attn.png')
    plt.cla()
    plt.clf()

In [None]:
# open more test set

test_sentences = []
test_labels = []
    
# with open('data/unseen_tests/telicity_unseen.tsv', 'r', encoding='utf-8') as f: # qualitative test 1
# with open('data/unseen_tests/telicity_min_pairs.tsv', 'r', encoding='utf-8') as f:     # minimal pairs test set
with open('data/unseen_tests/telicity_more_tests.tsv', 'r', encoding='utf-8') as f:
    for line in f:
        l = line.strip().split('\t')
        test_sentences.append([l[-4], l[-3], int(l[-2])])
        test_labels.append(int(l[-1]))

# use_segment_ids = False
test_inputs, test_masks, test_segments = tokenize_and_pad(transformer_model, test_sentences)

test_inputs = torch.tensor(test_inputs)
test_labels = torch.tensor(test_labels)
test_masks = torch.tensor(test_masks)
test_segments = torch.tensor(test_segments)

# Return attentions for each sentence of the test set, attentions per sentence (not batched per layer)

model.eval()

all_inputs = []
sent_attentions = []
sentences = []
predicted_labels = []
prob_prediction = []

for inputs in test_inputs:
    
    test_input = inputs.resize(1, 128)
    
    with torch.no_grad():        
        outputs = model(test_input, 
                        token_type_ids=None, 
                        attention_mask=None)

    logits = outputs[0]
    logits = logits.detach().cpu().numpy()
    attentions = outputs[1]
    
    log_probs = F.softmax(Variable(torch.from_numpy(logits)), dim=-1)

    predicted_labels += np.argmax(logits, axis=1).flatten().tolist()
    prob_prediction += log_probs.tolist()
    
    sentence = decode_result(transformer_model, inputs)
    sentences.append(sentence)
    len_sequence = len(sentence)
    
    temp_attentions = [] # turn attention to (layer, head, size, size)
    
    for layer in attentions:
        temp = torch.squeeze(layer) #remove dimension of batch size = 1
        temp = np.array(temp)[:, :len_sequence, :len_sequence]
        temp_attentions.append(temp)
        
    sent_attentions.append(np.asarray(temp_attentions))
    
logz = open('plots/min_pairs/sent_indices.txt', 'w')
logz.write('ID\tSentence\tLabel\tPred.\tProbabilities\n')

for sent_idx in range(len(sentences)):
    print(sent_idx)
    save_path = 'plots/more_tests/' + str(sent_idx) + '/' + transformer_model + '_' + verb_segment_ids +   '/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)    
    logz.write('\t'.join([str(sent_idx), ' '.join(sentences[sent_idx]), 
                            str(test_labels[sent_idx]), 
                            str(predicted_labels[sent_idx]),
                            str(prob_prediction[sent_idx][0]),
                            str(prob_prediction[sent_idx][1])
                         ]))
    logz.write('\n')
    for layer in range(12):
        heads = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
#         print('\nLayer:', layer+1, '\n')
        plt.figure(figsize=(50, 4))
        plt.axis("off")
        plot_attn(sentences[sent_idx], sent_attentions[sent_idx], layer, heads)
        plt.savefig(save_path + str(layer+1) + '.png')
        
    plt.cla()
    plt.clf()
    # sent_idx = 2 # target_sentence
    len_sequence = len(sentences[sent_idx])

    attentions_pos = sent_attentions[sent_idx]
    attentions_pos = torch.FloatTensor(attentions_pos).permute(2,1,0,3)

    verb_idx = test_segments[sent_idx].tolist().index(1) + 1 # add 1 for the [CLS] token
    
    #Plot Attention for specifically the verb

    cols = 2
    heads = 12
    rows = int(heads/cols)

    fig, axes = plt.subplots(rows, cols, figsize = (14,30))
    axes = axes.flat
#     print('Sentence: ', ' '.join(sentences[sent_idx]))
#     print ('Attention weights for token: ', sentences[sent_idx][verb_idx])

    for i, att in enumerate(attentions_pos[verb_idx]):
        #im = axes[i].imshow(att, cmap='gray')
        sns.heatmap(att, vmin = 0, vmax = 1, ax = axes[i], xticklabels = sentences[sent_idx])
        axes[i].set_title(f'head - {i+1} ' )
        axes[i].set_ylabel('layers')
        plt.savefig(save_path + 'verb_attn.png')
    plt.cla()
    plt.clf()