## Functions

In [1]:
import torch
import numpy as np
import pandas as pd
from transformers import AutoModel, AutoTokenizer
from scipy.optimize import linear_sum_assignment
from sklearn.metrics.pairwise import cosine_similarity
import re
from tqdm import tqdm
import nltk
import kss
from nltk.tokenize import sent_tokenize
from kobert_tokenizer import KoBERTTokenizer
from transformers import BertTokenizerFast
from transformers import BertModel

# 현재 적정 dp alignemtn 하이파라미터
# def paragraph_dp_align_v2(eng_pars, kor_pars, debug = False, max_merge=3, skip_penalty=0.35, method="mean", merge_threshold = 0.45, match_threshold=0.6, merge_bonus=0.08):

#TODO
#나중에 더 적정값 찾기, 아마 



# --- 0. 설정 및 파일 경로 ---
EMBEDDING_DIM = 768
LABSE = "setu4993/LaBSE"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MAX_SENT = 200  

sent_tags = []
for i in range(MAX_SENT):
    sent_tags.append(f"[S{i}]")
    sent_tags.append(f"[/S{i}]")


xlmr_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large")
xlmr_model = AutoModel.from_pretrained("xlm-roberta-large").to(DEVICE)
xlmr_model.eval()

kobert_tokenizer = KoBERTTokenizer.from_pretrained("skt/kobert-base-v1")
kobert_model = BertModel.from_pretrained("skt/kobert-base-v1").to(DEVICE)
kobert_model.eval()

bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_model = AutoModel.from_pretrained("bert-base-uncased").to(DEVICE)
bert_model.eval()

labase_tokenizer = AutoTokenizer.from_pretrained("setu4993/LaBSE")
labase_model = AutoModel.from_pretrained("setu4993/LaBSE").to(DEVICE)
labase_model.eval()

def save_sentences(sent_list, path):
    with open(path, "w", encoding="utf-8") as f:
        for s in sent_list:
            f.write(s.strip() + "\n")


# Load Text
def load_and_segment_text(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        return text
    except FileNotFoundError:
        print(f"오류: 파일 없음 -> {file_path}")
        return "" 
    
def clean_non_text_lines(text):
    clean = []
    for line in text.splitlines():
        # 영문, 한글,숫자 중 하나라도 포함하면 정상
        if re.search(r'[A-Za-z0-9가-힣]', line):
            clean.append(line)
        # 아니면 문자가 없는 장식이면 skip
    return "\n".join(clean)



# Split sentence

def split_paragraphs(raw_text):
    raw_paragraphs = re.split(r'\n\s*\n+', raw_text)
    paragraphs = [
        re.sub(r'\n+', ' ', p).strip()
        for p in raw_paragraphs 
        if (stripped := p.strip()) and re.search(r'[A-Za-z0-9가-힣]', stripped)
    ]
    return paragraphs

def split_by_punctuation(text):
    parts = re.split(r'(?<=[.!?])\s+(?=[A-Za-z0-9“"‘\'가-힣])', text)
    return parts

def split_sentences(text):

    text = re.sub(r'\s+', ' ', text).strip()

    sentences = []

    quote_blocks = re.split(r'(“[^”]+”|"[^"]+")', text)

    for block in quote_blocks:
        if not block.strip():
            continue

        if (block.startswith("“") and block.endswith("”")) or \
           (block.startswith('"') and block.endswith('"')):
            sentences.extend(split_by_punctuation(block))
            continue

        paren_blocks = re.split(r'(\([^()]+\)[.!?;,]?)', block)

        for pblock in paren_blocks:
            if not pblock.strip():
                continue

            if pblock.startswith("(") and pblock.endswith(")"):
                sentences.extend(split_by_punctuation(pblock))
            else:
                sentences.extend(split_by_punctuation(pblock))

    return [s.strip() for s in sentences if s.strip()]


# functions

def get_embeddings_cached(sentences):
    inputs = labase_tokenizer(
        sentences, padding=True, truncation=True, return_tensors="pt"
    ).to(DEVICE)

    with torch.no_grad():
        outputs = labase_model(**inputs)
        cls_emb = outputs.last_hidden_state[:, 0, :]

    return cls_emb.cpu().numpy()

def merge_embeddings(emb_list, method="mean"):
    emb = np.stack(emb_list, axis=0)
    if method == "mean":
        return emb.mean(axis=0)
    else:  # sum 방식
        return emb.sum(axis=0)
    
# DP aligner

def get_paragraph_embedding_xlmr(paragraph, tokenizer, model,
                                 layer_indices=[8,10,12]):
    encoded = tokenizer(
        paragraph,
        return_tensors="pt",
        truncation=True,
        return_offsets_mapping=False
    )
    encoded = {k:v.to(DEVICE) for k,v in encoded.items()}
    
    with torch.no_grad():
        out = model(**encoded, output_hidden_states=True)

    hidden_states = out.hidden_states
    pooled = []

    for li in layer_indices:
        h = hidden_states[li][0]  
        pooled.append(h.mean(dim=0).cpu().numpy())

    return np.concatenate(pooled, axis=0)   

#TODO
# N:M 집어넣기

def paragraph_dp_align_v2(eng_pars, kor_pars, debug = False, max_merge=3, skip_penalty=0.7, method="mean", k_bonus = 0.0, e_bonus =0.0):

    N = len(eng_pars)
    M = len(kor_pars)

    # Fast embedding
    eng_emb = get_embeddings_cached(eng_pars)
    kor_emb = get_embeddings_cached(kor_pars)

    dp = np.zeros((N+1, M+1))
    ptr = [[None] * (M+1) for _ in range(N+1)]

    sim_11 = cosine_similarity(eng_emb, kor_emb)

    for j in range(1, M+1):
        dp[0][j] = dp[0][j-1] - skip_penalty
        ptr[0][j] = (0, j-1, 0, 1) # skip KOR

    for i in range(1, N+1):
        dp[i][0] = dp[i-1][0] - skip_penalty
        ptr[i][0] = (i-1, 0, 1, 0) # skip ENG

    for i in range(1, N+1):
        for j in range(1, M+1):

            best_score = -1e15
            best_ptr = None

            # 1:1 match
            score = dp[i-1][j-1] + sim_11[i-1][j-1]
            if score > best_score:
                best_score = score
                best_ptr = (i-1, j-1, 1, 1)

            # N:1 merge (ENG block)
            for k in range(2, max_merge+1):
                if i-k < 0: break
                merged_e = merge_embeddings(eng_emb[i-k:i], method)
                s = cosine_similarity(merged_e.reshape(1,-1),
                                      kor_emb[j-1].reshape(1,-1))[0][0]
                score = dp[i-k][j-1] + s + e_bonus
                if score > best_score:
                    best_score = score
                    best_ptr = (i-k, j-1, k, 1)

            # 1:N merge (KOR block)
            for k in range(2, max_merge+1):
                if j-k < 0: break
                merged_k = merge_embeddings(kor_emb[j-k:j], method)
                s = cosine_similarity(eng_emb[i-1].reshape(1,-1),
                                      merged_k.reshape(1,-1))[0][0]
                score = dp[i-1][j-k] + s + k_bonus
                if score > best_score:
                    best_score = score
                    best_ptr = (i-1, j-k, 1, k)

            # Skip ENG
            score = dp[i-1][j] - skip_penalty
            if score > best_score:
                best_score = score
                best_ptr = (i-1, j, 1, 0)

            # Skip KOR
            score = dp[i][j-1] - skip_penalty
            if score > best_score:
                best_score = score
                best_ptr = (i, j-1, 0, 1)

            if best_ptr is None:
                best_ptr = (i-1, j-1, 1, 1)

            dp[i][j] = best_score
            ptr[i][j] = best_ptr

    # Back Tranking
    aligned = []
    unused_eng = set()
    unused_kor = set()

    i, j = N, M

    while i > 0 or j > 0:
        pi, pj, ei, ej = ptr[i][j]

        if ei > 0 and ej > 0:  # matched block
            aligned.append({
                "eng_idx": list(range(pi, pi+ei)),
                "kor_idx": list(range(pj, pj+ej)),
                "eng": " ".join(eng_pars[pi:pi+ei]),
                "kor": " ".join(kor_pars[pj:pj+ej]),
            })
        elif ei > 0 and ej == 0:
            unused_eng.add(i-1)
        elif ei == 0 and ej > 0:
            unused_kor.add(j-1)

        i, j = pi, pj

    aligned = aligned[::-1]
    if debug:
        print(f"Unused ENG = {len(unused_eng)}, Unused KOR = {len(unused_kor)}")

    return aligned

def align_sentences_from_paragraphs_v2(aligned_paragraphs):
    final_pairs = []

    for p in aligned_paragraphs:
        eng_text = p["eng"]
        kor_text = p["kor"]

        # 문장 분리
        eng_sents = split_sentences(eng_text)
        kor_sents = split_sentences(kor_text)

        if len(kor_sents) < len(eng_sents): # 영어가 더 많으니까 영어 블락 만드는걸 장려
            kor_bonus = 0.06
            eng_bonus = 0.12
        elif len(kor_sents) > len(eng_sents):
            kor_bonus = 0.12
            eng_bonus = 0.06
        else:
            kor_bonus = 0.06
            eng_bonus = 0.06


        # block DP sentence alignment
        sent_align = paragraph_dp_align_v2(
            eng_sents, kor_sents,
            debug = True,
            max_merge=3,       
            method="mean",
            k_bonus = kor_bonus,
            e_bonus = eng_bonus
        )

        # 결과 저장
        for a in sent_align:
            final_pairs.append({
                "eng": a["eng"],
                "kor": a["kor"],
                "eng_idx": a["eng_idx"],
                "kor_idx": a["kor_idx"],
                "paragraph_eng": eng_text,
                "paragraph_kor": kor_text
            })

    return final_pairs

# Embedding for ENG

def get_embeddings_flexible(sentences, model_name, debug=True, output_type="mean" ):

    print(f"Load Model: {model_name} | Mode: {output_type}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name).to(DEVICE)
    except Exception as e:
        print(f" Error : faield to load {model_name} : {e}")
        return None

    model.eval()
    
    # 배치 처리
    batch_size = 32
    all_results = []
    iterator = range(0, len(sentences), batch_size)
    if debug:
        iterator = tqdm(iterator, desc=f"  -> Embedding Extract ({model_name})")
    
    for i in iterator:
        batch_sentences = sentences[i:i+batch_size]
        inputs = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors="pt")
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)

            last_hidden_state = outputs.last_hidden_state
            attention_mask = inputs["attention_mask"]
            hidden_states = outputs.hidden_states  

        if output_type == "mean":
            layer_indices = [8, 10, 12] 
            pooled_layers = []

            for li in layer_indices:
                h = hidden_states[li]    
                mask_exp = attention_mask.unsqueeze(-1).float()  

                sum_h = torch.sum(h * mask_exp, dim=1) 
                denom = torch.clamp(mask_exp.sum(dim=1), min=1e-9)
                mean_h = sum_h / denom        

                pooled_layers.append(mean_h)   

            sent_vec = torch.cat(pooled_layers, dim=-1) 
            all_results.append(sent_vec.cpu().numpy())

        elif output_type == "token":
            batch_list = []
            for j in range(len(batch_sentences)):
                seq_len = attention_mask[j].sum().item()
                token_emb = last_hidden_state[j, :seq_len, :].cpu().numpy()
                batch_list.append(token_emb)
            all_results.extend(batch_list)

        else: 
            cls_emb = last_hidden_state[:, 0, :]
            all_results.append(cls_emb.cpu().numpy())

    if output_type == "token":
        return np.array(all_results, dtype=object)
    else:
        return np.vstack(all_results)


