In [20]:
import json
from pathlib import Path
from collections import Counter

from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

In [21]:
ROOT_DIR = Path.cwd().parent

In [22]:
df_qiita = pd.read_json(ROOT_DIR / "data/qiita.json")
df_zenn = pd.read_json(ROOT_DIR / "data/zenn.json")

In [23]:
df_raw = pd.concat([df_qiita, df_zenn], axis=0)

In [24]:
df_raw.sample(3)

Unnamed: 0,title,tags,cleansed_content,url
401,Androidアプリエンジニアがよく見るサイトとチュートリアル,"[Android, Androidアプリ]",Android アプリエンジニアがよく見るであろうサイトや、Android アプリ開発を勉強...,https://qiita.com/Nabe1216/items/792914ae6803c...
631,「相関係数よ、安らかに眠れ」～ 新たなスコアPPSの紹介,"[Python, EDA, PPS]",### ～8080Labsが考案した新スコア「PPS」は、相関係数を凌駕できるだろうか？！...,https://qiita.com/hima2b4/items/2b500886512f14...
738,スケールする要求を支える仕様の「意図」と「直交性」,"[DDD, 要件定義]",# はじめに どんなソフトウェアエンジニアも拡張しやすくメンテナンスしやすいソフトウェアを...,https://qiita.com/hirokidaichi/items/61ad129ea...


In [25]:
all_tags = []

for tags in df_raw["tags"]:
    all_tags += tags

In [26]:
len(set(all_tags))

1616

In [27]:
tag_counter = Counter(all_tags)

In [28]:
categories_unique = [tag for tag, _ in tag_counter.most_common(300)]
categories_ids = {tag: idx for idx, tag in enumerate(categories_unique)}

In [29]:
df = pd.DataFrame(
    columns=["title", "body", "categories", "category_ids"],
)
idx = 0
for row in df_raw.itertuples():
    categories = []
    for tag in row.tags:
        if tag in categories_unique:
            categories.append(tag)
    df.at[row.Index, "categories"] = categories
    df.at[row.Index, "category_ids"] = [categories_ids[tag] for tag in categories]
    df.at[row.Index, "title"] = row.title
    df.at[row.Index, "body"] = row.cleansed_content

In [30]:
df.sample(3)

Unnamed: 0,title,body,categories,category_ids
406,開発プロセスはアウトカムの質を高めるためにつくられるべき,## メトリクスは指標であって方針ではない Four Keysとか指標として計測して観察する...,"[スクラム, idea]","[108, 6]"
180,Dagger Go SDK でポータブルな CI/CD パイプラインを構築する,CI/CD Advent Calendar 2022 の 20 日目の記事です。 < 先日...,"[Go, tech]","[23, 0]"
681,良い質問をして自己成長に繋げるためのあれこれ,# 質問に関することで悩む若手エンジニアは多い * 「ちゃんと調べてから質問した？」 * ...,"[質問, コミュニケーション, 初心者]","[212, 71, 3]"


In [31]:
train_df, eval_df = train_test_split(df, train_size=0.8)
eval_df, test_df = train_test_split(eval_df, train_size=0.5)

In [13]:
class ArticleDataset(Dataset):
    def __init__(self, df):
        self.features = [
            {"title": row.title, "category_ids": row.category_ids}
            for row in tqdm(df.itertuples(), total=df.shape[0])
        ]

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx]

In [14]:
train_dataset = ArticleDataset(train_df)
eval_dataset = ArticleDataset(eval_df)
test_dataset = ArticleDataset(test_df)

100%|██████████| 782/782 [00:00<00:00, 137316.66it/s]
100%|██████████| 98/98 [00:00<00:00, 170982.44it/s]
100%|██████████| 98/98 [00:00<00:00, 135926.52it/s]


In [15]:
class ArticleCollator:
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, examples):
        examples = {
            "title": list(map(lambda x: x["title"], examples)),
            "category_id": list(map(lambda x: x["category_id"], examples)),
        }

        encodings = self.tokenizer(
            examples["title"],
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        encodings["category_id"] = torch.tensor(examples["category_id"])
        return encodings

In [18]:
tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
article_collator = ArticleCollator(tokenizer, max_length=20000)

In [19]:
loader = DataLoader(
    train_dataset,
    collate_fn=article_collator,
    batch_size=8,
    shuffle=True,
)

batch = next(iter(loader))

for k, v in batch.items():
    print(k, v.shape)

ValueError: expected sequence of length 4 at dim 1 (got 5)