This is an example to show how to use narration features in Ego4D.

This notebook:
1. Data Preparation
    - data pre-processing
    - couple two narration sets via nearest matching

2. (Deliverable) Fine-tuned model
    - filter out mismatched pairs (cos-sim < 0.5)
    - train model in 1M narration pairs without splitting test set
    - model saved in 'final_model_path'

3. (Deliverable) Refined narration embeddings
    - encode narrations from annotator1
    - semantic search (compared with pre-trained embedding results)


In [None]:
import random
import math
import time
import os
import json
import torch
import numpy as np
import pandas as pd 

import plotly.express as px

from sentence_transformers import SentenceTransformer, util, LoggingHandler, losses, InputExample
import scipy.spatial as sp, scipy.cluster.hierarchy as hc
from sklearn.metrics import adjusted_mutual_info_score
from sklearn.metrics.pairwise import euclidean_distances
from scipy.spatial import distance

from torch.utils.data import DataLoader
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, BinaryClassificationEvaluator, TripletEvaluator
import matplotlib.pyplot as plt
import pickle

import seaborn as sns

from sklearn.manifold import TSNE

from moviepy.editor import VideoFileClip

from multinegative_loss import MultipleNegativesLoss

## 1. Data Preparation
Load narration data

In [None]:
NARRATION_PATH='/datasets01/ego4d_track2/v1/annotations/narration.json'
narrations = json.load(open(NARRATION_PATH))

Collect narrations from two annotator sources

In [None]:
uid_stamp_narrations_pass_1=[(uid,nar['timestamp_sec'],nar['timestamp_frame'],nar['narration_text']) for uid,value in narrations.items() if value['status']!='redacted' for nar in value['narration_pass_1']['narrations']]

uid_stamp_narrations_pass_2=[(uid,nar['timestamp_sec'],nar['timestamp_frame'],nar['narration_text']) for uid,value in narrations.items() if value['status']!='redacted' for nar in value['narration_pass_2']['narrations']]

In [None]:
narration_uid_pass_1=np.array([uid for uid, _, _, _ in uid_stamp_narrations_pass_1])
narration_stamp_sec_pass_1=np.array([stamp_sec for _, stamp_sec, _, _ in uid_stamp_narrations_pass_1])
narration_stamp_frame_pass_1=np.array([stamp_frame for _, _, stamp_frame, _ in uid_stamp_narrations_pass_1])
narration_text_pass_1=[text for _, _, _, text in uid_stamp_narrations_pass_1]

narration_uid_pass_2=np.array([uid for uid, _, _, _ in uid_stamp_narrations_pass_2])
narration_stamp_sec_pass_2=np.array([stamp_sec for _, stamp_sec, _, _ in uid_stamp_narrations_pass_2])
narration_stamp_frame_pass_2=np.array([stamp_frame for _, _, stamp_frame, _ in uid_stamp_narrations_pass_2])
narration_text_pass_2=[text for _, _, _, text in uid_stamp_narrations_pass_2]

In [None]:
from collections import defaultdict

dict_1 = defaultdict(dict)
for uid, stamp_sec, stamp_frame, text in uid_stamp_narrations_pass_1:
    dict_1[uid][stamp_frame]=text

dict_2 = defaultdict(dict)
for uid, stamp_sec, stamp_frame, text in uid_stamp_narrations_pass_2:
    dict_2[uid][stamp_frame]=text

couple narration pairs via nearest matching (within a stamp threshold)

In [None]:
from scipy.spatial import distance_matrix