# Embedding for Kor
def get_kobert_embeddings_flexible(sentences, debug=True, output_type="mean"):

    batch_size = 32
    all_results = []
    iterator = range(0, len(sentences), batch_size)
    if debug:
        iterator = tqdm(iterator, desc="  -> Extracting Embedding with KoBERT")

    for i in iterator:
        batch_sentences = sentences[i:i+batch_size]
        inputs = kobert_tokenizer(
            batch_sentences,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

        with torch.no_grad():
            out = kobert_model(**inputs, output_hidden_states=True)

            last_hidden_state = out.last_hidden_state
            attention_mask = inputs['attention_mask']
            hidden_states = out.hidden_states    

        if output_type == "mean":
            layer_indices = [8, 10, 12]
            pooled_layers = []

            for li in layer_indices:
                h = hidden_states[li] 
                mask_exp = attention_mask.unsqueeze(-1).float()

                sum_h = torch.sum(h * mask_exp, dim=1) 
                denom = torch.clamp(mask_exp.sum(dim=1), min=1e-9)
                mean_h = sum_h / denom

                pooled_layers.append(mean_h)

            sent_vec = torch.cat(pooled_layers, dim=-1) 
            all_results.append(sent_vec.cpu().numpy())

        elif output_type == "token":
            batch_list = []
            for j in range(len(batch_sentences)):
                seq_len = attention_mask[j].sum().item()
                token_emb = last_hidden_state[j, :seq_len, :].cpu().numpy()
                batch_list.append(token_emb)
            all_results.extend(batch_list)
            
        else: # cls
            cls_emb = last_hidden_state[:, 0, :]
            all_results.append(cls_emb.cpu().numpy())

    if output_type == "token":
        return np.array(all_results, dtype=object)
    else:
        return np.vstack(all_results)

def get_boundary_tokens(model_type: str):

    model_type = model_type.lower()

    # SentencePiece 계열 (kobert, xlm-r, electra-kor 등)
    if "kobert" in model_type or "xlm" in model_type or "sentencepiece" in model_type:
        return "▁▲", "▁△"

    # WordPiece 계열 (bert-base, bert-multilingual, roberta 등)
    return "■", "●"



def build_marked_paragraph(paragraph_text, target_idx, model_type):
    """
    wrap target_idx sentence with START/END
    Create new marked_paragraph.
    """

    START, END = get_boundary_tokens(model_type)

    if isinstance(target_idx, list):
        start_i = target_idx[0]
        end_i   = target_idx[-1]
    else:
        start_i = end_i = target_idx

    sentences = split_sentences(paragraph_text)
    target_chunk = " ".join(sentences[start_i:end_i+1])

    new_sentences = []
    new_sentences.extend(sentences[:start_i])              
    new_sentences.append(f"{START} {target_chunk} {END}")  
    new_sentences.extend(sentences[end_i+1:])             

    marked_paragraph = " ".join(new_sentences)
    return marked_paragraph

def get_contextual_sentence_embedding(
    paragraph_text,
    target_idx,
    tokenizer,
    model,
    model_type,
    layer_indices=[8, 10, 12]
):
    """
    모델 타입(KoBERT/BERT/XLM-R)에 따라 boundary 토큰 자동 선택.
    1. target sentence만 boundary로 감싸기
    2. tokenize + forward
    3. boundary span pooling

    """

    marked_paragraph = build_marked_paragraph(paragraph_text, target_idx, model_type)

    encoded = tokenizer(
        marked_paragraph,
        return_tensors="pt",
        truncation=True,
        add_special_tokens=True
    )
    encoded = {k: v.to(DEVICE) for k, v in encoded.items()}

    input_ids = encoded["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    # boundary 토큰 가져오기
    START, END = get_boundary_tokens(model_type)

    try:
        start_pos = tokens.index(START) + 1
        end_pos   = tokens.index(END) - 1
    except ValueError:
        print("boundary span 매칭 실패 -> CLS fallback")
        with torch.no_grad():
            out = model(**encoded, output_hidden_states=True)
        return out.hidden_states[-1][0, 0].cpu().numpy()

    with torch.no_grad():
        out = model(**encoded, output_hidden_states=True)

    hidden_states = out.hidden_states


    pooled = []
    for li in layer_indices:
        h = hidden_states[li][0]  
        vec = h[start_pos:end_pos+1].mean(dim=0)
        pooled.append(vec.cpu().numpy())

    return np.concatenate(pooled, axis=0)



  from .autonotebook import tqdm as notebook_tqdm
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'XLNetTokenizer'. 
The class this function is called from is 'KoBERTTokenizer'.


# 1. Converting and result Check

## Converter

In [2]:
# TODO
# 차원수가 지금 2304라 이거 각 layer의 mean을 쓰거나, 아니면 나중에 데이터양을 늘리자, 책 2권 더 해서
# Setting
ENG_FILE_PATH = "text_source/eng_testp1.txt"
KOR_FILE_PATH = "text_source/kor_testp1.txt"
PARAGRAPH_FILE_PATH = "days_results/day1_paragraph_data.npz"
SENTENCE_FILE_PATH = "days_results/day1_sentence_data.npz"
EDU_FILE_PATH = "days_results/day1_edu_data.npz"
EMBED_DEBUG = False

# Loading 메인 실행
if __name__ == "__main__":

    eng_raw = load_and_segment_text(ENG_FILE_PATH)
    eng_paragraph = split_paragraphs(eng_raw)
    eng_sentences = split_sentences("\n".join(eng_paragraph))

    kor_raw = load_and_segment_text(KOR_FILE_PATH)
    kor_paragraph = split_paragraphs(kor_raw)
    kor_sentences = split_sentences("\n".join(kor_paragraph))

    #---------------------------------------------------------------------------------------#

    save_sentences(eng_paragraph, "text_results/eng_paragraph.txt")
    save_sentences(kor_paragraph, "text_results/kor_paragraph.txt")

    print(f"Extract Paragraph: ENG={len(eng_paragraph)}, KOR={len(kor_paragraph)}")

    aligned_paragraph = paragraph_dp_align_v2(eng_paragraph, kor_paragraph, debug = True)
    print(f"Aligned Paragraph Amount: {len(aligned_paragraph)}")

    df = pd.DataFrame(aligned_paragraph)

    print("Extracting Eng Embedding")
    final_par_eng = get_embeddings_flexible(df["eng"].tolist(), "bert-base-uncased",  EMBED_DEBUG)

    print("Extracting Kor Embedding")
    final_par_kor = get_kobert_embeddings_flexible(df["kor"].tolist(), EMBED_DEBUG)

    np.savez(
        PARAGRAPH_FILE_PATH,
        eng_embs=final_par_eng,
        kor_embs=final_par_kor,
        eng_sents=df["eng"].values,
        kor_sents=df["kor"].values
    )
    #---------------------------------------------------------------------------------------

# TODO
# 문단 자르기 
# 문단을 dp 정렬하되, xml-r을 쓰기 함
# 그리고 문단마다 페어된 상태로 문장들을 페어 맞추기 (이때도 똑같이 xml-r을 사용한 dp aligment)- 함
#---------------------------------------------------------------------------------------#
    aligned_sentence = align_sentences_from_paragraphs_v2(aligned_paragraph)
    print(f"Aligned Sentences in aligned pargraph: {len(aligned_sentence)}")

    df_sen = pd.DataFrame(aligned_sentence)
    #위에까진 xml-r로 정렬이 됌, 그럼 이제 밑에서 
#---------------------------------------------------------------------------------------#
    # 이거 하려고 했던게, 원래는 문장 정렬
    final_sen_eng = []
    final_sen_kor = []

    for row in df_sen.itertuples():
        paragraph_eng = row.paragraph_eng
        paragraph_kor = row.paragraph_kor

        eng_vec = get_contextual_sentence_embedding(
                paragraph_text= row.paragraph_eng,
                target_idx = row.eng_idx,
                tokenizer = bert_tokenizer,
                model = bert_model,
                model_type = "bert",
        )

        kor_vec = get_contextual_sentence_embedding(
                paragraph_text= row.paragraph_kor,
                target_idx = row.kor_idx,
                tokenizer = kobert_tokenizer,
                model = kobert_model,
                model_type = "kobert",
        )

        final_sen_eng.append(eng_vec)
        final_sen_kor.append(kor_vec)

    final_sen_eng = np.vstack(final_sen_eng)
    final_sen_kor = np.vstack(final_sen_kor)

    np.savez(
        SENTENCE_FILE_PATH,
        eng_embs=final_sen_eng,
        kor_embs=final_sen_kor,
        eng_sents=df_sen["eng"].values,
        kor_sents=df_sen["kor"].values,

        paragraph_eng=df_sen["paragraph_eng"].values,
        paragraph_kor=df_sen["paragraph_kor"].values,
        eng_idx=df_sen["eng_idx"].values,
        kor_idx=df_sen["kor_idx"].values
    )
#---------------------------------------------------------------------------------------#
    # This is for without context
    print("without context Eng Embedding")
    final_sen_eng_without = get_embeddings_flexible(df_sen["eng"].tolist(), "bert-base-uncased", EMBED_DEBUG)

    print("without context Kor Embedding")
    final_sen_kor_without = get_kobert_embeddings_flexible(df_sen["kor"].tolist(), EMBED_DEBUG)
  
    np.savez(
        "days_results/day1_sentence_data_without.npz",
        eng_embs=final_sen_eng_without,
        kor_embs=final_sen_kor_without,
        eng_sents=df_sen["eng"].values,
        kor_sents=df_sen["kor"].values,
        
        paragraph_eng=df_sen["paragraph_eng"].values,
        paragraph_kor=df_sen["paragraph_kor"].values,
        eng_idx=df_sen["eng_idx"].values,
        kor_idx=df_sen["kor_idx"].values
    )
    print("Without context sentence npz saved")
#---------------------------------------------------------------------------------------#
    print(f"Finished Saving")


Extract Paragraph: ENG=27, KOR=30


  ret = a @ b
  ret = a @ b
  ret = a @ b


Unused ENG = 0, Unused KOR = 0
Aligned Paragraph Amount: 26
Extracting Eng Embedding
Load Model: bert-base-uncased | Mode: mean
Extracting Kor Embedding
Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0


  ret = a @ b
  ret = a @ b
  ret = a @ b


Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0


  ret = a @ b
  ret = a @ b
  ret = a @ b


Unused ENG = 0, Unused KOR = 0


  ret = a @ b
  ret = a @ b
  ret = a @ b


Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 1


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0


  ret = a @ b
  ret = a @ b
  ret = a @ b


Unused ENG = 0, Unused KOR = 0


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


Unused ENG = 0, Unused KOR = 0
Unused ENG = 0, Unused KOR = 0


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Unused ENG = 0, Unused KOR = 0
Aligned Sentences in aligned pargraph: 154
without context Eng Embedding
Load Model: bert-base-uncased | Mode: mean
without context Kor Embedding
Without context sentence npz saved
Finished Saving


## .npz information checker

In [3]:
import numpy as np

# File input
# example : 'day1_sentence_data_mean.npz' 또는 'day1_sentence_data_token.npz'
FILE_NAME = 'days_results/day1_sentence_data.npz' 


NUM_SAMPLES_TO_SHOW = 5 

try:
    day1_data = np.load(FILE_NAME, allow_pickle=True)
    
    print(f"Checking '{FILE_NAME}' ")
    
    print(day1_data.files)

    # Check data Type
    if 'eng_embs' in day1_data:
        eng_embs = day1_data['eng_embs']
        
        is_token_sequence = (eng_embs.dtype == 'O') or (len(eng_embs.shape) == 1)
        
        print(f"Aligned Amount: {eng_embs.shape[0]}")
        
        if is_token_sequence:
            first_shape = eng_embs[0].shape if len(eng_embs) > 0 else "Unknown"
            print(f"Embedding Dimension: {first_shape}")
        else:
            print(f"Embedding Dimension: {eng_embs.shape[1]}")
    
    # Load Aligned sentences
    eng_sents = day1_data['eng_sents']
    kor_sents = day1_data['kor_sents']
    
    print(f"\n Aligned Sentences {NUM_SAMPLES_TO_SHOW} List")
    
    for i in range(len(eng_sents)):
        print(f"\n----# {i+1}----")
        print(f" ENG: {eng_sents[i]}")
        print(f" KOR: {kor_sents[i]}")
        
    day1_data.close()
    
except FileNotFoundError:
    print(f"Error: cannot find '{FILE_NAME}'.")
except Exception as e:
    print(f"Error: {e}")

Checking 'days_results/day1_sentence_data.npz' 
['eng_embs', 'kor_embs', 'eng_sents', 'kor_sents', 'paragraph_eng', 'paragraph_kor', 'eng_idx', 'kor_idx']
Aligned Amount: 154
Embedding Dimension: 2304

 Aligned Sentences 5 List

----# 1----
 ENG: CHAPTER II. The Pool of Tears
 KOR: 제2장 눈물 웅덩이

----# 2----
 ENG: “Curiouser and curiouser!”
 KOR: “요상하고도 요상해!”

----# 3----
 ENG: cried Alice
 KOR: 앨리스는 소리쳤다.

----# 4----
 ENG: (she was so much surprised, that for the moment she quite forgot how to speak good English);
 KOR: (앨리스는 너무 놀란 나머지 말조차 똑바로 하지 못했다.)

----# 5----
 ENG: “now I’m opening out like the largest telescope that ever was!
 KOR: “이젠 내가 세상에서 가장 큰 망원경처럼 펼쳐져 버렸어.

----# 6----
 ENG: Good-bye, feet!”
 KOR: 잘있어 - 내 발아!”

----# 7----
 ENG: (for when she looked down at her feet, they seemed to be almost out of sight, they were getting so far off).
 KOR: (앨리스가 발을 쳐다 보니 까마득히 멀리 있어서 겨우 보일락 말락 할 지경이었다.)

----# 8----
 ENG: “Oh, my poor little feet, I wonder who will put on your shoes and s

## .npz NaN Checker

In [4]:
import numpy as np
import os

# --- 0. 설정 ---
# .npz List you want to check
FILES_TO_CHECK = [
    "days_results/day1_sentence_data.npz",
]

def check_abnormal_values(arr, name):
    if arr.dtype == 'O': # (Option B)
        total_elements = sum(x.size for x in arr)
        nan_count = sum(np.isnan(x).sum() for x in arr)
        inf_count = sum(np.isinf(x).sum() for x in arr)
    else: # (Option A)
        total_elements = arr.size
        nan_count = np.isnan(arr).sum()
        inf_count = np.isinf(arr).sum()

    if nan_count > 0 or inf_count > 0:
        print(f"Warning")
        print(f"- NaN: {nan_count}")
        print(f"- Inf: {inf_count}")
    else:
        print(f"{name}: no NaN/Inf")

def inspect_data(file_path):
    if not os.path.exists(file_path):
        print(f"Error : cannot find the '{file_path}' ")
        return

    try:
        data = np.load(file_path, allow_pickle=True)
        print(f"Analyzing: {file_path}")
        
        eng_sents = data['eng_sents'] if 'eng_sents' in data else []
        kor_sents = data['kor_sents'] if 'kor_sents' in data else []
        print(f"Total sentences amount: {len(eng_sents)}")

        # 2. NaN/Inf check
        if 'eng_embs' in data:
            eng_embs = data['eng_embs']
            kor_embs = data['kor_embs'] 
            
            # .npz type check
            if eng_embs.dtype != 'O' and len(eng_embs.shape) == 2:
                print(f"Data Type: [Option A] Fixed Vector {eng_embs.shape}")
            elif eng_embs.dtype == 'O' or len(eng_embs.shape) == 1:
                print(f"Data Type: [Option B] non Fixed Vector")
            
            # Check Abnormal (NaN/Inf)
            check_abnormal_values(eng_embs, "Eng Embedding")
            check_abnormal_values(kor_embs, "Kor Embedding")

        # 3. Printing out 
        print(f"\n Sample Printing ({len(eng_sents)}):")
        
        for i in range(10):
            print(f"\n[Pair {i+1}]")
            
            print(f"ENG: {eng_sents[i]}")
            print(f"KOR: {kor_sents[i]}")
            
            if 'eng_embs' in locals():
                if eng_embs.dtype != 'O': 
                    if np.isnan(eng_embs[i]).any() or np.isinf(eng_embs[i]).any():
                        print(f"There is a NaN or inf in this sentence !")
                    else:
                        norm_val = np.linalg.norm(eng_embs[i])
                        print(f"Pass (Norm: {norm_val:.4f})")

        data.close()

    except Exception as e:
        print(f" Error : {e}")

if __name__ == "__main__":
    for f in FILES_TO_CHECK:
        inspect_data(f)

Analyzing: days_results/day1_sentence_data.npz
Total sentences amount: 154
Data Type: [Option A] Fixed Vector (154, 2304)
Eng Embedding: no NaN/Inf
Kor Embedding: no NaN/Inf

 Sample Printing (154):

[Pair 1]
ENG: CHAPTER II. The Pool of Tears
KOR: 제2장 눈물 웅덩이
Pass (Norm: 23.0947)

[Pair 2]
ENG: “Curiouser and curiouser!”
KOR: “요상하고도 요상해!”
Pass (Norm: 25.0098)

[Pair 3]
ENG: cried Alice
KOR: 앨리스는 소리쳤다.
Pass (Norm: 26.5862)

[Pair 4]
ENG: (she was so much surprised, that for the moment she quite forgot how to speak good English);
KOR: (앨리스는 너무 놀란 나머지 말조차 똑바로 하지 못했다.)
Pass (Norm: 21.5968)

[Pair 5]
ENG: “now I’m opening out like the largest telescope that ever was!
KOR: “이젠 내가 세상에서 가장 큰 망원경처럼 펼쳐져 버렸어.
Pass (Norm: 22.4167)

[Pair 6]
ENG: Good-bye, feet!”
KOR: 잘있어 - 내 발아!”
Pass (Norm: 22.5986)

[Pair 7]
ENG: (for when she looked down at her feet, they seemed to be almost out of sight, they were getting so far off).
KOR: (앨리스가 발을 쳐다 보니 까마득히 멀리 있어서 겨우 보일락 말락 할 지경이었다.)
Pass (Norm: 21.9558)

[P

# 2. Training Adapter

## Adapter Trainer

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from scipy.linalg import orthogonal_procrustes
from sklearn.metrics.pairwise import cosine_similarity
import copy

# Settings

SENTENCE_INPUT_FILE = "days_results/day1_sentence_data.npz"
SENTENCE_OUTPUT_FILE = "days_results/day2_sentence_results.npz"

SENTENCE_LONG_INPUT_FILE = "days_results/day1_sentence_data_long.npz"
SENTENCE_LONG_OUTPUT_FILE = "days_results/day2_sentence_results_long.npz"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 2304

LR = 5e-4
EPOCHS = 1000
PATIENCE = 40
TEMP = 0.07


# 1. Load Data
def load_raw_embeddings(path):
    print(f"\nLoading: {path}")
    data = np.load(path, allow_pickle=True)

    eng = data["eng_embs"]
    kor = data["kor_embs"]
    eng_sents = data["eng_sents"]
    kor_sents = data["kor_sents"]
    paragraph_eng = data.get("paragraph_eng", None)
    paragraph_kor = data.get("paragraph_kor", None)
    eng_idx = data.get("eng_idx", None)
    kor_idx = data.get("kor_idx", None)

    if eng.ndim == 3:
        print("Token embeddings detected -> mean pooled")
        eng = np.vstack([np.mean(x, axis=0) for x in tqdm(eng)])
        kor = np.vstack([np.mean(x, axis=0) for x in tqdm(kor)])
    else:
        print("Sentence embeddings detected")

    eng = np.nan_to_num(eng)
    kor = np.nan_to_num(kor)

    return (
        eng, kor, # sentence embeddings
        eng_sents, kor_sents, # sentence text
        paragraph_eng, # context paragraph
        paragraph_kor,
        eng_idx, kor_idx # in-paragraph indices
    )


def compute_procrustes(eng_raw, kor_raw):
    print("\n Computing Procrustes")

    R, _ = orthogonal_procrustes(eng_raw, kor_raw)
    aligned = eng_raw @ R

    aligned = aligned / (np.linalg.norm(aligned, axis=1, keepdims=True) + 1e-12)
    kor_norm = kor_raw / (np.linalg.norm(kor_raw, axis=1, keepdims=True) + 1e-12)

    sim = cosine_similarity(aligned, kor_norm)
    acc = np.mean(np.argmax(sim, axis=1) == np.arange(len(sim)))
    return acc, R, aligned


# Model Definitions

class ProcrustesModel(nn.Module):
    def __init__(self, R):
        super().__init__()
        self.register_buffer("R", torch.tensor(R, dtype=torch.float32))

    def forward(self, x):
        return x @ self.R  
    
class LinearMap(nn.Module):
    def __init__(self, dim=EMBED_DIM, R_init=None):
        super().__init__()
        self.fc = nn.Linear(dim, dim, bias=False)

        if R_init is not None:
            if R_init.shape == (dim, dim):
                self.fc.weight.data.copy_(torch.from_numpy(R_init.T.astype(np.float32)))
                print("Linear model initialized with Procrustes matrix")
            else:
                print("Warning: R_init shape mismatch, skipping initialization.")

    def forward(self, x):
        return self.fc(x)


class MLP(nn.Module):
    def __init__(self, dim=EMBED_DIM, R_init=None):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(dim, dim)

        if R_init is not None:
            if R_init.shape == (dim, dim):
                self.fc1.weight.data.copy_(torch.from_numpy(R_init.T.astype(np.float32)))
                print("MLP initialized with Procrustes matrix")
            else:
                print("Warning: R_init shape mismatch, skipping initialization.")

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))


class ResidualMLP(nn.Module):
    def __init__(self, dim=EMBED_DIM, R_init=None):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(dim, dim)

        if R_init is not None:
            if R_init.shape == (dim, dim):
                self.fc1.weight.data.copy_(torch.from_numpy(R_init.T.astype(np.float32)))
                print("Residual MLP initialized with Procrustes matrix")
            else:
                print("Warning: R_init shape mismatch, skipping initialization.")

        nn.init.zeros_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        return x + self.fc2(self.act(self.fc1(x)))


class L2Loss(nn.Module):
    def forward(self, a, b):
        return torch.mean((a - b)**2)


class InfoNCELoss(nn.Module):
    def __init__(self, temp=TEMP):
        super().__init__()
        self.temp = temp
        self.ce = nn.CrossEntropyLoss()

    def forward(self, a, b):
        a = F.normalize(a, dim=1)
        b = F.normalize(b, dim=1)
        logits = (a @ b.T) / self.temp
        labels = torch.arange(len(a), device=logits.device)
        return (self.ce(logits, labels) + self.ce(logits.T, labels)) / 2

# Training Loop 
def train(model, eng, kor):
    model.train()
    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = InfoNCELoss()
    # loss_fn = L2Loss()
    t_eng = torch.tensor(eng, dtype=torch.float32).to(DEVICE)
    t_kor = torch.tensor(kor, dtype=torch.float32).to(DEVICE)

    best_loss = np.inf
    patience = 0
    best_state = None

    for ep in range(EPOCHS):
        opt.zero_grad()
        out = model(t_eng)
        loss = loss_fn(out, t_kor)
        loss.backward()
        opt.step()

        loss_val = loss.item()

        if loss_val < best_loss:
            best_loss = loss_val
            best_state = model.state_dict()
            patience = 0
        else:
            patience += 1

        if (ep + 1) % 40 == 0:
            print(f" Epoch {ep+1}/{EPOCHS} | loss={loss_val:.4f}")

        if patience >= PATIENCE:
            print("Early stopping")
            break

    model.load_state_dict(best_state)
    return model

def train_with_r1(model, train_eng, train_kor, val_eng, val_kor):
    model.train()
    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = InfoNCELoss()

    t_train_eng = torch.tensor(train_eng, dtype=torch.float32).to(DEVICE)
    t_train_kor = torch.tensor(train_kor, dtype=torch.float32).to(DEVICE)

    best_r1 = -1
    patience = 0
    best_state = None

    for ep in range(EPOCHS):
        # Train step
        opt.zero_grad()
        out = model(t_train_eng)
        loss = loss_fn(out, t_train_kor)
        loss.backward()
        opt.step()

        # Validation (R@1)
        model.eval()
        with torch.no_grad():
            proj = model(torch.tensor(val_eng, dtype=torch.float32).to(DEVICE))
            proj = proj.cpu().numpy()
            proj = proj / (np.linalg.norm(proj, axis=1, keepdims=True) + 1e-12)
            sim = cosine_similarity(proj, val_kor)
            r1 = np.mean(np.argmax(sim, axis=1) == np.arange(len(sim)))

        # Print log
        if (ep + 1) % 20 == 0:
            print(f" Epoch {ep+1}/{EPOCHS} | loss={loss.item():.4f} | R@1={r1:.2%} | best={best_r1:.2%}")

        # Update best model
        if r1 > best_r1:
            best_r1 = r1
            best_state = copy.deepcopy(model.state_dict())
            patience = 0
        else:
            patience += 1

        # Early stopping
        if patience >= PATIENCE:
            print(f"Early stopping (no improvement for {PATIENCE} epochs)")
            break

        model.train()

    # Load best model
    model.load_state_dict(best_state)
    return model



# Evaluation
def evaluate(model, eng, kor):
    model.eval()
    with torch.no_grad():
        proj = model(torch.tensor(eng, dtype=torch.float32).to(DEVICE))
        proj = proj.cpu().numpy()
        proj = proj / (np.linalg.norm(proj, axis=1, keepdims=True) + 1e-12)

    # kor = kor / (np.linalg.norm(kor, axis=1, keepdims=True) + 1e-12)
    sim = cosine_similarity(proj, kor)
    acc = np.mean(np.argmax(sim, axis=1) == np.arange(len(sim)))

    return acc, proj, sim


# Main
def run_day2(INPUT, OUTPUT):
    (
    eng_raw,
    kor_raw,
    eng_sents,
    kor_sents,
    paragraph_eng,
    paragraph_kor,
    eng_idx,
    kor_idx
    ) = load_raw_embeddings(INPUT)


    eng_norm = eng_raw / (np.linalg.norm(eng_raw, axis=1, keepdims=True) + 1e-12)
    kor_norm = kor_raw / (np.linalg.norm(kor_raw, axis=1, keepdims=True) + 1e-12)
    # Split test and train data
    N = len(eng_raw)
    S = int(N * 0.8)

    train_eng, test_eng = eng_norm[:S], eng_norm[S:]
    train_kor, test_kor = kor_norm[:S], kor_norm[S:]

    procrustes_acc, R, procrustes_proj = compute_procrustes(train_eng, train_kor)

    print("\nTRAINING MODELSS")

    models = {
        # "Procrustes" : ProcrustesModel(R).to(DEVICE),
        "MLP_Pro": MLP(R_init=R).to(DEVICE),
        "MLP_Rand": MLP().to(DEVICE),
        "Res_Pro": ResidualMLP(R_init=R).to(DEVICE),
        "Res_Rand": ResidualMLP().to(DEVICE),
        "Linear_Pro": LinearMap(R_init=R).to(DEVICE),
        "Linear_Rand": LinearMap().to(DEVICE),
    }

    results = {}

    for name, model in models.items():
        print(f"\nTraining {name}")
        if name == "Procrustes":
            trained = model
        else:
            trained = train_with_r1(model, train_eng, train_kor, test_eng,test_kor)
        acc, proj, sim = evaluate(trained, test_eng, test_kor)
        results[name] = {"acc": acc, "proj": proj, "model": trained}
        print(f"{name} R@1: {acc:.2%}")

    # -------- SELECT BEST NON-PROCRUSTES MODEL --------
    print("\nSelecting Best Model")
    best_model_name = max(
        [k for k in results.keys() if "Procrustes" not in k], 
        key=lambda x: results[x]["acc"]
    )
    best_model_proj = results[best_model_name]["proj"]
    best_model_acc = results[best_model_name]["acc"]
    best_model = results[best_model_name]["model"]
    # Save best model weights for Day3
    torch.save(best_model.state_dict(), OUTPUT.replace(".npz", "_best_model.pt"))


    print(f"Best model: {best_model_name} (R@1={best_model_acc:.2%})")
    # ---- SAVE RESULTS (FULL FORMAT FOR DAY3) ----
    print("\nSaving all results")

    # baseline similarity (basline)
    baseline_scores = np.diag(cosine_similarity(test_eng, test_kor))

    # aligned similarity (best model)
    aligned_scores = np.diag(cosine_similarity(best_model_proj, test_kor))

    pro_aligned_scores = np.diag(cosine_similarity(procrustes_proj, train_kor))

    np.savez(
        OUTPUT,

        # --- For Day3 baseline ---
        procrustes_proj=procrustes_proj, 
        procrustes_acc=procrustes_acc,
        pro_aligned_scores = pro_aligned_scores,
        R = R, 

        # --- Best model info ---
        best_model_name=best_model_name,
        best_model_acc=best_model_acc,
        projected_eng_embs=best_model_proj,   

        # --- For visualization + analysis ---
        test_eng_embs=test_eng,
        test_kor_embs=test_kor,
        baseline_scores=baseline_scores,       
        aligned_scores=aligned_scores,         

        # --- Sentence metadata ---
        test_eng_sents=eng_sents[S:],          
        test_kor_sents=kor_sents[S:],

        test_paragraph_eng = paragraph_eng[S:],
        test_paragraph_kor = paragraph_kor[S:],
        test_eng_idx = eng_idx[S:],
        test_kor_idx = kor_idx[S:],
    )

    print(f"Saved to {OUTPUT}")

# run_day2("days_results/day1_sentence_data_without.npz", "days_results/day2_sentence_results_without.npz")
run_day2(SENTENCE_INPUT_FILE, SENTENCE_OUTPUT_FILE)
# run_day2("days_results/day1_sentence_data_long_without.npz", "days_results/day2_sentence_results_long_without.npz")
# run_day2(SENTENCE_LONG_INPUT_FILE, SENTENCE_LONG_OUTPUT_FILE)







Loading: days_results/day1_sentence_data.npz
Sentence embeddings detected

 Computing Procrustes


  u, w, vt = svd((B.T @ np.conjugate(A)).T)
  u, w, vt = svd((B.T @ np.conjugate(A)).T)
  u, w, vt = svd((B.T @ np.conjugate(A)).T)
  R = u @ vt
  R = u @ vt
  R = u @ vt
  aligned = eng_raw @ R
  aligned = eng_raw @ R
  aligned = eng_raw @ R
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b



TRAINING MODELSS
MLP initialized with Procrustes matrix
Residual MLP initialized with Procrustes matrix
Linear model initialized with Procrustes matrix

Training MLP_Pro


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 20/1000 | loss=0.3231 | R@1=67.74% | best=70.97%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 40/1000 | loss=0.1566 | R@1=67.74% | best=70.97%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


Early stopping (no improvement for 40 epochs)
MLP_Pro R@1: 70.97%

Training MLP_Rand


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 20/1000 | loss=0.4298 | R@1=61.29% | best=61.29%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 40/1000 | loss=0.1843 | R@1=67.74% | best=74.19%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret 

 Epoch 60/1000 | loss=0.1379 | R@1=67.74% | best=74.19%
Early stopping (no improvement for 40 epochs)
MLP_Rand R@1: 74.19%

Training Res_Pro


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret 

 Epoch 20/1000 | loss=0.7273 | R@1=61.29% | best=67.74%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 40/1000 | loss=0.3042 | R@1=64.52% | best=67.74%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret 

 Epoch 60/1000 | loss=0.1909 | R@1=67.74% | best=70.97%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 80/1000 | loss=0.1500 | R@1=61.29% | best=70.97%
Early stopping (no improvement for 40 epochs)
Res_Pro R@1: 70.97%

Training Res_Rand


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret 

 Epoch 20/1000 | loss=0.9898 | R@1=64.52% | best=64.52%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 40/1000 | loss=0.4131 | R@1=70.97% | best=70.97%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 60/1000 | loss=0.2425 | R@1=74.19% | best=74.19%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 80/1000 | loss=0.1768 | R@1=64.52% | best=74.19%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


Early stopping (no improvement for 40 epochs)
Res_Rand R@1: 74.19%

Training Linear_Pro


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 20/1000 | loss=0.1999 | R@1=70.97% | best=70.97%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


 Epoch 40/1000 | loss=0.1396 | R@1=64.52% | best=74.19%
 Epoch 60/1000 | loss=0.1251 | R@1=58.06% | best=74.19%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


Early stopping (no improvement for 40 epochs)
Linear_Pro R@1: 74.19%

Training Linear_Rand
 Epoch 20/1000 | loss=0.4614 | R@1=67.74% | best=70.97%


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret 

 Epoch 40/1000 | loss=0.2343 | R@1=67.74% | best=70.97%
Early stopping (no improvement for 40 epochs)
Linear_Rand R@1: 70.97%

Selecting Best Model
Best model: MLP_Rand (R@1=74.19%)

Saving all results
Saved to days_results/day2_sentence_results.npz


  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b
  ret = a @ b


# 3. Test the trained Adapter

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
import random
from openai import OpenAI
import os
from sklearn.manifold import TSNE

from dotenv import load_dotenv
load_dotenv()


EMBED_DIM = 2304
EMBED_DEBUG = False   # If True debugging printing on / False debugging print off
INPUT = "days_results/day2_sentence_results.npz"


# 0. Model Definitions

class ProcrustesModel(nn.Module):
    def __init__(self, R):
        super().__init__()
        self.register_buffer("R", torch.tensor(R, dtype=torch.float32))

    def forward(self, x):
        return x @ self.R 
    
class LinearMap(nn.Module):
    def __init__(self, dim=EMBED_DIM, R_init=None):
        super().__init__()
        self.fc = nn.Linear(dim, dim, bias=False)

        if R_init is not None:
            if R_init.shape == (dim, dim):
                self.fc.weight.data.copy_(torch.from_numpy(R_init.T.astype(np.float32)))
                print("Linear model initialized with Procrustes matrix")
            else:
                print("Warning: R_init shape mismatch, skipping initialization.")

    def forward(self, x):
        return self.fc(x)


class MLP(nn.Module):
    def __init__(self, dim=EMBED_DIM, R_init=None):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(dim, dim)

        if R_init is not None:
            if R_init.shape == (dim, dim):
                self.fc1.weight.data.copy_(torch.from_numpy(R_init.T.astype(np.float32)))
                print("MLP initialized with Procrustes matrix")
            else:
                print("Warning: R_init shape mismatch, skipping initialization.")

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))


