# Create embeddings.

### Description
This pipeline creates sentence embeddings and downloads them to the data/ folder in our shared workfolder. The process consists of loading the dataset that includes transcriptions, labels and topic changes and create embeddings representations per podcast. Due to the size of the dataset (~4000 videos) We save each tensor of embeddings to a specific folder for later processing. Results are saved in the "./data/dataset_name/pre-trained model/video_id.pt" folder. We use two different model architectures to calculate the embeddings, SBERT and Universal Sentence encoder. Both have different pre trained models that can also be selected.

 Parameters:
 - dataset_name: Name of the dataset to be loaded and processed.
 - model_name: Type of model to create the sentence embeddings.
 - pre_trained_model: Pre trained model to use in the sentence embeddings process
 - dim_redux_method: What type of dimensionality reduction process to use to create the sentence embeddings from the token embeddings.
 - print_debug: Print a message everytime an embedding is created.

In [None]:
#@title Imported Packages and Libraries
!pip install nlp --quiet
!pip install transformers --quiet
!pip install datasets --quiet
!pip install -U sentence-transformers --quiet

from collections import Counter
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub
from nlp import load_dataset

import seaborn as sns
from pprint import pprint 

# Utilites
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt # plotting
import numpy as np # linear algebra
from numpy.linalg import norm
from numpy import dot
import os # accessing directory structure
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import csv
import time
import torch
from datetime import datetime
from torch._C import NoneType
from transformers import AutoTokenizer, AutoModel

# JSON
import json

# Embeddings
from transformers import BertTokenizer, TFBertModel
import sklearn as sk
import nltk
from nltk.data import find

# Pandas CSV processing
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Formatting options for float number in numpy
float_formatter = "{:.4f}".format
np.set_printoptions(formatter={'float_kind':float_formatter})

from google.colab import drive
drive.mount('/content/drive')
path = "/content/drive/MyDrive/W266/project/nlp_podcast_segmentation/" #@param ["/content/drive/MyDrive/W266/project/nlp_podcast_segmentation/", "other"] {allow-input: true}


