<a href="https://colab.research.google.com/github/komazawa-deep-learning/komazawa-deep-learning.github.io/blob/master/2024notebooks/2024_1025sentence_similarity_by_SBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch

device = 'cuda:0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'device:{device}')

import IPython
isColab = 'google.colab' in str(IPython.get_ipython())

if isColab:

    # GPU 情報を表示
    !nvidia-smi -L

    # `import bit` する前に termcolor を downgrade しないと colab ではテキストに色がつかない
    !pip install --upgrade termcolor==1.1
    import termcolor

    !pip install --upgrade openpyxl
    !pip install --upgrade pandas
    #!pip install --upgrade fugashi[unidic-lite]
    !pip install --upgrade 'fugashi[unidic]'
    !pip install --upgrade ipadic
    !python -m unidic download
    !pip install transformers
    !pip install --upgrade jaconv

    !pip install --upgrade xlrd

import os
HOME = os.environ['HOME']

import sys
from collections import OrderedDict
from termcolor import colored
try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib

from tqdm.notebook import tqdm

try:
    import ipynb_path
except ImportError:
    !pip install ipynb-path
    import ipynb_path
__file__ = os.path.basename(ipynb_path.get())
print(f'__file__:{__file__}')
print(f'torch.__version__:{torch.__version__}')

In [None]:
import os
import tarfile
import json

if not os.path.exists('STAIR-captions/stair_captions_v1.2.tar.gz'):
    !git clone https://github.com/STAIR-Lab-CIT/STAIR-captions.git
    tgz = tarfile.open('STAIR-captions/stair_captions_v1.2.tar.gz')
    tgz.extractall()

STAIR_dir = '.'
data_ = {'val': {  'fname':os.path.join(STAIR_dir, 'stair_captions_v1.2_val.json')},
        # 'train': {'fname':os.path.join(STAIR_dir, 'stair_captions_v1.2_train.json')},
        }
data_

for _k in data_.keys():
    fname = data_[_k]['fname']
    with open(fname, 'r') as fp:
        data_[_k]['dic'] = json.load(fp)
    print(f"fname:{fname}, data_[_k]['dic']:{len(data_[_k]['dic'])}")

ann = {}
for _k in data_.keys():
    a = data_[_k]['dic']
    for x in a['annotations']:
        imageid, anno_id, caption = x['image_id'], x['id'], x['caption']
        if not imageid in ann:
            ann[imageid] = []
        ann[imageid].append({'anno_id':anno_id,'caption':caption})

image_ids_ = sorted(list(ann.keys()))

# n_samples の数だけデータを印字
n_samples = 5

for imageid in image_ids_[:n_samples]:
    print(ann[imageid])

for imageid in list(ann.keys())[:n_samples]:
    for s in ann[imageid]:
        print(f'画像インデックス(imageid):{imageid}, アノテーション:{s}') # [1])
print(f'len(ann):{len(ann)}')

In [None]:
# 更にデータを印字
idx2flickr_url = {}
for _k in data_.keys():
    a = data_[_k]['dic']
    for x in a['images']:
        flickr_url, idx = x['flickr_url'], x['id']
        if idx in idx2flickr_url:
            print(f'idx:{idx} duplicated!')
        else:
            idx2flickr_url[idx] = []
        idx2flickr_url[idx] = flickr_url

keys_ = sorted(list(ann.keys()))
print(data_['val']['dic']['images'][:3])
image_id = 391895
print(ann[image_id])
print(len(ann.keys()))

In [None]:
import numpy as np
import IPython
from IPython.display import Image

image_id = np.random.choice(list(idx2flickr_url.keys()))
print(f'image_id:{image_id}, ann[{image_id}]')

for x in ann[image_id]:
    print(x['caption'])

print(idx2flickr_url[image_id])
Image(url=idx2flickr_url[image_id])


In [None]:
from transformers import BertJapaneseTokenizer, BertModel
import torch