class ResidualMLP(nn.Module):
    def __init__(self, dim=EMBED_DIM, R_init=None):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(dim, dim)

        if R_init is not None:
            if R_init.shape == (dim, dim):
                self.fc1.weight.data.copy_(torch.from_numpy(R_init.T.astype(np.float32)))
                print("Residual MLP initialized with Procrustes matrix")
            else:
                print("Warning: R_init shape mismatch, skipping initialization.")

        nn.init.zeros_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        return x + self.fc2(self.act(self.fc1(x)))




# GPT based nuance-shift paraphrase creator 

client = OpenAI()

def gpt_generate_paraphrases(sentence):
    prompt = f"""
Sentence: "{sentence}"

Generate 4 English paraphrases that keep a roughly similar surface meaning,
but the nuance, tone, attitude, or emotional framing should be noticeably different.
Each paraphrase must feel distinct in nuance, even if the literal meaning is similar.

Format:

[PARA1]
sentence

[PARA2]
sentence

[PARA3]
sentence

[PARA4]
sentence
"""

    completion = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": prompt}]
    )

    res = completion.choices[0].message.content

    paras = []
    for tag in ["[PARA1]", "[PARA2]", "[PARA3]", "[PARA4]"]:
        try:
            part = res.split(tag)[1].split("[")[0].strip()
            paras.append(part)
        except:
            paras.append("")

    return paras



