-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from luismond/dev
Dev
- Loading branch information
Showing
17 changed files
with
995 additions
and
650 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,14 @@ | ||
""" | ||
TM2TB initialization | ||
""" | ||
__author__ = "Luis Mondragon (luismond@gmail.com)" | ||
__version__ = '1.0.3' | ||
__version__ = '1.4.0' | ||
|
||
from tm2tb.transformer_model import TransformerModel | ||
trf_model = TransformerModel().load() | ||
from tm2tb.bitext_reader import BitextReader | ||
from tm2tb.tm2tb import Tm2Tb, Sentence | ||
from tm2tb.sentence import Sentence | ||
from tm2tb.bisentence import BiSentence | ||
from tm2tb.text import Text | ||
from tm2tb.bitext import BiText | ||
from tm2tb.tm2tb import Tm2Tb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
align ngrams | ||
""" | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
pd.options.mode.chained_assignment = None | ||
|
||
def get_seq_similarities(src_embs, trg_embs): | ||
seq_similarities = cosine_similarity(src_embs, trg_embs) | ||
return seq_similarities | ||
|
||
def get_aligned_ngrams(src_ngrams_df, trg_ngrams_df, **kwargs): | ||
#src_ngrams_df, trg_ngrams_df = self.get_ngrams_dfs(**kwargs) | ||
src_ngrams = src_ngrams_df['joined_ngrams'].tolist() | ||
src_tags = src_ngrams_df['tags'].tolist() | ||
src_ranks = src_ngrams_df['rank'].tolist() | ||
src_embeddings = src_ngrams_df['embedding'].tolist() | ||
trg_ngrams = trg_ngrams_df['joined_ngrams'].tolist() | ||
trg_tags = trg_ngrams_df['tags'].tolist() | ||
trg_ranks = trg_ngrams_df['rank'].tolist() | ||
trg_embeddings = trg_ngrams_df['embedding'].tolist() | ||
seq_similarities = get_seq_similarities(src_embeddings, trg_embeddings) | ||
src_idx = list(range(len(src_ngrams))) | ||
trg_idx = list(range(len(trg_ngrams))) | ||
# Get indexes and values of most similar source ngram for each target ngram | ||
trg_max_values = np.max(seq_similarities[src_idx][:, trg_idx], axis=1) | ||
trg_max_idx = np.argmax(seq_similarities[src_idx][:, trg_idx], axis=1) | ||
# Get indexes and values of most similar target ngram for each source ngram | ||
src_max_values = np.max(seq_similarities[src_idx][:, trg_idx], axis=0) | ||
src_max_idx = np.argmax(seq_similarities[src_idx][:, trg_idx], axis=0) | ||
src_aligned_ngrams = pd.DataFrame([(src_ngrams[idx], | ||
src_tags[idx], | ||
src_ranks[idx], | ||
trg_ngrams[trg_max_idx[idx]], | ||
trg_tags[trg_max_idx[idx]], | ||
trg_ranks[trg_max_idx[idx]], | ||
float(trg_max_values[idx])) for idx in src_idx]) | ||
trg_aligned_ngrams = pd.DataFrame([(src_ngrams[src_max_idx[idx]], | ||
src_tags[src_max_idx[idx]], | ||
src_ranks[src_max_idx[idx]], | ||
trg_ngrams[idx], | ||
trg_tags[idx], | ||
trg_ranks[idx], | ||
float(src_max_values[idx])) for idx in trg_idx]) | ||
return src_aligned_ngrams, trg_aligned_ngrams | ||
|
||
def get_top_ngrams(src_ngrams_df, | ||
trg_ngrams_df, | ||
min_similarity=.8, | ||
**kwargs): | ||
# Concatenate source & target ngram alignments | ||
src_aligned_ngrams, trg_aligned_ngrams = get_aligned_ngrams(src_ngrams_df, | ||
trg_ngrams_df, | ||
**kwargs) | ||
bi_ngrams = pd.concat([src_aligned_ngrams, trg_aligned_ngrams]) | ||
bi_ngrams = bi_ngrams.reset_index() | ||
bi_ngrams = bi_ngrams.drop(columns=['index']) | ||
bi_ngrams.columns = ['src_ngram', | ||
'src_ngram_tags', | ||
'src_ngram_rank', | ||
'trg_ngram', | ||
'trg_ngram_tags', | ||
'trg_ngram_rank', | ||
'bi_ngram_similarity'] | ||
|
||
# Keep n-grams above min_similarity | ||
bi_ngrams = bi_ngrams[bi_ngrams['bi_ngram_similarity'] >= min_similarity] | ||
if len(bi_ngrams)==0: | ||
raise ValueError('No ngram pairs above minimum similarity!') | ||
# For one-word terms, keep those longer than 1 character | ||
bi_ngrams = bi_ngrams[bi_ngrams['src_ngram'].str.len()>1] | ||
bi_ngrams = bi_ngrams[bi_ngrams['trg_ngram'].str.len()>1] | ||
# Group by source, get the most similar target n-gram | ||
bi_ngrams = pd.DataFrame([df.loc[df['bi_ngram_similarity'].idxmax()] | ||
for (src_ngram, df) in list(bi_ngrams.groupby('src_ngram'))]) | ||
# Group by target, get the most similar source n-gram | ||
bi_ngrams = pd.DataFrame([df.loc[df['bi_ngram_similarity'].idxmax()] | ||
for (trg_ngram, df) in list(bi_ngrams.groupby('trg_ngram'))]) | ||
# Get bi n-gram rank | ||
bi_ngrams['bi_ngram_rank'] = bi_ngrams['bi_ngram_similarity'] * \ | ||
bi_ngrams['src_ngram_rank'] * bi_ngrams['trg_ngram_rank'] | ||
bi_ngrams = bi_ngrams.sort_values(by='bi_ngram_rank', ascending=False) | ||
bi_ngrams = bi_ngrams.round(4) | ||
return bi_ngrams |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
""" | ||
BiSentence class. | ||
""" | ||
from tm2tb import Sentence | ||
from tm2tb.align_ngrams import get_top_ngrams | ||
|
||
class BiSentence: | ||
def __init__(self, sentence_tuple): | ||
self.src_sentence = Sentence(sentence_tuple[0]) | ||
self.trg_sentence = Sentence(sentence_tuple[1]) | ||
|
||
def get_ngrams_dfs(self, **kwargs): | ||
src_ngrams_df = self.src_sentence.get_top_ngrams(return_embs=True, | ||
**kwargs) | ||
trg_ngrams_df = self.trg_sentence.get_top_ngrams(return_embs=True, | ||
**kwargs) | ||
return src_ngrams_df, trg_ngrams_df | ||
|
||
def get_top_ngrams(self, **kwargs): | ||
src_ngrams_df, trg_ngrams_df = self.get_ngrams_dfs(**kwargs) | ||
top_ngrams = get_top_ngrams(src_ngrams_df, | ||
trg_ngrams_df, | ||
min_similarity=.8, | ||
**kwargs) | ||
return top_ngrams |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
""" | ||
BiText class | ||
""" | ||
from tm2tb import Text | ||
from tm2tb.align_ngrams import get_top_ngrams | ||
|
||
class BiText: | ||
def __init__(self, bitext): | ||
self.src_text = Text(bitext['src'].tolist()) | ||
self.trg_text = Text(bitext['trg'].tolist()) | ||
|
||
def get_ngrams_dfs(self, **kwargs): | ||
src_ngrams_df = self.src_text.get_top_ngrams(return_embs=True, | ||
**kwargs) | ||
trg_ngrams_df = self.trg_text.get_top_ngrams(return_embs=True, | ||
**kwargs) | ||
return src_ngrams_df, trg_ngrams_df | ||
|
||
def get_top_ngrams(self, **kwargs): | ||
src_ngrams_df, trg_ngrams_df = self.get_ngrams_dfs(**kwargs) | ||
top_ngrams = get_top_ngrams(src_ngrams_df, | ||
trg_ngrams_df, | ||
min_similarity=.8, | ||
**kwargs) | ||
return top_ngrams |
Oops, something went wrong.