# threshold: couple narrations if stamp_frame < threshold
def collect_narration_pair(thresh):
    output={}
    for uid in set(narration_uid_pass_1).intersection(set(narration_uid_pass_2)):
        sentence_pair={}

        stamp_frame_1=narration_stamp_frame_pass_1[narration_uid_pass_1==uid]
        stamp_frame_2=narration_stamp_frame_pass_2[narration_uid_pass_2==uid]

        stamp_sec_1=narration_stamp_sec_pass_1[narration_uid_pass_1==uid]
        stamp_sec_2=narration_stamp_sec_pass_2[narration_uid_pass_2==uid]

        stamp_dist=distance_matrix(stamp_frame_1.reshape(-1,1),stamp_frame_2.reshape(-1,1))
        stamp_frame_pair=[];stamp_sec_pair=[];narr_pair=[]
        
        while np.min(stamp_dist, axis=None)<=thresh: # find neareast frames under threshold

            indx_1, indx_2=np.unravel_index(np.argmin(stamp_dist, axis=None), stamp_dist.shape)
            stamp_frame_pair.append((stamp_frame_1[indx_1], stamp_frame_2[indx_2]))
            stamp_sec_pair.append((stamp_sec_1[indx_1], stamp_sec_2[indx_2]))
            narr_pair.append([dict_1[uid][stamp_frame_1[indx_1]],dict_2[uid][stamp_frame_2[indx_2]]])

            stamp_dist[indx_1,:]=10000 # remove the paired sentences if matched
            stamp_dist[:,indx_2]=10000

        sentence_pair['stamp_frame_pair']=stamp_frame_pair
        sentence_pair['stamp_sec_pair']=stamp_sec_pair
        sentence_pair['narration_pair']=narr_pair

        output[uid]=sentence_pair
    return output

In [None]:
pair_data_thresh_20=collect_narration_pair(thresh=20) #set frame delta 20

In [None]:
# write data
f = open("narration_pair_data_thresh_20.pkl","wb") # create a binary pickle file 
pickle.dump(pair_data_thresh_20,f) # write the python object (dict) to pickle file
f.close() # close file

# reload data
#with open('narration_pair_data_thresh_20.pkl', 'rb') as f:
#    pair_data_thresh_20 = pickle.load(f)

data preprocessing

In [None]:
nar_comp=[n for k, v in pair_data_thresh_20.items() for f,n in zip(v['stamp_frame_pair'],v['narration_pair'])]
df=pd.DataFrame(nar_comp).apply(lambda x: x.str.lower().replace({'^#[a-z]':'','^#\s+[a-z]':'','^c\s+c':'c','#unsure':'','#':'','\s+':' '},regex=True).str.strip())
df=df.apply(lambda x: x.str.lower().replace({'\sc\s':' person c ', '^c\s':'person c ', '\sc$':' person c', '\.$':''},regex=True))
pair_narration_all=df.values.tolist()

In [None]:
nar_comp=[n for k, v in pair_data_thresh_20.items() for f,n in zip(v['stamp_frame_pair'],v['narration_pair']) if abs(f[1]-f[0])<=5]
df=pd.DataFrame(nar_comp).apply(lambda x: x.str.lower().replace({'^#[a-z]':'','^#\s+[a-z]':'','^c\s+c':'c','#unsure':'','#':'','\s+':' '},regex=True).str.strip())
df=df.apply(lambda x: x.str.lower().replace({'\sc\s':' person c ', '^c\s':'person c ', '\sc$':' person c', '\.$':''},regex=True))
pair_narration_train=df.values.tolist()

In [None]:
nar_comp=[n for k, v in pair_data_thresh_20.items() for f,n in zip(v['stamp_frame_pair'],v['narration_pair']) if (abs(f[1]-f[0])>5 and abs(f[1]-f[0])<=10)]
df=pd.DataFrame(nar_comp).apply(lambda x: x.str.lower().replace({'^#[a-z]':'','^#\s+[a-z]':'','^c\s+c':'c','#unsure':'','#':'','\s+':' '},regex=True).str.strip())
df=df.apply(lambda x: x.str.lower().replace({'\sc\s':' person c ', '^c\s':'person c ', '\sc$':' person c', '\.$':''},regex=True))
pair_narration_test=df.values.tolist()

In [None]:
# sample size; stamp_sec delta
print('frame_thresh: 20')
len(pair_narration_all), max([abs(v_i[1]-v_i[0]) for k, v in pair_data_thresh_20.items() for v_i in v['stamp_sec_pair']])

In [None]:
len(pair_narration_train), len(pair_narration_test)

calcuate consine similarity for each positive pairs

In [None]:
pre_trained_model = SentenceTransformer('all-mpnet-base-v2')

sim_pre_trained_model=[]
for i in range(len(pair_narration_all)):
    pair_embed=pre_trained_model.encode(pair_narration_all[i])
    sim_pre_trained_model.append(util.cos_sim(pair_embed[0],pair_embed[1]).item())

# save result in txt
output_file = open('cos_sim_rec/all_data_pre_trained_model.txt', 'w')
for x in sim_pre_trained_model:
    output_file.write(str(x) + '\n')
output_file.close()