# 1. Sentence Quality checker
def perform_nuance_analysis(eng_sents, kor_sents, baseline_scores, aligned_scores, top_k=3):

    print("(Qualitative Case Study)")

    diff = aligned_scores - baseline_scores

    # best improvements
    top_idx = np.argsort(diff)[::-1][:top_k]

    print("\n Quality increased sentences Top-K")
    for i in top_idx:
        print(f"\n Increased amount: {diff[i]:.4f}")
        print("ENG:", eng_sents[i])
        print("KOR:", kor_sents[i])
        print(f"Before={baseline_scores[i]:.4f} -> After={aligned_scores[i]:.4f}")

    return top_idx


def visualize_alignment_arrows(eng_before, kor, eng_after, perplexity=10):
    """
    # Arrow plot showing ENG_before -> ENG_after movement using t-SNE (cosine distance)

    eng_before: english embeddings before alignment
    eng_after: english embeddings after alignment
    kor: korean embeddings 
    """

    print("\nRunning t-SNE for Arrow Plot ")

    all_data = np.vstack([eng_before, eng_after, kor])

    tsne = TSNE(
        n_components=2,
        perplexity=perplexity,
        metric="cosine",
        learning_rate="auto",
        random_state=42
    )

    all_2d = tsne.fit_transform(all_data)

    N = len(eng_before)

    eng_before_2d = all_2d[:N]
    eng_after_2d  = all_2d[N:2*N]
    kor_2d        = all_2d[2*N:3*N]

    # Plot
    plt.figure(figsize=(12, 10))

    # before ENG
    plt.scatter(eng_before_2d[:,0], eng_before_2d[:,1],
                color="blue", alpha=0.7, label="ENG_before")

    # after ENG 
    plt.scatter(eng_after_2d[:,0], eng_after_2d[:,1],
                color="green", alpha=0.7, label="ENG_after")

    # korean target
    plt.scatter(kor_2d[:,0], kor_2d[:,1],
                color="orange", alpha=0.7, label="KOR")

    # arrows: ENG_before -> ENG_after
    for i in range(N):
        plt.arrow(
            eng_before_2d[i,0], eng_before_2d[i,1],
            eng_after_2d[i,0] - eng_before_2d[i,0],
            eng_after_2d[i,1] - eng_before_2d[i,1],
            color="gray",
            alpha=0.5,
            width=0.05,
            length_includes_head=True,
            head_width=1.0
        )

    plt.title("ENG Alignment Path (Arrow Plot via t-SNE)")
    plt.legend()
    plt.grid(True)
    plt.show()