[K     |████████████████████████████████| 1.7 MB 4.3 MB/s 
[K     |████████████████████████████████| 212 kB 66.4 MB/s 
[K     |████████████████████████████████| 5.8 MB 3.8 MB/s 
[K     |████████████████████████████████| 7.6 MB 52.8 MB/s 
[K     |████████████████████████████████| 182 kB 83.0 MB/s 
[K     |████████████████████████████████| 451 kB 4.1 MB/s 
[K     |████████████████████████████████| 132 kB 61.2 MB/s 
[K     |████████████████████████████████| 127 kB 71.3 MB/s 
[K     |████████████████████████████████| 85 kB 2.4 MB/s 
[K     |████████████████████████████████| 1.3 MB 14.7 MB/s 
[?25h  Building wheel for sentence-transformers (setup.py) ... [?25l[?25hdone
Mounted at /content/drive


In [None]:
#@title Choose parameters
# Choose parameters

dataset_name = "YouTube/yt_scripts_segments_split_n3_111422.csv" #@param ["AMIDataset", "YouTube/yt_scripts_segments_yt_simple_110922.csv", "YouTube/yt_scripts_segments_split_n5_111422.csv", "YouTube/yt_scripts_segments_split_n3_111422.csv","YouTube/yt_small_spacy_dev.csv", "YouTube/yt_scripts_segments_spacy_111022.csv", "YouTube/yt_scripts_segments_split_n5_111422_subset10.csv","YouTube/yt_scripts_segments_split_n10_112922.csv"]
model_name = "SBERT" #@param ["SBERT", "Universal Sentence Encoder"]
pre_trained_model =   "all-mpnet-base-v2" #@param ["all-mpnet-base-v2","stsb-mpnet-base-v2", "all-MiniLM-L6-v2", "multi-qa-mpnet-base-dot-v1", "nli-bert-large-max-pooling", "https://tfhub.dev/google/universal-sentence-encoder/4", "https://tfhub.dev/google/universal-sentence-encoder-large/5", "/content/drive/MyDrive/W266/project/nlp_podcast_segmentation/scripts/ricardo/use/"]
dim_redux_method = 'meanpooling' #@param ["meanpooling", "maxpooling"]
print_debug = "Yes" #@param ["Yes", "No"]
if model_name == "Universal Sentence Encoder":
  pre_trained_model_url = pre_trained_model
  pre_trained_model = str(pre_trained_model.split("/")[-2]) + "-" + str(pre_trained_model.split("/")[-1]) 

filename = ""
if len(dataset_name.split("/")) > 1:
  if dataset_name.split("/")[-2] == "YouTube":
    filename = dataset_name.split("/")[-1]
    dataset_name = dataset_name.split("/")[-2]
    print("Filename:                  " + filename)
print("Dataset:                   " + dataset_name)
print("Model:                     " + model_name)
print("Pre-trained Model:         " + pre_trained_model)
print("Dimensionality Redyction:  " + dim_redux_method)

# Define loading functions

#Remove adjacent topic changes
def clean_adj_topic(topic_list):
  '''
  Removes the second topic change from two neighboring topic changes
  Example: [0,1,1,0,1,1,1,0,1] would be cleaned to [0,1,0,0,1,0,1,0,1]
  '''

  idx = 1
  N = len(topic_list)
  clean_list = topic_list.copy()

  while(idx < N):
    if(clean_list[idx] == 1 and clean_list[idx-1] == 1):
      clean_list[idx] = 0

    idx += 1

  return clean_list

# average_sentences
# Calculate average sentences
def average_sentences(topic_list):
  '''
  Counts the number of sentences between topics
  input: topic labels
  returns: average number of sentences between topics
  '''

  idx = 0
  N = len(topic_list)
  sentence_counts = []
  count = 0

  while(idx < N):
    if(topic_list[idx] == 1):
      sentence_counts.append(count)
      count = 0

    count += 1
    idx += 1

  
  if len(sentence_counts) == 0:
    return 0
    
  return sum(sentence_counts) / len(sentence_counts)

# evaluate_pk
# Evaluates PK
def evaluate_pk(pred, act, k=5):
  idx = k
  miss_count = 0
  measurement = 0

  while (idx+k) < len(act):
    topic_change_pred = False
    topic_change_act = False

    #Checking if there is a topic change - not including the first index
    if sum(pred[idx-(k-1):idx+k]) >= 1:
      topic_change_pred = True
    if sum(act[idx-(k-1):idx+k]) >= 1:
      topic_change_act = True

    if topic_change_pred != topic_change_act:
      miss_count += 1.0

    measurement += 1.0
    idx += 1

  # print(miss_count)
  # print(measurement)
  pk = miss_count/measurement

  return pk

# evaluate_pk
# Evaluates WD
def evaluate_wd(pred, act, k=5):
  idx = k
  N = len(act)
  count = 0

  while (idx+k) < N:
    # print(pred[idx-(k-1):idx+k])
    sum_pred = sum(pred[idx-(k):idx+k])
    sum_act = sum(act[idx-(k):idx+k])

    #adds a count only if the number of boundaries is greater than 0
    if abs(sum_pred - sum_act) > 0:
      count += 1

    idx += 1

  # print(miss_count)
  # print(measurement)
  wd = (1/(N-k))*count

  return wd

# get_meeting_sentences:
# retrieves specific sentences from the transcripts_list
def get_meeting_sentences(embedding_name="", S_list=[], T_list=[], Y_list=[], transcripts_list=[]):
  for transcript_idx, transcript in enumerate(transcripts_list):
    if transcript == embedding_name:
      return S_list[transcript_idx], T_list[transcript_idx], Y_list[transcript_idx], transcript

# cos_sim:
# Calculates cosine similarity between two vectors
def cos_sim(a,b):
  return dot(a, b)/(norm(a)*norm(b))

# estimate_total:
# Estimates total ETA to finish embeddings
def estimate_total(S_list, done_embeddings, emb_speed):
  # Inputs: S_list, and list of done embeddings, embedding speed
  # Returns ETA for completion.
  total_ETA = 0
  total_words_to_embed = 0
  for idx, sentences_to_embed in enumerate(S_list):
    transcript_to_do = transcripts_list[idx]
    if transcript_to_do in done_embeddings:
      continue
    else:
      for sentence_to_embed in sentences_to_embed:
        total_words_to_embed = total_words_to_embed + len(sentence_to_embed.split())
  total_ETA = total_words_to_embed/emb_speed
  return total_ETA

# get_done_embeddings:
# Searches all embeddings done in a folder.
# Inputs: Embeddings Path
# Returns List of transcript names that have embeddings stored in file
def get_done_embeddings(embeddings_path):
  embeddings_path = embeddings_path
  done_embeddings = []
  embedding_name = ""
  for embeddingsPath, embeddingsDname, embedddingsFname in os.walk(os.path.join(embeddings_path)):
    for embeddings_name in embedddingsFname:
      if embeddings_name.split(".")[-1] == "pt" or embeddings_name.split(".")[-1] == "npy":
        embeddings_name = embeddings_name.replace(".pt","")
        embeddings_name = embeddings_name.replace(".npy","")
        if embeddings_name not in transcripts_list:
          continue
        done_embeddings.append(embeddings_name)
    break
  return done_embeddings

# get_done_metrics:
# Searches all metrics done in a csv file.
def get_done_metrics(metric_results_filename_path="",
                     Y_hat_list_filename_path="", 
                     T_hat_list_filename_path="",
                     sims_list_filename_path=""):
  with open(metric_results_filename_path, 'r+', encoding='UTF8', newline='') as f:
    done_metrics = []
    PK_metrics = []
    WD_metrics = []
    Y_hat_list = []
    T_hat_list = []
    csv_reader = csv.reader(f, delimiter=',')
    line_count = 0
    for row_idx, row in enumerate(csv_reader):
      done_metrics.append(row[1])
      PK_metrics.append(row[2])
      WD_metrics.append(row[3])
  Y_hat_list = np.load(Y_hat_list_filename_path,allow_pickle=True).tolist()
  T_hat_list = np.load(T_hat_list_filename_path,allow_pickle=True).tolist()
  sims_list = np.load(sims_list_filename_path,allow_pickle=True).tolist()
  return done_metrics, PK_metrics, WD_metrics, Y_hat_list, T_hat_list, sims_list

# load_ami_dataset:
# Loads Transcripts from .transcripts/ folder and converts them into
# S:                list of M utterances S = {S_1,..., S_M}
# T:                Underlying topic structure Ti ∈ [Sj , Sk]
# Y:                Label sequence Y = {y1,.., yM} yi is binary indicates whether the utterance Si is the start of a new topic segment
# S_List:           List of S_i (utterances) for the i-th transcript or meeting
# T_List:           List of T_i (Topic changes tuples) for the i-th transcript or meeting
# Y_List:           List of Y_i (Topic changes flat) for the i-th transcript or meeting
# transcripts_list: List of transcript names. Meetings or video_id.
def load_ami_dataset(transcript_path=""):

  # Initiate variables
  S_list=[]
  T_list = []
  Y_list = []
  transcripts_list=[]
  meeting_transcripts=[]

  # Reads JSON from Folder
  try:
    transcripts = []
    meeting_transcripts = []
    transcripts_list = []
    for transcriptPath, transcriptDname, transcriptFname in os.walk(os.path.join(transcripts_path)):
      for transcript_name in transcriptFname:
        transcripts_list.append(transcript_name.replace(".json",""))
        transcript_path = os.path.join(transcripts_path,transcript_name)
        with open(transcript_path) as f:
          data = json.load(f)
        meeting_transcripts.append(data)
  except Exception as error:
    print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " ---  Error Loading JSON files - " + str(error))
    return S_list, T_list, Y_list, transcripts_list

  # Process data set to return Main Variables
  try:
    S = []                
    S_list = []            
    T = []
    T_list = []            
    Y = []            
    Y_list = []
    W_count = []
    W_count_list = []
    W_T_count = []
    W_T_count_list = []
    W_M_count = []

    T_start = 0
    T_prev = 0
    T_end = 0
    idx_prev = 0
    vocabulary = set()
    sentence_greater = []

    for meeting_idx, meeting_transcript in enumerate(meeting_transcripts):
      # Change this to get one big vector
      S = []
      T = []
      Y = []
      W_count = []
      W_T_count = []
      T_start = 0
      T_prev = 0
      T_end = 0
      idx_prev = 0
      meeting_word_count = 0

      transcript_id = transcripts_list[meeting_idx]
      # print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " ---  " + str(transcript_id) + " - Start Process")
      # Remove above to get one big vector
      for topics_idx, topics in enumerate(meeting_transcript):
        topic_word_count = 0
        for sentence_idx, sentence in enumerate(topics['sentences']):
          # Generate S Vector
          sentence_text = sentence['text']
          sentence_text = sentence_text.replace(" . ",". ")
          sentence_text = sentence_text.replace(" . ",". ")
          sentence_text = sentence_text.replace("[gap]","")
          sentence_text = sentence_text.replace("[vocalsound]","")
          sentence_text = sentence_text.replace("[disfmarker]","")
          sentence_text = sentence_text.replace("[transformerror]"," ")
          sentence_word_count = len(sentence_text.split())
          topic_word_count = topic_word_count + sentence_word_count
          meeting_word_count = meeting_word_count + sentence_word_count
          W_count.append(sentence_word_count)
          # print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " ---  " + str(transcript_id) + " - Topic: " + str(topics['topic_idx']) + " Sentence: " + str(sentence_idx) + " Word Count: " + str(sentence_word_count))

          # Check if Sentence over 512 Words
          if sentence_word_count > 512:
            print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " ---  Error wordcount over 512 " + str(transcript_id) + " - Topic: "  + str(topics_idx) + " title: " + str(topics['topic_idx']) + " Sentence: " + str(sentence_idx) + " Word Count: " + str(sentence_word_count))
            sentence_greater.append(sentence_text)
          S.append(sentence_text)
          # Generate T initial and T end 
          if sentence_idx == 0:
            T_start = T_prev
            Y.append(1)
          else:
            Y.append(0)

          # Create Vocabulary set
          for word in sentence_text.split():
            vocabulary.add(word)
        T_end = sentence_idx + T_prev
        T.append((T_start,T_end))
        T_prev = T_end + 1
        W_T_count.append(topic_word_count)
        # print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " ---  " + str(transcript_id) + " - Word count per topic: " + str(topic_word_count) + " Total Sentences: " + str(len(S)) + " Total sentence words: " + str(len(W_count)) + " Total topics: " + str(len(T)))
      S_list.append(S)
      T_list.append(T)
      Y_list.append(Y)
      W_count_list.append(W_count)
      W_T_count_list.append(W_T_count)
      W_M_count.append(meeting_word_count)
      # print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " ---  " + str(transcript_id) + " - Total Sentences: " + str(len(S)) + " Total Sentence word count: " + str(len(W_count)) + " Total topics: " + str(len(T)) + " Total topics word_count: " + str(len(W_T_count)))
    # test_sentenece = 25
    # print(Y_list[test_sentenece])
    # print(len(S_list[test_sentenece][0]))
    # print(T_list[test_sentenece])
    # print(sentence_greater)

    # print("Total Topics for meeting 25: " + str(len(meeting_transcripts[25])))
    # print("Example of meeting: " + str(meeting_transcripts[25]))
    # print("Meeting name: " + str(transcripts_list[25]))
    # print(len(meeting_transcripts[:1]))
    # print(len(meeting_transcripts[0]))

    # index_test = 115
    # print(len(S[index_test].split()))
    # print(W_count[index_test])
    # print(len(meeting_transcripts))
    # print(len(W_M_count))

    # index_test = 25
    for meeting_idx, meeting_transcript in enumerate(meeting_transcripts):
      meeting_word_count = 0
      transcript_id = transcripts_list[meeting_idx]
      # if meeting_idx != index_test:
      #   continue
      for topics_idx, topics in enumerate(meeting_transcript):
        for sentence_idx, sentence in enumerate(topics['sentences']):
          sentence_text = sentence['text']
          sentence_text = sentence_text.replace(" . ",". ")
          sentence_text = sentence_text.replace(" . ",". ")
          sentence_text = sentence_text.replace("[gap]","")
          sentence_text = sentence_text.replace("[vocalsound]","")
          sentence_text = sentence_text.replace("[disfmarker]","")
          sentence_text = sentence_text.replace("[transformerror]"," ")
          sentence_word_count = len(sentence_text.split())
          meeting_word_count = meeting_word_count + sentence_word_count
      # print(meeting_word_count)
    # print(W_M_count[index_test])
  except Exception as error:
    print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " ---  Error Processing AMI Dataset - " + str(error))
    return S_list, T_list, Y_list, transcripts_list

  return S_list, T_list, Y_list, transcripts_list

# load_youtube_dataset:
# Loads csvs from .YouTube/ folder and converts them into
# S:                list of M utterances S = {S_1,..., S_M}
# T:                Underlying topic structure Ti ∈ [Sj , Sk]
# Y:                Label sequence Y = {y1,.., yM} yi is binary indicates whether the utterance Si is the start of a new topic segment
# S_List:           List of S_i (utterances) for the i-th transcript or meeting
# T_List:           List of T_i (Topic changes tuples) for the i-th transcript or meeting
# Y_List:           List of Y_i (Topic changes flat) for the i-th transcript or meeting
# transcripts_list: List of transcript names. Meetings or video_id.
def load_youtube_dataset(dataset_path="", filename=""):
  # Initiate variables
  S_list=[]
  T_list = []
  Y_list = []
  transcripts_list=[]
  meeting_transcripts=[]

  # Reads CSV from Folder
  try:
    # filename = "yt_small_spacy_dev.csv"
    # filename = "yt_scripts_segments_yt_simple_110922.csv"
    csv_path = os.path.join(dataset_path,filename)
    yt_pods = pd.read_pickle(csv_path)
    # print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " --- Loaded CSV: " + str(csv_path))
  except Exception as error:
    print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " ---  Error Loading csv from YouTube Dataset - " + str(error))
    return S_list, T_list, Y_list, transcripts_list

  # Process data set to return Main Variables
  try:
    for video_idx, video_id in enumerate(yt_pods["Video_Id"]):
      transcripts_list.append(video_id)

    for sentences_idx, sentences in enumerate(yt_pods["Sentence_Word_Lists"]):
      S=[]
      for sentence_idx, sentence in enumerate(sentences):
        sentence_text = sentence[0]
        S.append(sentence_text)
      S_size = len(S)
      S_list.append(S)

    for transcript_labels_tuple_idx, transcript_labels_tuple in enumerate(yt_pods["Transition_Labels_Tuple"]):
      T_list.append(transcript_labels_tuple)

    for transcript_labels_idx, transcript_labels in enumerate(yt_pods["Transition_Labels"]):
      Y_list.append(transcript_labels)
      
    # print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " --- Processed Video: " + str(video_id) + " - Sentences: " + str(S_size))
  except Exception as error:
    print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + " --- Error processing Video: " + str(video_id) + " - Index: " +str(sentences_idx) + " - " + str(error))
    return S_list, T_list, Y_list, transcripts_list
  
  return S_list, T_list, Y_list, transcripts_list

# load_process_dataset:
# Process all datasets, and returns main variables and
# transcripts, daset path
def load_process_dataset(dataset_name="", filename=""):
  S_list=[]
  T_list=[]
  Y_list=[]
  W_count_list=[]
  meeting_transcripts = []
  transcripts_list = []
  dataset_path = ""
  transcripts_path = ""
  embeddings_path = ""
  dataset_path = os.path.join(path,'data/' + dataset_name + "/")
  transcripts_path = os.path.join(dataset_path,'transcripts/')
  if dataset_name == "AMIDataset":
    S_list, T_list, Y_list, transcripts_list = load_ami_dataset(transcripts_path)
  elif dataset_name == "YouTube":
    S_list, T_list, Y_list, transcripts_list = load_youtube_dataset(dataset_path=dataset_path,filename=filename)
  return (S_list, T_list, Y_list, W_count_list), (meeting_transcripts, transcripts_list), (dataset_path, transcripts_path)


Filename:                  yt_scripts_segments_split_n3_111422.csv
Dataset:                   YouTube
Model:                     SBERT
Pre-trained Model:         all-mpnet-base-v2
Dimensionality Redyction:  meanpooling


In [None]:
#@title Load Dataset
(S_list, T_list, Y_list, W_count_list), (meeting_transcripts, transcripts_list), (dataset_path, transcripts_path) = load_process_dataset(dataset_name = dataset_name, filename = filename)

# Print Test
index_test = 3
print("\n")
print("Transcripts path:            " + str(transcripts_path))
print("Dataset path:                " + str(dataset_path))
print("Numb of transcripts S_list:  " + str(len(S_list)))
print("Numb of Topics T_list:       " + str(len(T_list)))
print("Numb of Outputs Y_list:      " + str(len(Y_list)))
print("Transcripts name:            " + str(transcripts_list[index_test]))
print("Numb of sentences in Test    " + str(len(S_list[index_test])))
print("Numb of topics in Test       " + str(len(T_list[index_test])))
print("Numb of Outputs in Test      " + str(len(Y_list[index_test])))



Transcripts path:            /content/drive/MyDrive/W266/project/nlp_podcast_segmentation/data/YouTube/transcripts/
Dataset path:                /content/drive/MyDrive/W266/project/nlp_podcast_segmentation/data/YouTube/
Numb of transcripts S_list:  3731
Numb of Topics T_list:       3731
Numb of Outputs Y_list:      3731
Transcripts name:            M5IzFvBQ-Zo
Numb of sentences in Test    904
Numb of topics in Test       6
Numb of Outputs in Test      904


In [None]:
#@title Load pre-trained model

print("Loading Model: " + str(model_name) + " Pretrained model name: " + str(pre_trained_model))
if model_name == "SBERT":

  # Define specific function of SBERT

  def max_pooling(model_output, attention_mask):
    # Max Pooling - Take the max value over time for every dimension. 
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    token_embeddings[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
    return torch.max(token_embeddings, 1)[0]

  def mean_pooling(model_output, attention_mask):
    #Mean Pooling - Take attention mask into account for correct averaging
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

  # Load model from HuggingFace Hub (Sentence Bert Max Pooling)
  tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + pre_trained_model)
  model = AutoModel.from_pretrained('sentence-transformers/' + pre_trained_model)

elif model_name == "Universal Sentence Encoder":
  model = hub.load(pre_trained_model_url)
  print ("module %s loaded" % pre_trained_model)
  def use_embed(input):
    return model(input)

print("Loaded Model: " + str(model_name) + " Pretrained model name: " + str(pre_trained_model))

Loading Model: SBERT Pretrained model name: all-mpnet-base-v2


Downloading:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Loaded Model: SBERT Pretrained model name: all-mpnet-base-v2


In [None]:
#@title Create embeddins per transription

# Prepare Sentence and get embeddings for ALL transcript

# Prepare Variables
done_embeddings = []
sentence_embeddings_list = []
pre_trained_model_fullname = str(pre_trained_model + '-' + dim_redux_method)

# Define Embeddings path depending on Dataset
embeddings_path = os.path.join(dataset_path,'embeddings/', str(pre_trained_model_fullname + '/'))
if dataset_name == "YouTube":
  filename_folder = filename.split(".")[-2]
  embeddings_path = os.path.join(dataset_path,'embeddings/', str(filename_folder), str(pre_trained_model_fullname + '/'))

# Check if embeddings folder exists if not create it
if not os.path.exists(embeddings_path):
   os.makedirs(embeddings_path)
   print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + 
         " - " + str(pre_trained_model_fullname) + 
         " - Embedding folder not found. New folder created")

# List all embeddings in folder embeddings_path
done_embeddings = get_done_embeddings(embeddings_path)

# Run for all transcripts
for transcripts_list_idx, sentences in enumerate(S_list):

  # Start Process
  embedding_name = transcripts_list[transcripts_list_idx]
  
  if embedding_name in done_embeddings : # or embedding_name != "ES2013a":
    if print_debug == "Yes":
      print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + 
            " - " + str(pre_trained_model_fullname) + 
            " - " + str(embedding_name) + 
            " - done: " + str(transcripts_list_idx +1) + " of: " + str(len(S_list)) +
            " - Skipping embedding, already Done.")
    continue
  else:
    # Embeddings
    start_time = time.time()
    if print_debug == "Yes":
      print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + 
            " - " + str(pre_trained_model_fullname) + 
            " - " + str(embedding_name) + 
            " - done: " + str(transcripts_list_idx +1) + " of: " + str(len(S_list)) +
            " - Start -  Embeddings: " + str(len(sentences)))

    # Compute token embeddings
    if model_name == "SBERT":
      try:
        # Tokenize sentences
        encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt',max_length=3000)

        # Create Embedding
        with torch.no_grad():
            model_output = model(**encoded_input)
        
        # MAX or MEAN pooling
        if dim_redux_method == "maxpooling":
          sentence_embeddings = max_pooling(model_output, encoded_input['attention_mask'])
        elif dim_redux_method == "meanpooling":
          sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

        # Check if embeddings size is correct
        sentence_embeddings_size_total = sentence_embeddings.size()
        sentence_embeddings_size = sentence_embeddings_size_total[0]
        if int(sentence_embeddings_size) != len(sentences):
          print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + 
                " - " + str(pre_trained_model_fullname) + 
                " - " + str(embedding_name) + 
                " - done: " + str(transcripts_list_idx +1) + " of: " + str(len(S_list)) +
                " - ERROR on Emb Size Test: " + str(sentence_embeddings_size_total) + " != " + str(len(sentences)))
          break

        # Save tensor to file
        full_embeddings_path = os.path.join(embeddings_path, str(embedding_name + ".pt"))
        torch.save(sentence_embeddings,full_embeddings_path)

      except Exception as error:
        print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + 
              " - " + str(pre_trained_model_fullname) + 
              " - " + str(embedding_name) + 
              " - done: " + str(transcripts_list_idx +1) + " of: " + str(len(S_list)) +
              " - Error computing embedding, skipping... " + str(error))
        continue
    
    elif model_name == "Universal Sentence Encoder":
      try:

        # Create Embedding
        sentence_embeddings = use_embed(sentences)

        # Check if embeddings size is correct
        sentence_embeddings_size_total = sentence_embeddings.shape
        sentence_embeddings_size = sentence_embeddings_size_total[0]
        if int(sentence_embeddings_size) != len(sentences):
          print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + 
                " - " + str(pre_trained_model_fullname) + 
                " - " + str(embedding_name) + 
                " - done: " + str(transcripts_list_idx +1) + " of: " + str(len(S_list)) +
                " - ERROR on Emb Size Test: " + str(sentence_embeddings_size_total) + " != " + str(len(sentences)))
          break

        # Save tensor to file
        sentence_embeddings_np = sentence_embeddings.numpy()
        full_embeddings_path = os.path.join(embeddings_path, str(embedding_name + ".npy"))
        np.save(full_embeddings_path, np.asarray(sentence_embeddings_np, dtype=object))
        # one_string_tensor = tf.strings.format("{}", sentence_embeddings)
        # tf.io.write_file(full_embeddings_path, one_string_tensor)
      
      except Exception as error:
        print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + 
              " - " + str(pre_trained_model_fullname) + 
              " - " + str(embedding_name) + 
              " - done: " + str(transcripts_list_idx +1) + " of: " + str(len(S_list)) +
              " - Error computing embedding, skipping... " + str(error))
        continue

    done_embeddings.append(embedding_name)

    print(datetime.today().strftime('%Y-%m-%d %H:%M:%S') + 
          " - " + str(pre_trained_model_fullname) + 
          " - " + str(embedding_name) + 
          " - done: " + str(transcripts_list_idx +1) + " of: " + str(len(S_list)) +
          " - Emb Size: " + str(sentence_embeddings_size_total))


[1;30;43mSe han truncado las últimas 5000 líneas del flujo de salida.[0m
2022-12-03 03:32:02 - all-mpnet-base-v2-meanpooling - VIx8a_Ywrvg - done: 258 of: 3731 - Emb Size: torch.Size([660, 768])
2022-12-03 03:32:02 - all-mpnet-base-v2-meanpooling - AE5Bqw3e4lQ - done: 259 of: 3731 - Start -  Embeddings: 536
2022-12-03 03:32:18 - all-mpnet-base-v2-meanpooling - AE5Bqw3e4lQ - done: 259 of: 3731 - Emb Size: torch.Size([536, 768])
2022-12-03 03:32:18 - all-mpnet-base-v2-meanpooling - 2PNNIlay4cw - done: 260 of: 3731 - Start -  Embeddings: 716
2022-12-03 03:32:40 - all-mpnet-base-v2-meanpooling - 2PNNIlay4cw - done: 260 of: 3731 - Emb Size: torch.Size([716, 768])
2022-12-03 03:32:40 - all-mpnet-base-v2-meanpooling - d_HNYsEEHlQ - done: 261 of: 3731 - Start -  Embeddings: 618
2022-12-03 03:32:59 - all-mpnet-base-v2-meanpooling - d_HNYsEEHlQ - done: 261 of: 3731 - Emb Size: torch.Size([618, 768])
2022-12-03 03:32:59 - all-mpnet-base-v2-meanpooling - wIHJtlXc0CI - done: 262 of: 3731 - Start 

KeyboardInterrupt: ignored