In [None]:
import json
from pathlib import Path
from collections import Counter

from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

In [None]:
ROOT_DIR = Path.cwd().parent

In [None]:
df_qiita = pd.read_json(ROOT_DIR / "data/qiita.json")
df_zenn = pd.read_json(ROOT_DIR / "data/zenn.json")

In [None]:
df_raw = pd.concat([df_qiita, df_zenn], axis=0)

In [None]:
df_raw.sample(3)

In [None]:
all_tags = []

for tags in df_raw["tags"]:
    all_tags += tags

In [None]:
len(set(all_tags))

In [None]:
tag_counter = Counter(all_tags)

In [None]:
categories = [tag for tag, _ in tag_counter.most_common(300)]
categories_id = {tag: idx for idx, tag in enumerate(categories)}

In [None]:
df = pd.DataFrame(
    columns=["title", "body", "category", "category_id"],
)
for row in df_raw.itertuples():
    category = []
    for tag in row.tags:
        if tag in categories:
            category.append(tag)
    df.at[row.Index, "category"] = category
    df.at[row.Index, "category_id"] = [categories_id[tag] for tag in category]
    df.at[row.Index, "title"] = row.title
    df.at[row.Index, "body"] = row.cleansed_content

In [None]:
df.sample(3)

In [None]:
train_df, eval_df = train_test_split(df, train_size=0.8)
eval_df, test_df = train_test_split(eval_df, train_size=0.5)

In [None]:
class ArticleDataset(Dataset):
    def __init__(self, df):
        self.features = [
            {
                'title': row.title,
                'category_id': row.category_id
            } for row in tqdm(df.itertuples(), total=df.shape[0])
        ]
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx]

In [None]:
train_dataset = ArticleDataset(train_df)
eval_dataset = ArticleDataset(eval_df)
test_dataset = ArticleDataset(test_df)

In [None]:
class ArticleCollator():
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __call__(self, examples):
        examples = {
            'title': list(map(lambda x: x['title'], examples)),
            'category_id': list(map(lambda x: x['category_id'], examples))
        }
        
        encodings = self.tokenizer(examples['title'],
                                   padding=True,
                                   truncation=True,
                                   max_length=self.max_length,
                                   return_tensors='pt')
        encodings['category_id'] = torch.tensor(examples['category_id'])
        return encodings

In [None]:
tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
article_collator = ArticleCollator(tokenizer)