# run nuance test for no ctx npz. 
def run_nuance_test_without(eng_sents, kor_sents, best_model, indices):

    print("Nuance Perturbation Test")

    for idx in indices:
        src = eng_sents[idx]
        tgt = kor_sents[idx]

        print("ENG :", src)
        print("KOR :", tgt)

        # ENG embedding
        src_emb = get_embeddings_flexible([src], "bert-base-uncased", EMBED_DEBUG)
        tgt_emb = get_kobert_embeddings_flexible([tgt], EMBED_DEBUG)

        src_emb = src_emb / (np.linalg.norm(src_emb, axis=1, keepdims=True) + 1e-12)
        tgt_emb = tgt_emb / (np.linalg.norm(tgt_emb, axis=1, keepdims=True) + 1e-12)

        # Best model forward 
        with torch.no_grad():
            proj_src = best_model(torch.tensor(src_emb, dtype=torch.float32).to(DEVICE))
        proj_src = proj_src.cpu().numpy()
        proj_src = proj_src / (np.linalg.norm(proj_src, axis=1, keepdims=True) + 1e-12)

        sim_orig = cosine_similarity(proj_src, tgt_emb)[0][0]
        print(f"원문 sim = {sim_orig:.4f}")

        # GPT paraphrase
        paras = gpt_generate_paraphrases(src)

        for j, p in enumerate(paras, start=1):
            emb_p = get_embeddings_flexible([p], "bert-base-uncased", EMBED_DEBUG)
            emb_p = emb_p / (np.linalg.norm(emb_p, axis=1, keepdims=True) + 1e-12)

            with torch.no_grad():
                proj_p = best_model(torch.tensor(emb_p, dtype=torch.float32).to(DEVICE))
            proj_p = proj_p.cpu().numpy()
            proj_p = proj_p / (np.linalg.norm(proj_p, axis=1, keepdims=True) + 1e-12)

            sim_p = cosine_similarity(proj_p, tgt_emb)[0][0]

            print(f"\nPARA{j}: sim={sim_p:.4f}")
            print("->", p)

