In [1]:
from bert import BertModel
from distill_emb import DistillEmb
from config import DistillModelConfig, DistillEmbConfig
import torch
from transformers import AutoTokenizer, RwkvConfig, RwkvModel, AutoModel
from tokenizer import CharTokenizer
from knn_classifier import KNNTextClassifier
from data_loader import load_sentiment
from data_loader import load_news_dataset
import pandas as pd
from retrieval import build_json_pairs, top1_accuracy
import os
from fasttext_model import FastTextModel

In [2]:
num_input_chars=12
tokenizer = CharTokenizer(charset_file_path='tokenizer/charset.json',
                          max_word_length=num_input_chars)

In [3]:
config = DistillModelConfig(
    vocab_size=30522,
    hidden_size=768,
    num_hidden_layers=9,
    num_attention_heads=8,
    intermediate_size=3072,
    max_position_embeddings=512,
    type_vocab_size=2,
    pad_token_id=0,
    position_embedding_type="absolute",
    use_cache=True,
    classifier_dropout=None,
    embedding_type="distill",  # 'distilemb', 'fasttext'
    encoder_type='bert',
    num_input_chars=num_input_chars,  # number of characters in each token
    char_vocab_size=tokenizer.char_vocab_size,
    distil_config=DistillEmbConfig(
        num_input_chars=tokenizer.max_word_length,  # number of characters in each token
        char_vocab_size=tokenizer.char_vocab_size,
        size="small",
        distill_dropout=0.1,
    )
)

In [4]:
def eval_model(model=None, tokenizer=None, pipeline=None):
    # assert that either pipeline or model and tokenizer are provided
    if pipeline is None:
        assert model is not None and tokenizer is not None, "Either pipeline or model and tokenizer must be provided."
    
    with torch.no_grad():
        classifier = KNNTextClassifier(tokenizer=tokenizer, model=model, pipeline=pipeline)

        df, classes = load_sentiment()
        # Sample equal amount for each language in the 'lang' column
        min_count = min(df['lang'].value_counts().min(), 250)
        sent_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
        sent_train_df = sent_df.sample(frac=0.8, random_state=42)
        sent_test_df = sent_df.drop(sent_train_df.index)
        print(f"train shape: {sent_train_df.shape}, test shape: {sent_test_df.shape}")
        sent_f1, sent_acc, sent_per_lang, sent_test_df = classifier.classifiy(train_df=sent_train_df, test_df=sent_test_df, k=5, batch_size=32, model=None, tokenizer=None)

        df, classes = load_news_dataset()
        min_count = min(df['lang'].value_counts().min(), 250)
        news_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
        news_train_df = news_df.sample(frac=0.8, random_state=42)
        news_test_df = news_df.drop(news_train_df.index)
        print(f"train shape: {news_train_df.shape}, test shape: {news_test_df.shape}")
        news_f1, news_acc, news_per_lang, news_test_df = classifier.classifiy(train_df=news_train_df, test_df=news_test_df, k=5, batch_size=32, model=None, tokenizer=None)

        df = pd.read_json('downstream-data/news_result.json')
        d = df.to_dict(orient='records')
        ret_acc, _, ret_per_lang = top1_accuracy(d, batch_size=32, model=model, tokenizer=tokenizer, pipeline=pipeline)
    
    return sent_acc, news_acc, ret_acc

In [None]:
distill_emb = DistillEmb(config.distil_config)
path = "logs/distill_emb_v0/distill_emb_v0-epoch=136-epoch_val_loss=0.27.ckpt"
if os.path.exists(path):
    state_dict = torch.load(path, map_location='cpu')['state_dict']
    # remove 'model.' prefix from state_dict keys
    state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
    distill_emb.load_state_dict(state_dict)
else:
    print(f"Model checkpoint {path} not found. Please check the path.")

distill_emb = distill_emb.to('cuda').eval()
eval_model(distill_emb, tokenizer)
# (0.45428571428571435, 0.51125, 0.1453125)

Model checkpoint logs/distill_emb_v0/distill_emb_v0-epoch=136-epoch_val_loss=0.27x.ckpt not found. Please check the path.
Loaded 105862 rows from sentiment.parquet columns Index(['text', 'label', 'lang', 'split'], dtype='object')
train shape: (2800, 4), test shape: (700, 4)


  sent_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Loaded 30809 rows from masakhanews.parquet columns Index(['label', 'headline', 'text', 'headline_text', 'url', 'lang', 'split'], dtype='object')
train shape: (3200, 7), test shape: (800, 7)


  news_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Evaluating:   0%|          | 0/3200 [00:00<?, ?it/s]

(0.4085714285714286, 0.42375, 0.12375)

In [5]:
fasttext_model = FastTextModel(file_path='embeddings/afriberta/afriberta.vec')
class FastTextTokenizer:

    def __init__(self, word2id):
        self.word2id = word2id
        
    def __call__(self, texts, **kwargs):
        tokens = []
        for text in texts:
            text_tokens = []
            for word in text.split():
                if word in self.word2id:
                    text_tokens.append(self.word2id[word])
                else:
                    text_tokens.append(self.word2id['<unk>'])
            tokens.append(text_tokens)
        return tokens