filter out mismatched pairs if similarity less than a threshold (sim_thresh)

In [None]:
sim_thresh=0.5

list_under_thresh=[1 if abs(f[1]-f[0])<=5 else 0 for k, v in pair_data_thresh_20.items() for f in v['stamp_frame_pair']]
x=sum([list_under_thresh[i] for i in range(len(pair_narration_all)) if sim_pre_trained_model[i]>sim_thresh])
print('when frame_delta= 5, # of samples < sim thresh: ',
    x, 'out of', sum(list_under_thresh), 'is', 100*round(x/sum(list_under_thresh),4),'%'
    )

list_under_thresh=[1 if abs(f[1]-f[0])<=10 else 0 for k, v in pair_data_thresh_20.items() for f in v['stamp_frame_pair']]
x=sum([list_under_thresh[i] for i in range(len(pair_narration_all)) if sim_pre_trained_model[i]>sim_thresh])
print('when frame_delta=10, # of samples < sim thresh: ',
    x, 'out of', sum(list_under_thresh), 'is', 100*round(x/sum(list_under_thresh),4),'%'
    )

list_under_thresh=[1 if abs(f[1]-f[0])<=15 else 0 for k, v in pair_data_thresh_20.items() for f in v['stamp_frame_pair']]
x=sum([list_under_thresh[i] for i in range(len(pair_narration_all)) if sim_pre_trained_model[i]>sim_thresh])
print('when frame_delta=15, # of samples < sim thresh: ',
    x, 'out of', sum(list_under_thresh), 'is', 100*round(x/sum(list_under_thresh),4),'%'
    )

list_under_thresh=[1 if abs(f[1]-f[0])<=20 else 0 for k, v in pair_data_thresh_20.items() for f in v['stamp_frame_pair']]
x=sum([list_under_thresh[i] for i in range(len(pair_narration_all)) if sim_pre_trained_model[i]>sim_thresh])
print('when frame_delta=20, # of samples < sim thresh: ',
    x, 'out of', sum(list_under_thresh), 'is', 100*round(x/sum(list_under_thresh),4),'%'
    )

## 2. (Delierable) Fine-tuning all filtered samples without spliting test

In [None]:
sim_thresh=0.5

pair_narration_all_filtered=[pair_narration_all[i] for i in range(len(pair_narration_all)) if sim_pre_trained_model[i]>sim_thresh]

split 20% training data as validate used to pick the best model

In [None]:
valid_ratio=0.2

pair_all_train_samples=[]
pair_all_valid_samples=[]
process_label=set(range(len(pair_narration_all_filtered)))
for i in process_label:
    pos_sample=pair_narration_all_filtered[i]
    
    np.random.seed(100*i+1)
    rand_num=np.random.random(1) # generate sample in either train or validte set
    if rand_num < valid_ratio:
        random.seed(i)
        neg_indx=random.sample(list(process_label-{i}),1)[0] # random select a negative pair
        neg_sample=pair_narration_all_filtered[neg_indx]
        inp_example = [
        InputExample(texts=[pos_sample[0], pos_sample[1]], label=1), # one positive sample
        InputExample(texts=[pos_sample[0], neg_sample[1]], label=0)  # one negative sample
        ]
        pair_all_valid_samples+=inp_example
        
    else:
        inp_example = [
        InputExample(texts=[pos_sample[0], pos_sample[1]]) # one positive sample
        ]
        pair_all_train_samples+=inp_example

In [None]:
num_epochs=8
batch_size=32

model = SentenceTransformer('all-mpnet-base-v2')
evaluator0 = BinaryClassificationEvaluator.from_input_examples(pair_all_valid_samples, name='validation')
train_dataloader = DataLoader(pair_all_train_samples, shuffle=True, batch_size=batch_size, drop_last=True)
train_loss = MultipleNegativesLoss(model=model, neg_size=-1) #proportion of negatives; use (n-1) negatives if -1
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up

# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          evaluator=evaluator0,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          show_progress_bar=True,
          output_path='final_model_path')

In [None]:
validation_results=pd.read_csv("final_model_path/eval/binary_classification_evaluation_validation_results.csv")

plt.figure(figsize=(10,7))

plt.plot(np.array(range(len(validation_results)))*1000,validation_results.cossim_ap,label='ap')

plt.plot(np.array(range(len(validation_results)))*1000,validation_results.cossim_accuracy,label='acc')