# function to repleac the idx sentence to GPT PARA
def get_contextual_embedding_with_replacement(
    paragraph_text,
    target_idx,
    new_sentence,
    tokenizer,
    model,
    model_type,
    layer_indices=[8, 10, 12]
):

    sentences = split_sentences(paragraph_text)

    # target idx arrange
    if isinstance(target_idx, list):
        start_i = target_idx[0]
        end_i   = target_idx[-1]
    else:
        start_i = end_i = target_idx

    if start_i < 0 or end_i >= len(sentences):
        raise ValueError("invalid target_idx")

    # target span sentence replacement
    new_sent_list = sentences[:start_i] + [new_sentence] + sentences[end_i+1:]

    modified_paragraph = " ".join(new_sent_list)

    return get_contextual_sentence_embedding(
        paragraph_text = modified_paragraph,
        target_idx = start_i,        
        tokenizer = tokenizer,
        model = model,
        model_type = model_type,
        layer_indices = layer_indices
    )


def run_nuance_test(
    eng_sents,
    kor_sents,
    paragraph_eng,
    paragraph_kor,
    eng_idx,
    kor_idx,
    best_model,
    indices
):

    print("Nuance Perturbation Test (Contextual)")

    for idx in indices:
        src = eng_sents[idx]
        tgt = kor_sents[idx]

        print("ENG :", src)
        print("ENG paragraph:", paragraph_eng[idx])
        print("KOR :", tgt)
        print("KOR paragraph:", paragraph_kor[idx])

        # ============================
        # 1) 원문 contextual embedding
        # ============================
        src_emb = get_contextual_sentence_embedding(
            paragraph_text = paragraph_eng[idx],
            target_idx = eng_idx[idx],
            tokenizer = bert_tokenizer,
            model = bert_model,
            model_type = "bert",
        )

        tgt_emb = get_contextual_sentence_embedding(
            paragraph_text = paragraph_kor[idx],
            target_idx = kor_idx[idx],
            tokenizer = kobert_tokenizer,
            model = kobert_model,
            model_type = "kobert",
        )

        # normalize
        src_emb = src_emb.reshape(1, -1)
        tgt_emb = tgt_emb.reshape(1, -1)

        src_emb /= (np.linalg.norm(src_emb, axis=1, keepdims=True) + 1e-12)
        tgt_emb /= (np.linalg.norm(tgt_emb, axis=1, keepdims=True) + 1e-12)

        #  Model forward 
        with torch.no_grad():
            proj_src = best_model(torch.tensor(src_emb, dtype=torch.float32).to(DEVICE))
        proj_src = proj_src.cpu().numpy()
        proj_src /= (np.linalg.norm(proj_src, axis=1, keepdims=True) + 1e-12)

        sim_orig = cosine_similarity(proj_src, tgt_emb)[0][0]
        print(f"원문 sim = {sim_orig:.4f}")


        # 2) 4 GPT paraphrase 
        paras = gpt_generate_paraphrases(src)

        for j, p in enumerate(paras, start=1):

            emb_p = get_contextual_embedding_with_replacement(
                paragraph_text = paragraph_eng[idx],
                target_idx = eng_idx[idx],
                new_sentence = p,
                tokenizer = bert_tokenizer,
                model = bert_model,
                model_type = "bert"
            )

            emb_p = emb_p.reshape(1, -1)
            emb_p /= (np.linalg.norm(emb_p, axis=1, keepdims=True) + 1e-12)

            with torch.no_grad():
                proj_p = best_model(torch.tensor(emb_p, dtype=torch.float32).to(DEVICE))
            proj_p = proj_p.cpu().numpy()
            proj_p /= (np.linalg.norm(proj_p, axis=1, keepdims=True) + 1e-12)

            sim_p = cosine_similarity(proj_p, tgt_emb)[0][0]

            print(f"\nPARA{j}: sim={sim_p:.4f}")
            print("->", p)


