In this notebook, we attempt to find semantically similar classes between the following pairs of datasets: 

> 1) VGGSound - UCF101

> 2) VGGSound - HMDB51

> 3) AudioSet - UCF101

> 4) AudioSet - HMDB51

> 5) Kinetics400 - UCF101

> 6) Kinetics400 - HMDB51

, respectively. This could be useful to establish a connection between the
pretraining dataset (in a self-supervised learning setting) and the target (downstream) datasets.

### 1) Install [SentenceTransformers](https://www.sbert.net/)

In [None]:
!pip install -U sentence-transformers

Collecting sentence-transformers
  Downloading sentence-transformers-2.2.0.tar.gz (79 kB)
[K     |████████████████████████████████| 79 kB 2.7 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 32.1 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 31.7 MB/s 
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 1.4 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 18.8 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |███████████████████████████████

### 2) Download and extract files

In [None]:
%%bash
wget -q https://www.robots.ox.ac.uk/~vgg/data/vggsound/vggsound.csv

In [None]:
%%bash
wget -q --no-check-certificate https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip
unzip -q UCF101TrainTestSplits-RecognitionTask.zip
mv ucfTrainTestlist/classInd.txt ./ucf101_classes.txt

rm UCF101TrainTestSplits-RecognitionTask.zip
rm -rf ucfTrainTestlist/

In [None]:
%%bash
wget -q http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar
unrar x test_train_splits.rar -inul
rm test_train_splits.rar

touch hmdb51_classes.txt
cd testTrainMulti_7030_splits/
for filename in ./*_split1.txt; do
    filename=${filename##*/}
    echo "$filename" >> "../hmdb51_classes.txt"
done

In [None]:
!rm -rf testTrainMulti_7030_splits/

### 3) Extract class names for each dataset

In [None]:
import csv
import numpy as np
import re
import string

In [None]:
vggsound_classes = []
with open('vggsound.csv', 'r') as f:
    reader = csv.reader(f)
    for line in reader:
        vggsound_classes.append(line[-2])
vggsound_classes = np.unique(vggsound_classes).tolist()
print(len(vggsound_classes))
# remove punctuation
for i, cls in enumerate(vggsound_classes):
    vggsound_classes[i] = cls.translate(str.maketrans('', '', string.punctuation))

309


In [None]:
ucf101_classes = []
with open('ucf101_classes.txt', 'r') as f:
    for line in f.readlines():
        ucf101_classes.append(line.rstrip().split()[1])
print(len(ucf101_classes))
# split class names into multiple tokens
for idx, cls in enumerate(ucf101_classes):
    ucf101_classes[idx] = (' '.join(re.findall('[A-Z][a-z]*', cls))).lower()

101


In [None]:
hmdb51_classes = []
with open('hmdb51_classes.txt', 'r') as f:
    for line in f.readlines():
        res = line.rstrip().split('_')
        idx = res.index('test')
        hmdb51_classes.append(' '.join(res[:idx]))
print(len(hmdb51_classes))

51


In [None]:
# print a few class names
print(vggsound_classes[:5], '\n', ucf101_classes[:5], '\n', hmdb51_classes[:5])

['air conditioning noise', 'air horn', 'airplane', 'airplane flyby', 'alarm clock ringing'] 
 ['apply eye makeup', 'apply lipstick', 'archery', 'baby crawling', 'balance beam'] 
 ['brush hair', 'cartwheel', 'catch', 'chew', 'clap']


### 4) Initialize pretrained LMs, extract embeddings and calculate similarities

In [None]:
from sentence_transformers import SentenceTransformer, util