class SentenceBertJapanese:
    def __init__(self,
                 model_name_or_path:str = "sonoisa/sentence-bert-base-ja-mean-tokens-v2",
                 device=device):
                 #device=None):

        self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
        self.model = BertModel.from_pretrained(model_name_or_path)
        self.model.eval()

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.model.to(device)

    def _mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


    def encode(self, sentences, batch_size=8):
        all_embeddings = []
        iterator = range(0, len(sentences), batch_size)
        for batch_idx in iterator:
            batch = sentences[batch_idx:batch_idx + batch_size]

            encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest",
                                           truncation=True, return_tensors="pt").to(self.device)
            model_output = self.model(**encoded_input)
            sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')

            all_embeddings.extend(sentence_embeddings)

        # return torch.stack(all_embeddings).numpy()
        return torch.stack(all_embeddings)


MODEL_NAME = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"  # <- v2です。
sbert_model = SentenceBertJapanese(MODEL_NAME)

sentences = ["暴走したAI", "暴走した人工知能"]
sentence_embeddings = sbert_model.encode(sentences, batch_size=8)

print("Sentence embeddings:", sentence_embeddings)

In [None]:
sents = []
for x in list(ann.keys()):
    for _x in ann[x]:
        sents.append(_x['caption'])
print(f'総文数 len(sents):{len(sents)}')

top_n = 100
coco_anno = sbert_model.encode(sents[:top_n])
print(coco_anno.size(), len(sents))

In [None]:
%%time
from tqdm.notebook import tqdm
coco = sents
_coco = {}
n_coco = len(coco)
for i in tqdm(range(n_coco)):
    #sent = coco[i]['caption']
    sent = coco[i]
    _coco[i] = {}
    _coco[i]['tokens'] = sbert_model.tokenizer.tokenize(sent)
    _coco[i]['input_ids'] = sbert_model.tokenizer(sent)['input_ids']


In [None]:
#import jaconv
#coco_sents = [jaconv.normalize(x['caption']) for x in coco]
coco_sents = sents
#coco_vects = sbert_model.encode(coco_sents)
coco_vects = coco_anno.detach().numpy()

In [None]:
from termcolor import colored
import scipy.spatial

def search_sim_sents(queries:list,
                     answers:list,
                     model:BertModel,
                     vectors:torch.Tensor,
                     top_n:int = 5,
                     verbose:bool=False,
                    ):
    # 文埋め込みベクトルの中から類似する文を検索する関数
    if answers == None:
        answers = queries
    ret = {}
    query_embeddings = sbert_model.encode(queries).detach().numpy()
    for query, query_embedding in zip(queries, query_embeddings):
        distances = scipy.spatial.distance.cdist([query_embedding],
                                                 vectors,
                                                 metric="cosine")[0]

        results = zip(range(len(distances)), distances)
        results = sorted(results, key=lambda x: x[1])
        ret[query] = []
        for idx, distance in results[1:top_n+1]:
            print(f'{query}, {answers[idx]}, {1-distance/2:.3f}') if verbose else None
            ret[query].append((answers[idx], 1 - distance/2, idx))
            #ret[query].append((answers[idx], 1 - distance/2))
    return ret


# 検証のため，最初の top_n 文について
top_n = 3
ret = search_sim_sents(queries=coco_sents[:top_n],
                       answers=coco_sents,
                       model=sbert_model,
                       vectors=coco_vects,
                       #verbose=True,
                       top_n=10,
                      )
for i, (k, v) in enumerate(ret.items()):
    print(colored(f'文番号 {i:3d} {k}', 'grey', attrs=['bold']))
    for j, _v in enumerate(v):
        print(f'\t近接文 {_v[2]:4d}:{_v[0]} ({_v[1]:.3f})')
        #print(f'\t近接文{j:2d}:{_v[0]}: {_v[2]:5d}({_v[1]:.3f})')

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# tSNE 原著者の公開しているソースコードを Python 3 対応に書き換えた版を読み込む
try:
    from ccap import tsne
except ImportError:
    !git clone https://github.com/ShinAsakawa/ccap.git
    from ccap import tsne