# 4. Main : Load Saved Outputs & Run Evaluation
if __name__ == "__main__":

    print("Trained Adapter Tester Loading")


    # # Sentence npz data
    # data = np.load("days_results/day2_sentence_results.npz", allow_pickle=True)

    # Paragraph npz data
    data = np.load(INPUT, allow_pickle = True)


    # ----- Load according to Day2 keys -----
    procrustes_proj  = data["procrustes_proj"]
    procrustes_acc   = float(data["procrustes_acc"])
    R = data["R"]

    best_name        = str(data["best_model_name"])
    best_acc         = float(data["best_model_acc"])
    proj_eng_embs    = data["projected_eng_embs"]    # aligned ENG embeddings

    test_eng_embs    = data["test_eng_embs"]
    test_kor_embs    = data["test_kor_embs"]

    baseline_scores  = data["baseline_scores"]
    aligned_scores   = data["aligned_scores"]

    test_eng_sents        = data["test_eng_sents"]
    test_kor_sents        = data["test_kor_sents"]

    paragraph_eng       = data["test_paragraph_eng"]
    paragraph_kor       = data["test_paragraph_kor"]
    eng_idx         = data["test_eng_idx"]
    kor_idx         = data["test_kor_idx"]

    print(f"Loaded best model = {best_name} (acc={best_acc:.2%})")

    if best_name == "MLP_Pro" or best_name == "MLP_Rand":
        best_model = MLP().to(DEVICE)
    elif best_name == "Res_Pro" or best_name == "Res_Rand":
        best_model = ResidualMLP().to(DEVICE)
    elif best_name == "Linear_Pro" or best_name == "Linear_Rand":
        best_model = LinearMap().to(DEVICE)
    else:
        raise ValueError("Unknown model type")

    # weight load
    best_model.load_state_dict(
        torch.load(INPUT.replace(".npz","_best_model.pt"), map_location=DEVICE)
    )

    best_model.eval()


    # 1. Quality Test
    top_indices_local = perform_nuance_analysis(
        test_eng_sents, test_kor_sents,
        baseline_scores,
        aligned_scores,
        top_k=3
    )

    # 2. Create arrow plot 
    visualize_alignment_arrows(test_eng_embs, test_kor_embs, proj_eng_embs)


    # 3. Nuance Perturbation Test

    # run_nuance_test_without(eng_sents, kor_sents, best_model, top_indices)
    run_nuance_test(
        eng_sents = test_eng_sents,
        kor_sents = test_kor_sents,
        paragraph_eng = paragraph_eng,
        paragraph_kor = paragraph_kor,
        eng_idx = eng_idx,
        kor_idx = kor_idx,
        best_model = best_model,
        indices = top_indices_local
    )

    print("\n Finished Printing Trained Adapter Result ")


OpenAIError: The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable