In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np

from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from torch import nn as nn

import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from datasets import load_dataset

import tqdm
from torcheval.metrics.functional import multiclass_f1_score

  _warn(("h5py is running against HDF5 {0} when it was built against {1}, "
  warn(


In [2]:
np.random.seed(22)

In [3]:
def retrieve_data(language, data_dir):
    data_files = {"train": data_dir + language+'/train.tsv', 
                  "dev": data_dir + language+'/dev.tsv',
                  "test": data_dir + language+'/test.tsv' }
    
    language_data = load_dataset("csv", data_files=data_files, sep="\t")
    return language_data   
    

In [4]:
class EmbeddingExtractor(nn.Module):
    def __init__(self, model_string):
        super().__init__()
        self._tokenizer = AutoTokenizer.from_pretrained(model_string)
        self._model = AutoModelForSequenceClassification.from_pretrained(model_string, num_labels= 3,  output_hidden_states=True)
    @property
    def tokenizer(self):
        return self._tokenizer
    @property
    def model(self):
        return self._model
    
    @property
    def embedding_idex(self):
        return self._embedding_index   
        
    def forward(self, inputs, index, with_prediction =False):
        output= self._model(**inputs)        
        embedding = output['hidden_states'][index]  
        attention_mask = inputs['attention_mask'] # return output
        embedding_pooled = self.mean_pooling(embedding, attention_mask) 
        if with_prediction:            
            scores = nn.functional.softmax(output['logits'], dim=1)
            predictions = torch.argmax(scores, dim=1)
            return  embedding_pooled, inputs['labels'], predictions
        return  embedding_pooled, inputs['labels']
    def mean_pooling(self, model_output, attention_mask):        
        token_embeddings = model_output #First element of model_output contains all token embeddings        
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()        
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


In [5]:
from ipywidgets import interact

In [7]:
@interact(model_string={"AfriBerta":"Davlan/naija-twitter-sentiment-afriberta-large"
                        , "Bert-base":"bert-base-multilingual-cased"}
          , lang ={"Ibo": 'ibo', "Yoruba": 'yor', "Pidgen":'pcm', "Hausa": 'hau'}
          , data_dir='naija_sent_data/annotated_tweets/', batch = range(10, 101, 10), index=[-1,0]
          , max_retrieved = range(50, 501, 50))
def visualize_embeddings(model_string, lang, data_dir, batch, index, max_retrieved):    
# def visualize_embeddings(model_params, lang, data_dir, id2label, label2id, batch):   
    id2label ={0:"positive", 1:"neutral", 2:"negative"}
    label2id = {"positive": 0, "neutral": 1, "negative": 2}
    
    ################
    def collate_fn (batch):
        features  = [{"input_ids": d['input_ids'], "attention_mask":  d['attention_mask'], "label":  d['label']} for d in batch]
        c = data_collator(features)
        return c

    def preprocess_function(examples):
        tokenized_batch = tokenizer(examples['tweet'], truncation=True, padding= True)
        tokenized_batch["label"] = [label2id[label] for label in examples["label"]]
        return tokenized_batch

    # instantiate model initializer
    print("Analyzes for  "+lang)    
    extractor = EmbeddingExtractor(model_string)
    tokenizer = extractor.tokenizer
    # # summary(extractor.model, input_size = (768,), depth=1, batch_dim =1, dtypes=['torch.IntTensor'])
    # # return
    
    data_collator = DataCollatorWithPadding(tokenizer) 
    
    ## retrieve data    
    language_data = retrieve_data(lang, data_dir)   
    tokenized_data = language_data.map(preprocess_function, batched=True)    
    ###########################
    train_loader = DataLoader(tokenized_data['train'], batch_size=batch, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(tokenized_data['dev'], batch_size=batch, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(tokenized_data['test'], batch_size=batch, shuffle=True, collate_fn=collate_fn)

    retrieved_embeddings = None  
    retrieved_labels = None
    predicted_labels= None
    sentences = []
    with torch.no_grad():
        for idx, data in enumerate(train_loader): 
            for encoded in data['input_ids']:
                sentences.append(tokenizer.decode(encoded, skip_special_tokens=True))            
            embedding, labels, predictions = extractor(data, index, with_prediction= True)
            embedding, labels, predictions = embedding.detach(), labels.detach(), predictions.detach()
            if not isinstance(retrieved_embeddings, torch.Tensor):
                retrieved_embeddings = embedding
                retrieved_labels = labels
                predicted_labels = predictions
            else:   
                retrieved_embeddings = torch.cat((retrieved_embeddings, embedding), dim = 0)
                retrieved_labels = torch.cat((retrieved_labels, labels), dim = 0)
                predicted_labels = torch.cat((predicted_labels, predictions), dim = 0)
                ## check if maximum fetched data is reached
            if ((retrieved_embeddings != None) and (retrieved_embeddings.size()[0] > max_retrieved)):
                break
    sentences = np.array(sentences)
    X_sentences = torch.matmul(retrieved_embeddings, retrieved_embeddings.T)
    U,S, V = torch.svd_lowrank(retrieved_embeddings)    
    proj = torch.matmul(X_sentences, U[:, :3]).detach().numpy()

    fig = go.Figure()
   #  fig = make_subplots(rows=1,cols=2,start_cell="top-left",
   #                      specs=[
   #      [{"type": "scatter3d"}, {"type": "scatter3d"}] 
   #      ]
   # )
    colours = ['#00FF00', '#000000', '#FF0000']    
    # colours = ['Green', 'Black', 'Red']
    for category, value in label2id.items():        
        categorical_data = proj[retrieved_labels==value]
        seived_sentences = sentences[retrieved_labels==value]

        predict_categorical_data = proj[predicted_labels==value]
        predict_seived_sentences = sentences[predicted_labels==value]

        wrong_indices = (retrieved_labels==value).bitwise_and(retrieved_labels !=predicted_labels)        
        wrong_categorical_data = proj[wrong_indices]
        wrong_seived_sentences = sentences[wrong_indices]
        wrong_seived_labels = predicted_labels[wrong_indices]

        predicted_colours = [colours[colour_index] for colour_index in wrong_seived_labels]
        inner_colour = [ colours[inCol] for inCol in  retrieved_labels[wrong_indices]]
        
        #Ground labeled 
        fig.add_trace(go.Scatter3d(
            x= categorical_data[:, 0], y =categorical_data[:, 1], z = categorical_data[:, 2], 
            mode = 'markers',
            name = "True_"+category,
            text = seived_sentences,
            marker=dict(
            # size=12, # Size
            color=colours[value], # Color            
            opacity=0.8, # Point transparency 
            line=dict(width=1, color=colours[value]) # Properties of the edges
            )))

        fig.add_trace(go.Scatter3d(
                    x= wrong_categorical_data[:, 0], y =wrong_categorical_data[:, 1], z = wrong_categorical_data[:, 2], 
                    mode = 'markers',
                    name = "False_"+category,
                    text = wrong_seived_sentences,
                    marker=dict(
                    # size=12, # Size
                    color=predicted_colours, # Color
                    symbol="x",
                    opacity=0.8, # Point transparency 
                    line=dict(width=1.5, color= inner_colour) # Properties of the edges
                    )))


        # #Ground labeled 
        # fig.add_trace(go.Scatter3d(
        #     x= predict_categorical_data[:, 0], y =predict_categorical_data[:, 1], z = predict_categorical_data[:, 2], 
        #     mode = 'markers',
        #     name = "Predicted_"+category,
        #     text = predict_seived_sentences,
        #     marker=dict(
        #     # size=12, # Size
        #     color=colours[value], # Color            
        #     opacity=0.8, # Point transparency 
        #     line=dict(width=1, color='black') # Properties of the edges
        #     )),row=1, col=2)

        # #wrongly labeled 
        # fig.add_trace(go.Scatter3d(
        #     x= wrong_categorical_data[:, 0], y =wrong_categorical_data[:, 1], z = wrong_categorical_data[:, 2], 
        #     mode = 'markers',
        #     name = "False_"+category,
        #     text = wrong_seived_sentences,
        #     marker=dict(
        #     # size=12, # Size,
        #     symbol="x",
        #     color=colours[value], # Color            
        #     opacity=0.8, # Point transparency 
        #     line=dict(width=1, color='red') # Properties of the edges
        #     )),row=1, col=2
        # )

        # fig.add_trace(go.Scatter3d(
        #     x= categorical_data[:, 0], y =categorical_data[:, 1], z = categorical_data[:, 2], 
        #     mode = 'markers',
        #     name = category,
        #     text = seived_sentences,
        #     marker=dict(
        #     size=12, # Size
        #     color=colours[value], # Color            
        #     opacity=0.8, # Point transparency 
        #     line=dict(width=1, color='black') # Properties of the edges
        #     ),
        # ))

    
    # Customize the layout
    fig.update_layout(
        title='PCA', # Title
        scene=dict(
        xaxis_title='1st Principle Axis',
        yaxis_title='2nd Principle Axis',
        zaxis_title='3rd Principle Axis',
        ),       
        width=1200,  # Set the width of the figure to 800 pixels
        height=800,  # Set the height of the figure to 600 pixels
    )
    fig.show()     

interactive(children=(Dropdown(description='model_string', options={'AfriBerta': 'Davlan/naija-twitter-sentime…