In [1]:
from bert import BertModel
from distill_emb import DistillEmbSmall
from config import DistilEmbConfig
import torch
from transformers import AutoTokenizer, RwkvConfig, RwkvModel, AutoModel
from tokenizer import CharTokenizer
from knn_classifier import KNNTextClassifier
from data_loader import load_sentiment
from data_loader import load_news_dataset
import pandas as pd
from retrieval import build_json_pairs, top1_accuracy
import os

In [2]:
num_input_chars=12
tokenizer = CharTokenizer(charset_file_path='tokenizer/charset.json',
                          max_word_length=num_input_chars)

In [3]:
config = DistilEmbConfig(
    vocab_size=30522,
    hidden_size=768,
    num_hidden_layers=9,
    num_attention_heads=8,
    intermediate_size=3072,
    max_position_embeddings=512,
    type_vocab_size=2,
    pad_token_id=0,
    position_embedding_type="absolute",
    use_cache=True,
    classifier_dropout=None,
    embedding_type="distill",  # 'distilemb', 'fasttext'
    encoder_type='lstm',
    num_input_chars=num_input_chars,  # number of characters in each token
    char_vocab_size=tokenizer.char_vocab_size
)
distill_emb = DistillEmbSmall(config)
path = "logs/distill_emb_v0/distill_emb_v0-epoch=510-epoch_val_loss=0.27.ckpt"
if os.path.exists(path):
    state_dict = torch.load(path, map_location='cpu')['state_dict']
    # remove 'model.' prefix from state_dict keys
    state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
    distill_emb.load_state_dict(state_dict)
else:
    print(f"Model checkpoint {path} not found. Please check the path.")

distill_emb = distill_emb.to('cuda').eval()

In [4]:
def eval_model(model, tokenizer):
    model.eval()
    with torch.no_grad():
        classifier = KNNTextClassifier(tokenizer, model=model)

        df, classes = load_sentiment()
        # Sample equal amount for each language in the 'lang' column
        min_count = min(df['lang'].value_counts().min(), 250)
        sent_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
        sent_train_df = sent_df.sample(frac=0.8, random_state=42)
        sent_test_df = sent_df.drop(sent_train_df.index)
        print(f"train shape: {sent_train_df.shape}, test shape: {sent_test_df.shape}")
        sent_f1, sent_acc, sent_per_lang, sent_test_df = classifier.classifiy(train_df=sent_train_df, test_df=sent_test_df, k=5, batch_size=32, model=None, tokenizer=None)

        df, classes = load_news_dataset()
        min_count = min(df['lang'].value_counts().min(), 250)
        news_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
        news_train_df = news_df.sample(frac=0.8, random_state=42)
        news_test_df = news_df.drop(news_train_df.index)
        print(f"train shape: {news_train_df.shape}, test shape: {news_test_df.shape}")
        news_f1, news_acc, news_per_lang, news_test_df = classifier.classifiy(train_df=news_train_df, test_df=news_test_df, k=5, batch_size=32, model=None, tokenizer=None)

        df = pd.read_json('downstream-data/news_result.json')
        d = df.to_dict(orient='records')
        ret_acc, _, ret_per_lang = top1_accuracy(d, batch_size=32, model=model, tokenizer=tokenizer)
    
    return sent_acc, news_acc, ret_acc

In [5]:
eval_model(distill_emb, tokenizer)
# (0.45428571428571435, 0.51125, 0.1453125)

Loaded 105862 rows from sentiment.parquet columns Index(['text', 'label', 'lang', 'split'], dtype='object')
train shape: (2800, 4), test shape: (700, 4)


  sent_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Loaded 30809 rows from masakhanews.parquet columns Index(['label', 'headline', 'text', 'headline_text', 'url', 'lang', 'split'], dtype='object')
train shape: (3200, 7), test shape: (800, 7)


  news_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Evaluating:   0%|          | 0/3200 [00:00<?, ?it/s]

(0.44999999999999996, 0.5387499999999998, 0.174375)

In [6]:
class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, **kwargs):
        embs = self.model(**kwargs).last_hidden_state
        return embs

In [7]:
model_name = "castorini/afriberta_small"
tok = AutoTokenizer.from_pretrained(model_name)
xmodel = AutoModel.from_pretrained(model_name)
xmodel = Wrapper(xmodel)
eval_model(xmodel, tok)

Some weights of XLMRobertaModel were not initialized from the model checkpoint at castorini/afriberta_small and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  sent_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Loaded 105862 rows from sentiment.parquet columns Index(['text', 'label', 'lang', 'split'], dtype='object')
train shape: (2800, 4), test shape: (700, 4)
Loaded 30809 rows from masakhanews.parquet columns Index(['label', 'headline', 'text', 'headline_text', 'url', 'lang', 'split'], dtype='object')
train shape: (3200, 7), test shape: (800, 7)


  news_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Evaluating:   0%|          | 0/3200 [00:00<?, ?it/s]

(0.47428571428571425, 0.6525, 0.244375)

In [8]:
model_name = "castorini/afriberta_large"
tok = AutoTokenizer.from_pretrained(model_name)
xmodel = AutoModel.from_pretrained(model_name)
xmodel = Wrapper(xmodel)
eval_model(xmodel, tok)

Some weights of XLMRobertaModel were not initialized from the model checkpoint at castorini/afriberta_large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded 105862 rows from sentiment.parquet columns Index(['text', 'label', 'lang', 'split'], dtype='object')
train shape: (2800, 4), test shape: (700, 4)


  sent_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Loaded 30809 rows from masakhanews.parquet columns Index(['label', 'headline', 'text', 'headline_text', 'url', 'lang', 'split'], dtype='object')
train shape: (3200, 7), test shape: (800, 7)


  news_df = df.groupby('lang').apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)


Evaluating:   0%|          | 0/3200 [00:00<?, ?it/s]

(0.4671428571428571, 0.6612500000000001, 0.285)