In [None]:
def calc_and_plot_tsne_pca(
    # sbert_model.model もしくは _bert_model.model を指定可能
    model:BertModel=sbert_model.model,

    # どのベクトルを計算するかを指定: sbert_snow_vectors または _bert_snow_vectors を指定可能
    vectors:torch.Tensor=coco_anno,

    perplexity=30.0,         # 錯綜度を指定，原著論文によれば 5 から 50 程度
    pca = False,             # 主成分分析を行うか否か
    no_dims = 50,            # 主成分分析で求める固有値数

    figsize:tuple=(20,20),   # 表示する図の大きさを指定
    tag:list=sents,          # 散布図中に表示する記号や文字を指定
    fontsize:int=5,          # 散布図中に表示する記号や文字のサイズ
    title:str=None,          # 図のタイトル
    fig_fname:str=None,      # 図を保存するファイル名
    excel_fname:str=None,    # 図の座標を保存するエクセルファイル名
    mod_size:int=5,          # 同じ色で描画する範囲 5
    marker_size:int=20,      # marker の大きさ
    ):

    _X = vectors.detach().numpy()
    #_X = vectors.clone().numpy()
    (n, d) = _X.shape
    #colors = np.array(["orange", "purple", "beige", "brown",
    colors = ["orange", #"purple",
              "brown", "black", "cyan",
              "magenta", "red", "green", "blue", "yellow",
              "pink"]
    markers = ['o','v','^', '<', '>', '+','*','x', 'D','8','s', 'D','8','s','_']

    if pca == True:
        _X = _X - np.tile(np.mean(_X, 0), (n, 1))
        (l, M) = np.linalg.eig(np.dot(_X.T, _X))
        X = np.dot(_X, M[:,0:no_dims])
    else:
        X = tsne.tsne(_X, perplexity=perplexity)  # tSNE の実施
    plt.figure(figsize=figsize)                   # 図のサイズ指定，単位インチ

    half = n >> 1
    change_col = False
    _i = 0
    _j = -1

    for i in range(n):
        if i % mod_size == 0:
            _i = 0
            _j += 1
        else:
            _i += 1

        c = colors[_j % len(colors)]
        m = markers[_i]
        plt.scatter(X[i,0],X[i,1],
                    marker=m,
                    color=c,
                    s=marker_size)

    if tag != None:
        for i, txt in enumerate(tag):           # 図内にアノテーションを書き込む
            plt.annotate(tag[i], (X[i,0], X[i,1]),
                         alpha=0.7,
                         ha='center',
                         fontsize=fontsize)
    plt.title(title) if title != None else None  # 図の表題

    # 図のファイル書き出し
    if fig_fname != None:
        if os.path.exists(fig_fname):
            print(f'File: {fig_fname} exists.',end=" ")
            yn = input('delete[Y/n]')
            if yn == 'Y':
                os.remove(fig_fname)
        plt.savefig(fig_fname)
    plt.show()

    # 図のデータをエクセルファイルに書き出し
    if excel_fname != None:
        if os.path.exists(excel_fname):
            print(f'File: {excel_fname} exists.', end=" ")
            yn = input('delete[Y/n]')
            if yn == 'Y':
                os.remove(excel_fname)
        _dict = {}
        for i, (tag, _X) in enumerate(zip(tag, X)):
            _dict[i] = {'tag':tag,
                        'x': _X[0],
                        'y': _X[1]
                       }
        df = pd.DataFrame(_dict).T
        df.to_excel(excel_fname)

    return X     # tSNE の結果を返す。


In [None]:
%%time
N = 100
_Xvects = coco_anno[:N,:]
_Xsents = coco_sents[:N]

np.random.seed(42)
for _model, title in [(sbert_model, 'センテンスBERT')]:
    calc_and_plot_tsne_pca(
        model=_model,
        vectors=_Xvects,
        tag=_Xsents,
        #tag = None,
        fontsize=8,
        #fontsize=4,
        figsize=(12,12),
        perplexity=30,
        pca=False,
        title=title,
        #marker_size=50,
        marker_size=100,
        fig_fname=None, #'2022_0915sbert_staircoco500.pdf',
    )
