In [None]:
!pip install git+https://github.com/KornWtp/ConGen.git

In [None]:
import os
import logging
from datetime import datetime
import io
import math
import numpy as np
import random
from glob import glob 
import pickle

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sentence_transformers import models
from sentence_transformers import LoggingHandler, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

from sentence_transformers_congen import SentenceTransformer, losses

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

SEED = 1000
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

# ตั้งค่า parameters
Best paramerters: https://github.com/KornWtp/ConGen#parameters-1

In [None]:
max_seq_length = 128
train_batch_size = 128
num_epochs = 20
early_stopping_patience = 7
queue_size = 65536
student_temp = 0.5
teacher_temp = 0.5
learning_rate = 1e-4

# โหลด Teacher model

In [None]:
teacher_model_name_or_path = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
teacher_model = SentenceTransformer(teacher_model_name_or_path)

# โหลด dataset
Link: https://drive.google.com/file/d/1u7kCk9xpTfQkxpJ0zfILpo9SR5KNMfaj/view?usp=share_link

In [None]:
train_data_path = "back_translated_mt_scb_2020.txt"
all_pairs = open(train_data_path, mode="rt", encoding="utf-8").readlines()
all_pairs = [sample.strip().split('\t') for sample in all_pairs]
# Two lists of sentences
sents1 = [p[0] for p in all_pairs]
sents2 = [p[1] for p in all_pairs]


try:
	filename = open("data/sents1_encoded.pkl", "rb")
	sents1_encoded = pickle.load(filename)
	filename.close()
except:
	sents1_encoded = teacher_model.encode(sents1, convert_to_tensor=True, normalize_embeddings=True, device=device)
	filename = 'data/sents1_encoded.pkl'
	pickle.dump(sents1_encoded, open(filename, 'wb'), protocol=4)
teacher_dimension = sents1_encoded.shape[1]

# โหลด Student model

In [None]:
student_model_name_or_path = "airesearch/wangchanberta-base-att-spm-uncased"
student_word_embedding_model = models.Transformer(student_model_name_or_path, max_seq_length=max_seq_length)
student_dimension = student_word_embedding_model.get_word_embedding_dimension()
student_pooling_model = models.Pooling(student_dimension)
dense_model = models.Dense(in_features=student_dimension, out_features=teacher_dimension, activation_function=nn.Tanh())
student_model = SentenceTransformer(modules=[student_word_embedding_model, student_pooling_model, dense_model])

# สร้าง instance queue
instance queue คืออะไร? รายละเอียดอยู่ใน https://github.com/KornWtp/ConGen/blob/main/ConGen__Unsupervised_Control_and_Generalization_Distillation_For_Sentence_Representation.pdf Section ที่ 3.2

In [None]:
text_in_queue = np.random.RandomState(16349).choice(sents1, queue_size, replace=False)
train_samples = []
instance_queue = []
text_in_q_set = set(text_in_queue)
for s1, s2, s1_encoded in zip(sents1, sents2, sents1_encoded): 
	if s1 not in text_in_q_set:
		train_samples.append(InputExample(texts=[s1, s2], label=s1_encoded))
	else:
		instance_queue.append(s1)
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)

instance_queue_encoded = teacher_model.encode(instance_queue, 
									convert_to_tensor=True,
									normalize_embeddings=True, 
									device=device)

training_loss = losses.ConGenLoss(instanceQ_encoded=instance_queue_encoded,  
								model=student_model,
								student_temp=student_temp, 
								teacher_temp=teacher_temp)

del instance_queue, sents1_encoded, teacher_model, instance_queue_encoded					

warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)  # 10% of train data for warm-up


# Train model

In [None]:
student_model.fit(train_objectives=[(train_dataloader, training_loss)],
        epochs=num_epochs,
        warmup_steps=warmup_steps,
        output_path="congen-model-thai",
        optimizer_params={"lr": learning_rate, 'eps': 1e-6, 'correct_bias': False},
        use_amp=True,
        save_best_model=True,
        early_stopping_patience=early_stopping_patience)