plt.vlines(x=np.argmax(validation_results.cossim_ap)*1000,
           ymin=min(validation_results.cossim_ap.min(),validation_results.cossim_accuracy.min()),
           ymax=max(validation_results.cossim_ap.max(),validation_results.cossim_accuracy.max())+0.005,
          color='red',linestyle='dotted')

plt.xlabel('iteration steps')

plt.legend();

## 3. (Deliverable) Narration Embeddings
### 3.1 encode narrations from annotator1 as corpus

In [None]:
uid_stamp_narrations=[(uid,nar['timestamp_sec'],nar['timestamp_frame'],nar['narration_text']) for uid,value in narrations.items() if value['status']!='redacted' for nar in value['narration_pass_1']['narrations']]

narration_uid=[uid for uid, _, _, _ in uid_stamp_narrations]
narration_stamp_sec=[stamp_sec for _, stamp_sec, _, _ in uid_stamp_narrations]
narration_stamp_frame=[stamp_frame for _, _, stamp_frame, _ in uid_stamp_narrations]
narration_text=[text for _, _, _, text in uid_stamp_narrations]

In [None]:
df=pd.DataFrame(narration_text).apply(lambda x: x.str.lower().replace({'^#[a-z]':'','^#\s+[a-z]':'','^c\s+c':'c','#unsure':'','#':'','\s+':' '},regex=True).str.strip())
df=df.apply(lambda x: x.str.lower().replace({'\sc\s':' person c ', '^c\s':'person c ', '\sc$':' person c', '\.$':''},regex=True))
narration_text_proceeded=df.iloc[:,0].tolist()

### 3.2 semantic search (compared with pre-trained model)

In [None]:
final_model = SentenceTransformer('final_model_path')

corpus_embeddings = final_model.encode(narration_text_proceeded, convert_to_tensor=True)

In [None]:
pretrain_model = SentenceTransformer('all-mpnet-base-v2')

pretrain_corpus_embeddings = pretrain_model.encode(narration_text_proceeded, convert_to_tensor=True)

In [None]:
from collections import defaultdict

def corpus_search(model, query, corpus_embeddings, top_k):
    query_embeddings = model.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(query_embeddings, corpus_embeddings, score_function=util.cos_sim, top_k=top_k) #topk narrations
    
    uid_dict = defaultdict(list) #aggregate the k narrations into a dict
    for x in hits[0]:
        uid=narration_uid[x['corpus_id']]  
        uid_dict[uid].append( (narration_stamp_sec[x['corpus_id']],narration_text[x['corpus_id']],x['score']) )

    return uid_dict

Compare the search results from pre-trained model

In [None]:
query = ['eating a meal']

output=corpus_search(final_model, query, corpus_embeddings, 10000)

print('# of videos with cosine similarity > 0.6: ',len([k for k,v in output.items() if v[0][2]>0.6]))

[(k,v[0]) for k,v in output.items()]


In [None]:
query = ['eating a meal']

output=corpus_search(pretrain_model, query, pretrain_corpus_embeddings, 10000)

print('# of videos with cosine similarity > 0.6: ',len([k for k,v in output.items() if v[0][2]>0.6]))

[(k,v[0]) for k,v in output.items()]


In [None]:
query = ['construction']

output=corpus_search(final_model, query, corpus_embeddings, 10000)

print('# of videos with cosine similarity > 0.6: ',len([k for k,v in output.items() if v[0][2]>0.6]))

[(k,v[0]) for k,v in output.items()]

In [None]:
query = ['construction']

output=corpus_search(pretrain_model, query, pretrain_corpus_embeddings, 10000)

print('# of videos with cosine similarity > 0.6: ',len([k for k,v in output.items() if v[0][2]>0.6]))

[(k,v[0]) for k,v in output.items()]

In [None]:
query = ['tree']

output=corpus_search(final_model, query, corpus_embeddings, 10000)

print('# of videos with cosine similarity > 0.6: ',len([k for k,v in output.items() if v[0][2]>0.6]))

[(k,v[0]) for k,v in output.items()]

In [None]:
query = ['tree']

output=corpus_search(pretrain_model, query, pretrain_corpus_embeddings, 10000)

print('# of videos with cosine similarity > 0.6: ',len([k for k,v in output.items() if v[0][2]>0.6]))

[(k,v[0]) for k,v in output.items()]