2 models used here: 1) [MiniLM](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) and 2) [Phrase-BERT](https://huggingface.co/whaleloops/phrase-bert)

In [None]:
mini_lm = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
phrase_bert = SentenceTransformer('whaleloops/phrase-bert')

Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.23k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.41k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/670 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/632 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [None]:
def MiniLM_similarities(query_class, target_classes, topn=5):
    
    # embedding for query class
    q_emb = mini_lm.encode(query_class, convert_to_tensor=True)

    all_sims = []
    for cls in target_classes:
        # embedding for target class
        t_emb = mini_lm.encode(cls, convert_to_tensor=True)
        # cosine similarity
        sim = util.pytorch_cos_sim(q_emb, t_emb)
        all_sims.append(sim.item())

    trg = [target_classes[i] for i in np.argsort(- np.array(all_sims))[:topn]]
    all_sims.sort(reverse=True)  # in descending order
    
    return trg, all_sims[:topn]

In [None]:
import torch
import torch.nn as nn

def PhraseBERT_similarities(query_emb, target_embs, topn=5):
    """
    Args:
        query_emb: Embedding for query class, numpy.ndarray of size [embedding_dim]
        target_embs: Embeddings for all target classes, numpy.ndarray of size [N x embedding_dim],
            where N: total number of target classes
        topn: Number of target classes that are most similar to the query class
    """

    query_emb = torch.tensor(query_emb).unsqueeze(0)
    target_embs = torch.tensor(target_embs)

    cos_sim = nn.CosineSimilarity(dim=1)
    similarities = cos_sim(query_emb, target_embs)
    max_sims, indices = torch.topk(similarities, topn)

    return indices, max_sims

#### VGGSound - UCF101 pair

In [None]:
TOPN_RESULTS = 3

In [None]:
# Results using MiniLM for VGGSound - UCF101 pair

trg_classes_minilm, sims_minilm = [], []

for query_cls in vggsound_classes:
    trg, sim = MiniLM_similarities(query_cls, ucf101_classes, topn=TOPN_RESULTS)
    trg_classes_minilm.append(trg)
    sims_minilm.append(sim)

In [None]:
# Results using Phrase-BERT for VGGSound - UCF101 pair

vgg_embs = phrase_bert.encode(vggsound_classes)
ucf_embs = phrase_bert.encode(ucf101_classes)  # get embeddings

trg_classes_phrase, sims_phrase = [], []

for q_emb in vgg_embs:
    indices, sims = PhraseBERT_similarities(q_emb, ucf_embs, topn=TOPN_RESULTS)
    sims_phrase.append(sims)
    trg_classes_phrase.append([ucf101_classes[idx] for idx in indices])

In [None]:
from prettytable import PrettyTable

# pretty print results
t = PrettyTable(["Query class", "Target class (MiniLM)", "Similarity (MiniLM)", "Target class (Phrase-BERT)", "Similarity (Phrase-BERT)"])

for q, t1, s1, t2, s2 in zip(vggsound_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if len(s1) > 1:
        t.add_row([q, '\n'.join(t1), '\n'.join([f"{s:.4f}" for s in s1]), '\n'.join(t2), '\n'.join([f"{s:.4f}" for s in s2])])
    else:
        t.add_row([q, t1[0], f"{s1[0]:.4f}", t2[0], f"{s2[0]:.4f}"])
    
    t.add_row(['', '', '', '', ''])

print(t)

+-----------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|               Query class               | Target class (MiniLM) | Similarity (MiniLM) | Target class (Phrase-BERT) | Similarity (Phrase-BERT) |
+-----------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|          air conditioning noise         |     mopping floor     |        0.1746       |           skijet           |          0.6056          |
|                                         |         lunges        |        0.1445       |      blowing candles       |          0.5895          |
|                                         |     blow dry hair     |        0.1437       |      floor gymnastics      |          0.5877          |
|                                         |                       |                     |                            |      

Now we want to filter the results shown in the table above.
In order to keep some (or all) of the predicted target classes, two criteria need to be fulfilled: 

- The similarity between the query class and the top1 prediction of MiniLM model should exceed a predefined threshold (set empirically).

- The intersection between the top-N predictions of both models should not be an empty set.

Furthermore, we filter a few outliers manually (e.g. skijet, pizza tossing and shotput).

In [None]:
# print final results

sim_threshold = 0.47
table2 = PrettyTable(["Query class", "Final target class(es)"])

final_targets = []
for q, t1, s1, t2, s2 in zip(vggsound_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if s1[0] > sim_threshold and len(set(t1) & set(t2)) != 0:
        # manually filter some outlier target classes (optional step)
        res = [temp_t for temp_t in set(t1) & set(t2) if temp_t not in ['skijet', 'shotput', 'pizza tossing']]
        if len(res) == 0:
            table2.add_row([q, '-'])
        else:
            table2.add_row([q, '\n'.join(res)])
            final_targets.extend(res)
    else:
        table2.add_row([q, '-'])
    table2.add_row(['', ''])

print(table2)

+-----------------------------------------+------------------------+
|               Query class               | Final target class(es) |
+-----------------------------------------+------------------------+
|          air conditioning noise         |           -            |
|                                         |                        |
|                 air horn                |           -            |
|                                         |                        |
|                 airplane                |           -            |
|                                         |                        |
|              airplane flyby             |           -            |
|                                         |                        |
|           alarm clock ringing           |           -            |
|                                         |                        |
|      alligators crocodiles hissing      |           -            |
|                                 

In [None]:
# print final outcome

print(f"Approximately, {len(np.unique(final_targets))}/101 classes of UCF-101 dataset are similar to those of VGGSound dataset.")

Approximately, 54/101 classes of UCF-101 dataset are similar to those of VGGSound dataset.


In [None]:
# print rest of the classes from UCF-101 + save them in txt

rest_classes = set(ucf101_classes) - set(final_targets)
print("Classes from UCF-101 dataset that are not covered by the pretraining dataset in total: {}/101.".format(len(rest_classes)))

with open('ucf_rest_classes.txt', 'w') as f:
    for cl in rest_classes:
        f.write(cl + '\n')

Classes from UCF-101 dataset that are not covered by the pretraining dataset in total: 47/101.


#### VGGSound - HMDB51 pair

In [None]:
TOPN_RESULTS = 3

In [None]:
# Results using MiniLM for VGGSound - HMDB51 pair

trg_classes_minilm, sims_minilm = [], []

for query_cls in vggsound_classes:
    trg, sim = MiniLM_similarities(query_cls, hmdb51_classes, topn=TOPN_RESULTS)
    trg_classes_minilm.append(trg)
    sims_minilm.append(sim)

In [None]:
# Results using Phrase-BERT for VGGSound - HMDB51 pair

vgg_embs = phrase_bert.encode(vggsound_classes)
hmdb_embs = phrase_bert.encode(hmdb51_classes)  # get embeddings

trg_classes_phrase, sims_phrase = [], []

for q_emb in vgg_embs:
    indices, sims = PhraseBERT_similarities(q_emb, hmdb_embs, topn=TOPN_RESULTS)
    sims_phrase.append(sims)
    trg_classes_phrase.append([hmdb51_classes[idx] for idx in indices])

In [None]:
from prettytable import PrettyTable

# pretty print results
t = PrettyTable(["Query class", "Target class (MiniLM)", "Similarity (MiniLM)", "Target class (Phrase-BERT)", "Similarity (Phrase-BERT)"])

for q, t1, s1, t2, s2 in zip(vggsound_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if len(s1) > 1:
        t.add_row([q, '\n'.join(t1), '\n'.join([f"{s:.4f}" for s in s1]), '\n'.join(t2), '\n'.join([f"{s:.4f}" for s in s2])])
    else:
        t.add_row([q, t1[0], f"{s1[0]:.4f}", t2[0], f"{s2[0]:.4f}"])
    
    t.add_row(['', '', '', '', ''])

print(t)

+-----------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|               Query class               | Target class (MiniLM) | Similarity (MiniLM) | Target class (Phrase-BERT) | Similarity (Phrase-BERT) |
+-----------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|          air conditioning noise         |         smoke         |        0.1755       |           smoke            |          0.6367          |
|                                         |       fall floor      |        0.1234       |            wave            |          0.6052          |
|                                         |          wave         |        0.1221       |         handstand          |          0.5716          |
|                                         |                       |                     |                            |      

In [None]:
# print final results

sim_threshold = 0.46
table2 = PrettyTable(["Query class", "Final target class(es)"])

final_targets = []
for q, t1, s1, t2, s2 in zip(vggsound_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if s1[0] > sim_threshold and len(set(t1) & set(t2)) != 0:
        # manually filter some outlier target classes (optional step)
        res = [temp_t for temp_t in set(t1) & set(t2) if temp_t not in ['hug', 'shoot bow', 'cartwheel', 'smoke', 'somersault', 'wave', 'pushup', 'handstand']]
        if len(res) == 0:
            table2.add_row([q, '-'])
        else:
            table2.add_row([q, '\n'.join(res)])
            final_targets.extend(res)
    else:
        table2.add_row([q, '-'])
    table2.add_row(['', ''])

print(table2)

+-----------------------------------------+------------------------+
|               Query class               | Final target class(es) |
+-----------------------------------------+------------------------+
|          air conditioning noise         |           -            |
|                                         |                        |
|                 air horn                |           -            |
|                                         |                        |
|                 airplane                |           -            |
|                                         |                        |
|              airplane flyby             |           -            |
|                                         |                        |
|           alarm clock ringing           |           -            |
|                                         |                        |
|      alligators crocodiles hissing      |           -            |
|                                 

In [None]:
# print final outcome

print(f"Approximately, {len(np.unique(final_targets))}/51 classes of HMDB-51 dataset are similar to those of VGGSound dataset.")

Approximately, 22/51 classes of HMDB-51 dataset are similar to those of VGGSound dataset.


In [None]:
# print rest of the classes from HMDB-51 + save them in txt

rest_classes = set(hmdb51_classes) - set(final_targets)
print("Classes from HMDB-51 dataset that are not covered by the pretraining dataset in total: {}/51.".format(len(rest_classes)))

with open('hmdb_rest_classes.txt', 'w') as f:
    for cl in rest_classes:
        f.write(cl + '\n')

Classes from HMDB-51 dataset that are not covered by the pretraining dataset in total: 29/51.


#### Same experiments using Kinetics-400 [Link](https://github.com/cvdfoundation/kinetics-dataset#kinetics-400-info) and AudioSet [Link](https://raw.githubusercontent.com/facebookresearch/AVID-CMA/main/datasets/cache/audioset/class_labels_indices.csv) classes

More info about AudioSet's ontology can be found [here](https://research.google.com/audioset/ontology/index.html)



In [None]:
!wget https://raw.githubusercontent.com/facebookresearch/AVID-CMA/main/datasets/cache/audioset/class_labels_indices.csv

--2022-03-06 15:56:57--  https://raw.githubusercontent.com/facebookresearch/AVID-CMA/main/datasets/cache/audioset/class_labels_indices.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 14675 (14K) [text/plain]
Saving to: ‘class_labels_indices.csv’


2022-03-06 15:56:57 (23.2 MB/s) - ‘class_labels_indices.csv’ saved [14675/14675]



In [None]:
audioset_classes = []

with open('class_labels_indices.csv', 'r') as f:
    reader = csv.reader(f)
    next(reader, None)
    for line in reader:
        audioset_classes.append(line[-1].rstrip().lower())

audioset_classes = np.unique(audioset_classes)
print(len(audioset_classes))
# remove punctuation
for i, cls in enumerate(audioset_classes):
    audioset_classes[i] = cls.translate(str.maketrans('', '', string.punctuation))

print(audioset_classes[:5])

527
['a capella' 'accelerating revving vroom' 'accordion' 'acoustic guitar'
 'afrobeat']


In [None]:
!wget https://s3.amazonaws.com/kinetics/400/annotations/train.csv

--2022-03-06 16:25:37--  https://s3.amazonaws.com/kinetics/400/annotations/train.csv
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.81.150
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.81.150|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10529265 (10M) [text/csv]
Saving to: ‘train.csv’


2022-03-06 16:25:38 (45.8 MB/s) - ‘train.csv’ saved [10529265/10529265]



In [None]:
kinetics_classes = []

with open('train.csv', 'r') as f:
    reader = csv.reader(f)
    next(reader, None)
    for line in reader:
        kinetics_classes.append(line[0].rstrip().lower())

kinetics_classes = np.unique(kinetics_classes)
print(len(kinetics_classes))
# remove punctuation
for i, cls in enumerate(kinetics_classes):
    kinetics_classes[i] = cls.translate(str.maketrans('', '', string.punctuation))

print(kinetics_classes[:5])

400
['abseiling' 'air drumming' 'answering questions' 'applauding'
 'applying cream']


##### Audioset - UCF101 pair

In [None]:
TOPN_RESULTS = 3

In [None]:
trg_classes_minilm, sims_minilm = [], []

for query_cls in audioset_classes:
    trg, sim = MiniLM_similarities(query_cls, ucf101_classes, topn=TOPN_RESULTS)
    trg_classes_minilm.append(trg)
    sims_minilm.append(sim)

In [None]:
audioset_embs = phrase_bert.encode(audioset_classes)
ucf_embs = phrase_bert.encode(ucf101_classes)  # get embeddings

trg_classes_phrase, sims_phrase = [], []

for q_emb in audioset_embs:
    indices, sims = PhraseBERT_similarities(q_emb, ucf_embs, topn=TOPN_RESULTS)
    sims_phrase.append(sims)
    trg_classes_phrase.append([ucf101_classes[idx] for idx in indices])

In [None]:
from prettytable import PrettyTable

# pretty print results
t = PrettyTable(["Query class", "Target class (MiniLM)", "Similarity (MiniLM)", "Target class (Phrase-BERT)", "Similarity (Phrase-BERT)"])

for q, t1, s1, t2, s2 in zip(audioset_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if len(s1) > 1:
        t.add_row([q, '\n'.join(t1), '\n'.join([f"{s:.4f}" for s in s1]), '\n'.join(t2), '\n'.join([f"{s:.4f}" for s in s2])])
    else:
        t.add_row([q, t1[0], f"{s1[0]:.4f}", t2[0], f"{s2[0]:.4f}"])
    
    t.add_row(['', '', '', '', ''])

print(t)

+----------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|              Query class               | Target class (MiniLM) | Similarity (MiniLM) | Target class (Phrase-BERT) | Similarity (Phrase-BERT) |
+----------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|               a capella                |         lunges        |        0.3196       |       playing tabla        |          0.7119          |
|                                        |     playing tabla     |        0.2994       |       playing violin       |          0.7055          |
|                                        |        surfing        |        0.2966       |         salsa spin         |          0.6963          |
|                                        |                       |                     |                            |             

In [None]:
sim_threshold = 0.47
table2 = PrettyTable(["Query class", "Final target class(es)"])

final_targets = []
for q, t1, s1, t2, s2 in zip(audioset_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if s1[0] > sim_threshold and len(set(t1) & set(t2)) != 0:
        # manually filter some outlier target classes (optional step)
        res = [temp_t for temp_t in set(t1) & set(t2) if temp_t not in ['skijet', 'shotput', 'pizza tossing']]
        if len(res) == 0:
            table2.add_row([q, '-'])
        else:
            table2.add_row([q, '\n'.join(res)])
            final_targets.extend(res)
    else:
        table2.add_row([q, '-'])
    table2.add_row(['', ''])

print(table2)

+----------------------------------------+------------------------+
|              Query class               | Final target class(es) |
+----------------------------------------+------------------------+
|               a capella                |           -            |
|                                        |                        |
|       accelerating revving vroom       |           -            |
|                                        |                        |
|               accordion                |           -            |
|                                        |                        |
|            acoustic guitar             |     playing cello      |
|                                        |     playing violin     |
|                                        |     playing guitar     |
|                                        |                        |
|                afrobeat                |           -            |
|                                        |      

In [None]:
# print final outcome

print(f"Approximately, {len(np.unique(final_targets))}/101 classes of UCF-101 dataset are similar to those of AudioSet dataset.")

Approximately, 50/101 classes of UCF-101 dataset are similar to those of AudioSet dataset.


##### Audioset - HMDB51 pair

In [None]:
TOPN_RESULTS = 3

In [None]:
trg_classes_minilm, sims_minilm = [], []

for query_cls in audioset_classes:
    trg, sim = MiniLM_similarities(query_cls, hmdb51_classes, topn=TOPN_RESULTS)
    trg_classes_minilm.append(trg)
    sims_minilm.append(sim)

In [None]:
audioset_embs = phrase_bert.encode(audioset_classes)
hmdb_embs = phrase_bert.encode(hmdb51_classes)  # get embeddings

trg_classes_phrase, sims_phrase = [], []

for q_emb in audioset_embs:
    indices, sims = PhraseBERT_similarities(q_emb, hmdb_embs, topn=TOPN_RESULTS)
    sims_phrase.append(sims)
    trg_classes_phrase.append([hmdb51_classes[idx] for idx in indices])

In [None]:
from prettytable import PrettyTable

# pretty print results
t = PrettyTable(["Query class", "Target class (MiniLM)", "Similarity (MiniLM)", "Target class (Phrase-BERT)", "Similarity (Phrase-BERT)"])

for q, t1, s1, t2, s2 in zip(audioset_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if len(s1) > 1:
        t.add_row([q, '\n'.join(t1), '\n'.join([f"{s:.4f}" for s in s1]), '\n'.join(t2), '\n'.join([f"{s:.4f}" for s in s2])])
    else:
        t.add_row([q, t1[0], f"{s1[0]:.4f}", t2[0], f"{s2[0]:.4f}"])
    
    t.add_row(['', '', '', '', ''])

print(t)

+----------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|              Query class               | Target class (MiniLM) | Similarity (MiniLM) | Target class (Phrase-BERT) | Similarity (Phrase-BERT) |
+----------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|               a capella                |       flic flac       |        0.4216       |         flic flac          |          0.6223          |
|                                        |         climb         |        0.2727       |           smoke            |          0.5622          |
|                                        |          turn         |        0.2581       |           situp            |          0.5609          |
|                                        |                       |                     |                            |             

In [None]:
# print final results

sim_threshold = 0.46
table2 = PrettyTable(["Query class", "Final target class(es)"])

final_targets = []
for q, t1, s1, t2, s2 in zip(audioset_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if s1[0] > sim_threshold and len(set(t1) & set(t2)) != 0:
        # manually filter some outlier target classes (optional step)
        res = [temp_t for temp_t in set(t1) & set(t2) if temp_t not in ['hug', 'shoot bow', 'cartwheel', 'smoke', 'somersault', 'wave', 'pushup', 'handstand']]
        if len(res) == 0:
            table2.add_row([q, '-'])
        else:
            table2.add_row([q, '\n'.join(res)])
            final_targets.extend(res)
    else:
        table2.add_row([q, '-'])
    table2.add_row(['', ''])

print(table2)

+----------------------------------------+------------------------+
|              Query class               | Final target class(es) |
+----------------------------------------+------------------------+
|               a capella                |           -            |
|                                        |                        |
|       accelerating revving vroom       |           -            |
|                                        |                        |
|               accordion                |           -            |
|                                        |                        |
|            acoustic guitar             |           -            |
|                                        |                        |
|                afrobeat                |           -            |
|                                        |                        |
|               air brake                |           -            |
|                                        |      

In [None]:
print(f"Approximately, {len(np.unique(final_targets))}/51 classes of HMDB-51 dataset are similar to those of AudioSet dataset.")

Approximately, 28/51 classes of HMDB-51 dataset are similar to those of AudioSet dataset.


##### Kinetics-400 - UCF101 pair

In [None]:
TOPN_RESULTS = 3

In [None]:
trg_classes_minilm, sims_minilm = [], []

for query_cls in kinetics_classes:
    trg, sim = MiniLM_similarities(query_cls, ucf101_classes, topn=TOPN_RESULTS)
    trg_classes_minilm.append(trg)
    sims_minilm.append(sim)

In [None]:
kinetics_embs = phrase_bert.encode(kinetics_classes)
ucf_embs = phrase_bert.encode(ucf101_classes)  # get embeddings

trg_classes_phrase, sims_phrase = [], []

for q_emb in kinetics_embs:
    indices, sims = PhraseBERT_similarities(q_emb, ucf_embs, topn=TOPN_RESULTS)
    sims_phrase.append(sims)
    trg_classes_phrase.append([ucf101_classes[idx] for idx in indices])

In [None]:
from prettytable import PrettyTable

# pretty print results
t = PrettyTable(["Query class", "Target class (MiniLM)", "Similarity (MiniLM)", "Target class (Phrase-BERT)", "Similarity (Phrase-BERT)"])

for q, t1, s1, t2, s2 in zip(kinetics_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if len(s1) > 1:
        t.add_row([q, '\n'.join(t1), '\n'.join([f"{s:.4f}" for s in s1]), '\n'.join(t2), '\n'.join([f"{s:.4f}" for s in s2])])
    else:
        t.add_row([q, t1[0], f"{s1[0]:.4f}", t2[0], f"{s2[0]:.4f}"])
    
    t.add_row(['', '', '', '', ''])

print(t)

+---------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|              Query class              | Target class (MiniLM) | Similarity (MiniLM) | Target class (Phrase-BERT) | Similarity (Phrase-BERT) |
+---------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|               abseiling               |        rafting        |        0.5721       |       rope climbing        |          0.7466          |
|                                       |     rope climbing     |        0.4401       |         long jump          |          0.7314          |
|                                       |       jump rope       |        0.4323       |          push ups          |          0.7262          |
|                                       |                       |                     |                            |                    

In [None]:
sim_threshold = 0.47
table2 = PrettyTable(["Query class", "Final target class(es)"])

final_targets = []
for q, t1, s1, t2, s2 in zip(kinetics_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if s1[0] > sim_threshold and len(set(t1) & set(t2)) != 0:
        # manually filter some outlier target classes (optional step)
        res = [temp_t for temp_t in set(t1) & set(t2) if temp_t not in ['']]
        if len(res) == 0:
            table2.add_row([q, '-'])
        else:
            table2.add_row([q, '\n'.join(res)])
            final_targets.extend(res)
    else:
        table2.add_row([q, '-'])
    table2.add_row(['', ''])

print(table2)

+---------------------------------------+------------------------+
|              Query class              | Final target class(es) |
+---------------------------------------+------------------------+
|               abseiling               |     rope climbing      |
|                                       |                        |
|              air drumming             |        drumming        |
|                                       |                        |
|          answering questions          |           -            |
|                                       |                        |
|               applauding              |           -            |
|                                       |                        |
|             applying cream            |     shaving beard      |
|                                       |                        |
|                archery                |        archery         |
|                                       |                     

In [None]:
# print final outcome

print(f"Approximately, {len(np.unique(final_targets))}/101 classes of UCF-101 dataset are similar to those of Kinetics-400 dataset.")

Approximately, 86/101 classes of UCF-101 dataset are similar to those of Kinetics-400 dataset.


##### Kinetics-400 - HMDB51 pair

In [None]:
TOPN_RESULTS = 3

In [None]:
trg_classes_minilm, sims_minilm = [], []

for query_cls in kinetics_classes:
    trg, sim = MiniLM_similarities(query_cls, hmdb51_classes, topn=TOPN_RESULTS)
    trg_classes_minilm.append(trg)
    sims_minilm.append(sim)

In [None]:
kinetics_embs = phrase_bert.encode(kinetics_classes)
hmdb_embs = phrase_bert.encode(hmdb51_classes)  # get embeddings

trg_classes_phrase, sims_phrase = [], []

for q_emb in kinetics_embs:
    indices, sims = PhraseBERT_similarities(q_emb, hmdb_embs, topn=TOPN_RESULTS)
    sims_phrase.append(sims)
    trg_classes_phrase.append([hmdb51_classes[idx] for idx in indices])

In [None]:
from prettytable import PrettyTable

# pretty print results
t = PrettyTable(["Query class", "Target class (MiniLM)", "Similarity (MiniLM)", "Target class (Phrase-BERT)", "Similarity (Phrase-BERT)"])

for q, t1, s1, t2, s2 in zip(kinetics_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if len(s1) > 1:
        t.add_row([q, '\n'.join(t1), '\n'.join([f"{s:.4f}" for s in s1]), '\n'.join(t2), '\n'.join([f"{s:.4f}" for s in s2])])
    else:
        t.add_row([q, t1[0], f"{s1[0]:.4f}", t2[0], f"{s2[0]:.4f}"])
    
    t.add_row(['', '', '', '', ''])

print(t)

+---------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|              Query class              | Target class (MiniLM) | Similarity (MiniLM) | Target class (Phrase-BERT) | Similarity (Phrase-BERT) |
+---------------------------------------+-----------------------+---------------------+----------------------------+--------------------------+
|               abseiling               |         climb         |        0.4464       |         somersault         |          0.7696          |
|                                       |       fall floor      |        0.4263       |           pushup           |          0.6752          |
|                                       |          jump         |        0.3909       |           climb            |          0.6238          |
|                                       |                       |                     |                            |                    

In [None]:
# print final results

sim_threshold = 0.46
table2 = PrettyTable(["Query class", "Final target class(es)"])

final_targets = []
for q, t1, s1, t2, s2 in zip(kinetics_classes, trg_classes_minilm, sims_minilm, trg_classes_phrase, sims_phrase):
    if s1[0] > sim_threshold and len(set(t1) & set(t2)) != 0:
        # manually filter some outlier target classes (optional step)
        res = [temp_t for temp_t in set(t1) & set(t2) if temp_t not in ['wave', 'handstand']]
        if len(res) == 0:
            table2.add_row([q, '-'])
        else:
            table2.add_row([q, '\n'.join(res)])
            final_targets.extend(res)
    else:
        table2.add_row([q, '-'])
    table2.add_row(['', ''])

print(table2)

+---------------------------------------+------------------------+
|              Query class              | Final target class(es) |
+---------------------------------------+------------------------+
|               abseiling               |           -            |
|                                       |                        |
|              air drumming             |           -            |
|                                       |                        |
|          answering questions          |           -            |
|                                       |                        |
|               applauding              |         smile          |
|                                       |         laugh          |
|                                       |          clap          |
|                                       |                        |
|             applying cream            |           -            |
|                                       |                     

In [None]:
print(f"Approximately, {len(np.unique(final_targets))}/51 classes of HMDB-51 dataset are similar to those of Kinetics-400 dataset.")

Approximately, 41/51 classes of HMDB-51 dataset are similar to those of Kinetics-400 dataset.
