In [None]:
from bert import BertModel
from distill_emb import DistillEmbSmall, DistillEmb
from config import DistillModelConfig, DistillEmbConfig
import torch
from transformers import AutoTokenizer, RwkvConfig, RwkvModel, AutoModel
from tokenizer import CharTokenizer
from knn_classifier import KNNTextClassifier
from data_loader import load_sentiment
from data_loader import load_news_dataset
import pandas as pd
from retrieval import build_json_pairs, top1_accuracy
import os

In [None]:
num_input_chars=12

In [None]:
tokenizer = CharTokenizer(charset_file_path='tokenizer/charset.json',
                          max_word_length=num_input_chars)

In [None]:
config = DistillModelConfig(
    vocab_size=30522,
    hidden_size=768,
    num_hidden_layers=9,
    num_attention_heads=8,
    intermediate_size=3072,
    max_position_embeddings=512,
    type_vocab_size=2,
    pad_token_id=0,
    position_embedding_type="absolute",
    use_cache=True,
    classifier_dropout=None,
    embedding_type="distill",  # 'distilemb', 'fasttext'
    encoder_type='bert',
    num_input_chars=num_input_chars,  # number of characters in each token
    char_vocab_size=tokenizer.char_vocab_size,
    distil_config=DistillEmbConfig(
        num_input_chars=tokenizer.max_word_length,  # number of characters in each token
        char_vocab_size=tokenizer.char_vocab_size,
        size="small",
        distill_dropout=0.1,
    )
)
model = BertModel(config)
# input ids with (B, S, N)
char_input = torch.randint(0, config.num_input_chars, (1, 10, config.num_input_chars))
# input ids with (B, S, N)
print("char_input shape:", char_input.shape)
inputs = {
    "input_ids": char_input,
    "attention_mask":torch.tensor([[1] * char_input.size(1)]),  # attention mask for each token
    "token_type_ids": torch.tensor([[0] * char_input.size(1)]),  # token type ids for each token
}
outputs = model(**inputs)

In [None]:
outputs[0].shape

In [None]:
distill_emb = DistillEmb(config.distil_config)
path = "logs/distill_emb_v0/distill_emb_v0-epoch=95-epoch_val_loss=0.06.ckpt"
if os.path.exists(path):
    state_dict = torch.load(path, map_location='cpu')['state_dict']
    # remove 'model.' prefix from state_dict keys
    state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
    distill_emb.load_state_dict(state_dict)
else:
    print(f"Model checkpoint {path} not found. Please check the path.")

In [None]:
distill_emb

In [None]:
out = tokenizer('hello world', add_special_tokens=False, return_tensors='pt')

In [None]:
distill_emb(out['input_ids'][0]).shape

In [None]:
distill_emb = distill_emb.to('cuda').eval()

In [None]:
classifier = KNNTextClassifier(tokenizer, model=distill_emb)

In [None]:
df, classes = load_sentiment()

In [None]:
train_df = df.sample(1000, random_state=42)
test_df = df.drop(train_df.index).sample(100, random_state=42)

In [None]:
classifier.classifiy(train_df=train_df, test_df=test_df, k=5, batch_size=32, model=None, tokenizer=None)

In [None]:
model_name = "bonadossou/afrolm_active_learning"
tok = AutoTokenizer.from_pretrained(model_name)
xmodel = AutoModel.from_pretrained(model_name)
class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, **kwargs):
        embs = self.model(**kwargs).last_hidden_state
        return embs

wrapper_model = Wrapper(xmodel).to('cuda').eval()
classifier = KNNTextClassifier(tokenizer=tok, model=wrapper_model)
classifier.classifiy(train_df=train_df, test_df=test_df, k=5, batch_size=32, model=wrapper_model, tokenizer=tok)

In [None]:
data, classes = load_news_dataset()

In [None]:
train_df = data.sample(1000, random_state=42)
test_df = data.drop(train_df.index).sample(100, random_state=42)

In [None]:
classifier.classifiy(train_df=train_df, test_df=test_df, k=5, batch_size=32, model=None, tokenizer=None)

In [None]:
# select 10 per language
train_df = data[data['split'] == 'train'].groupby('lang').apply(lambda x: x.sample(200, random_state=42)).reset_index(drop=True)

In [None]:
# langs = ['amh', 'hau', 'ibo', 'lug', 'pcm','yor']
# train_df = train_df[train_df['lang'].isin(langs)].reset_index(drop=True)

In [None]:
train_df['headline'].sample(1).values[0]

In [None]:
# result = build_json_pairs(train_df, model_name="Davlan/afro-xlmr-large",
#                  n_samples=200, m_candidates=100, k_top=9, text_col="text", headline_col="headline")
# # save to json file
# import json
# with open('news_result.json', 'w', encoding='utf-8') as f:
#     json.dump(result, f, indent=4, ensure_ascii=False)  

In [None]:
# result = build_json_pairs(train_df, model_name="Davlan/afro-xlmr-large",
#                  n_samples=200, m_candidates=100, k_top=9, text_col="headline", headline_col="text")
# # save to json file
# import json
# with open('headline_result.json', 'w', encoding='utf-8') as f:
#     json.dump(result, f, indent=4, ensure_ascii=False)  

In [None]:
df = pd.read_json('news_result.json')
d = df.to_dict(orient='records')
top1_accuracy(d, batch_size=32, model=xmodel, tokenizer=tok)
# top1_accuracy(d, batch_size=32, model=distill_emb, tokenizer=tokenizer)

In [None]:
from fasttext_model import FastTextModel
fasttext_model = FastTextModel(file_path='embeddings/afriberta/afriberta.vec')
# fasttext_model.embedding.weight.requires_grad = False  # freeze the weights