fasttext_tokenizer = FastTextTokenizer(fasttext_model.word2id)


['885388', '512']
['0.44083', '-0.42965', '0.95103', '-0.36828', '0.13706', '-0.56783', '0.7566', '-0.32185', '-0.47046', '0.54537', '-0.56555', '0.2719', '-0.27293', '-0.44355', '0.34395', '-0.24223', '0.78267', '-0.20981', '0.1417', '-0.43735', '-0.32911', '-0.069809', '0.25748', '0.11058', '-0.069525', '-0.47525', '0.4594', '-0.32668', '-0.15052', '-0.24592', '0.56542', '0.56852', '-0.59562', '-0.088829', '-0.02957', '0.025426', '-0.98067', '-0.64689', '0.37806', '0.23468', '0.21492', '-0.64253', '-0.13397', '-0.15448', '-0.33669', '-0.40539', '-0.42543', '-0.27762', '-0.19928', '0.44423', '0.13137', '0.17243', '0.50332', '-0.0012694', '-0.43177', '-0.29488', '0.7005', '0.12004', '0.39133', '-0.46595', '-0.055164', '-0.436', '-0.06664', '-0.51853', '-0.072684', '0.65205', '0.46941', '-0.89723', '0.14686', '-0.20592', '-0.40505', '-0.01811', '0.097973', '-0.24806', '-0.4467', '-0.2761', '0.61319', '-0.14258', '0.21318', '-0.22695', '0.31721', '0.18999', '-0.22422', '-0.050228', '0.35

In [6]:
def pipeline(texts: list):
    """
    A simple pipeline that takes a list of texts and returns their embeddings.
    """
    tokens: list[list[int]] = fasttext_tokenizer(texts)
    # Pad tokens to the longest sequence in the batch
    max_len = min(max(len(seq) for seq in tokens), 512)  # Limit to 512 for compatibility
    tokens = [seq[:max_len] for seq in tokens]  # Truncate sequences longer than max_len
    pad_id = 0
    tokens = [seq + [pad_id] * (max_len - len(seq)) for seq in tokens]
    # print([len(seq) for seq in tokens])  # Debug: print lengths of token sequences
    attention_mask = [[1] * len(seq) + [0] * (max_len - len(seq)) for seq in tokens]
    
    tokens = torch.tensor(tokens, dtype=torch.long)
    attention_mask = torch.tensor(attention_mask, dtype=torch.long)
    
    with torch.no_grad():
        embs = fasttext_model(tokens, attention_mask, pool=True)
    return embs

eval_model(fasttext_model, fasttext_tokenizer, pipeline=pipeline)

Loaded 105862 rows from sentiment.parquet columns Index(['text', 'label', 'lang', 'split'], dtype='object')
train shape: (2800, 4), test shape: (700, 4)


  sent_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Loaded 30809 rows from masakhanews.parquet columns Index(['label', 'headline', 'text', 'headline_text', 'url', 'lang', 'split'], dtype='object')
train shape: (3200, 7), test shape: (800, 7)


  news_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Evaluating:   0%|          | 0/3200 [00:00<?, ?it/s]

(0.45285714285714296, 0.7137500000000001, 0.2325)

In [8]:
class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, **kwargs):
        embs = self.model(**kwargs).last_hidden_state
        return embs

In [9]:
model_name = "castorini/afriberta_small"
tok = AutoTokenizer.from_pretrained(model_name)
xmodel = AutoModel.from_pretrained(model_name)
xmodel = Wrapper(xmodel)
eval_model(xmodel, tok)

Some weights of XLMRobertaModel were not initialized from the model checkpoint at castorini/afriberta_small and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  sent_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Loaded 105862 rows from sentiment.parquet columns Index(['text', 'label', 'lang', 'split'], dtype='object')
train shape: (2800, 4), test shape: (700, 4)
Loaded 30809 rows from masakhanews.parquet columns Index(['label', 'headline', 'text', 'headline_text', 'url', 'lang', 'split'], dtype='object')
train shape: (3200, 7), test shape: (800, 7)


  news_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Evaluating:   0%|          | 0/3200 [00:00<?, ?it/s]

(0.47428571428571425, 0.6525, 0.244375)

In [10]:
model_name = "castorini/afriberta_large"
tok = AutoTokenizer.from_pretrained(model_name)
xmodel = AutoModel.from_pretrained(model_name)
xmodel = Wrapper(xmodel)
eval_model(xmodel, tok)

Some weights of XLMRobertaModel were not initialized from the model checkpoint at castorini/afriberta_large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  sent_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Loaded 105862 rows from sentiment.parquet columns Index(['text', 'label', 'lang', 'split'], dtype='object')
train shape: (2800, 4), test shape: (700, 4)
Loaded 30809 rows from masakhanews.parquet columns Index(['label', 'headline', 'text', 'headline_text', 'url', 'lang', 'split'], dtype='object')
train shape: (3200, 7), test shape: (800, 7)


  news_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Evaluating:   0%|          | 0/3200 [00:00<?, ?it/s]

(0.4671428571428571, 0.6612500000000001, 0.285)