In [2]:
!pip install -q fasttext jieba_hant

[0m

In [3]:
!pip install -q -U gensim datasets  # pytorch-nlp torchtext torchdata

[0m

In [4]:
!pip install -q -U tqdm pysnooper

[0m

# FastText Pytorch

In [2]:
from collections import Counter
from gensim.models.fasttext_inner import (
    compute_ngrams,
    compute_ngrams_bytes,
    ft_hash_bytes,
)
import torch
import pickle
# from collections.abc import Iterable  # python >= 3.9
from typing import Iterable  # python < 3.9
from typing import Union, Optional, List, Dict

DEFAULT_RESERVED_TOKENS = ['<pad>', '<unk>', '</s>']
DEFAULT_PAD_INDEX = 0
DEFAULT_UNK_INDEX = 1
DEFAULT_EOS_INDEX = 2

def _tokenize(s):
    return s.split()

def _detokenize(t):
    return ' '.join(t)

class FastTextEncoder():
    """
    Implement char ngram and word ngram.

    Differences:
    - no EOS
    - do not shrink vocab when reach (0.75 * max_vocab_size)
    """
    def __init__(
            self,
            texts: Union[Iterable[List[str]], List[List[str]]],
            min_count=1,
            max_vocab_size=None,
            min_n=0,
            max_n=0,
            word_ngrams=1,
            bucket=2000000):
        """
        Build tokenizer. Require segmented sentences.
        """
        # defaults
        self.reserved_tokens = DEFAULT_RESERVED_TOKENS
        self.pad_index = DEFAULT_PAD_INDEX
        self.unk_index = DEFAULT_UNK_INDEX
        self.eos_index = DEFAULT_EOS_INDEX
        
        self.min_count = min_count
        self.max_vocab_size = max_vocab_size
        self.min_n = min_n
        self.max_n = max_n
        self.word_ngrams = word_ngrams
        self.bucket = bucket
        
        self.tokens = Counter()

        for n, sequence in enumerate(texts):
            self.tokens.update([x.strip() for x in sequence if x.strip() not in ("", None)])
        
        self.corpus_count = n + 1
        self.corpus_total_words = sum(self.tokens.values())
        
        self.index_to_token = self.reserved_tokens.copy()
        self.token_to_index = {token: index for index, token in enumerate(self.reserved_tokens)}
        for token, count in self.tokens.items():
            if count >= self.min_count:
                self.index_to_token.append(token)
                self.token_to_index[token] = len(self.index_to_token) - 1
        
        self.initNgrams()
        
        # release memory
        self.tokens = Counter()
    
    @property
    def vocab(self) -> List:
        return self.index_to_token

    @property
    def vocab_size(self) -> int:
        return len(self.index_to_token)
    
    def initNgrams(self):
        """
        Initialize char ngrams for all vocabularies.
        """
        self.index_to_ngram = [[]] * len(self.index_to_token)  # caveats: all items are actually same object
        if self.max_n >= self.min_n and self.min_n > 0 and self.bucket > 0:
            for n, v in enumerate(self.index_to_token):
                # exclude preserved words
                if n >= len(self.reserved_tokens):
                    self.index_to_ngram[n] = self._compute_ngram_hashes(v, self.min_n, self.max_n, self.bucket)
    
    def decode(self, ids: List[int]) -> List[str]:
        vector = [self.index_to_token[i] for i in ids]
        return vector
    
    def encode(
            self,
            sequence: List[str],
            append_eos: bool=False,
            unk_to_zero: bool=False,
            remove_unk: bool=False) -> List[int]:
        vector = [self.token_to_index.get(token, self.unk_index) for token in sequence]
        if append_eos:
            vector.append(self.eos_index)
        if unk_to_zero:
            vector = [self.pad_index if x == self.unk_index else x for x in vector]
        if remove_unk:
            # but we will need correct length to do ngram hashing
            vector = [x for x in vector if x != self.unk_index]
        return vector
    
    def encode_ngram(self, sequence: List[str], input_ids: List[int]) -> List[int]:
        # char ngram
        char_ngrams = []
        for w, i in zip(sequence, input_ids):
            if i != self.eos_index:
                if i != self.unk_index:
                    char_ngrams += self.index_to_ngram[i]
                else:
                    # oov
                    char_ngrams += self._compute_ngram_hashes(w, self.min_n, self.max_n, self.bucket)
        
        # word ngram
        # we do not care oov, just hash it
        hashes = [self._hash(x.encode("UTF-8")) for x in sequence]
        word_ngrams = self._compute_wordNgram_hashes(hashes, self.word_ngrams, self.bucket)
        return char_ngrams + word_ngrams

    def _encode_ft(self, sequence: List[str]) -> List[int]:
        """
        FastText-like output.
        Ignore UNK. Ngram id goes after end of wid.
        """
        list_wids = []
        list_hashes = []
        list_ngrams = []  # here we separate this out
        for w in sequence:
            wid = self.token_to_index.get(w, self.unk_index)
            h = self._hash(w.encode("UTF-8"))
            list_hashes.append(h)
            if wid != self.unk_index:
                list_wids.append(wid)
                list_ngrams += self.index_to_ngram[wid]
            else:
                # oov
                list_ngrams += self._compute_ngram_hashes(w, self.min_n, self.max_n, self.bucket)
        # word ngrams
        list_ngrams += self._compute_wordNgram_hashes(list_hashes, self.word_ngrams, self.bucket)
        
        list_ngrams = [x + self.vocab_size for x in list_ngrams]
        return list_wids + list_ngrams
    
    def _batch_encode(
            self,
            texts: Union[Iterable[List[str]], List[List[str]]],
            # padding=False,
            # truncation=False,
            # max_length=None,
            return_tensors: Optional[str]=None,
            unk_to_zero: bool=False) -> Dict:
        """
        Separate wid and ngram.
        """
        input_ids = [self.encode(x, unk_to_zero=unk_to_zero) for x in texts]
        input_ngrams = [self.encode_ngram(x, y) for x, y in zip(texts, input_ids)] ### what if unk == 0???
        
        if return_tensors == "pt":
            input_ids = self._pad_sequence_pt(input_ids)
            input_ngrams = self._pad_sequence_pt(input_ngrams)
        
        output = {
            "input_ids": input_ids,
            "input_ngrams": input_ngrams
        }
        return output
    
    def _batch_encode_ft(
            self,
            texts: Union[Iterable[List[str]], List[List[str]]],
            # padding=False,
            # truncation=False,
            # max_length=None,
            return_tensors: Optional[str]=None) -> Dict:
        input_ids = [self._encode_ft(x) for x in texts]
        len_ids = [len(x) for x in input_ids]
        if return_tensors == "pt":
            input_ids = self._pad_sequence_pt(input_ids)
            len_ids = torch.IntTensor(len_ids)
        output = {
            "input_ids": input_ids,
            "len_ids": len_ids
        }
        return output
    
    def __call__(
            self,
            texts: Union[Iterable[List[str]], List[List[str]]],
            ft_mode: bool=False,
            **kwargs) -> Dict:
        if ft_mode:
            return self._batch_encode_ft(texts, **kwargs)
        else:
            return self._batch_encode(texts, **kwargs)
    
    def __contains__(self, word):
        return word in self.vocab
    
    def get_vector(self, embedding: torch.nn.Embedding, word: str):
        wid = self.token_to_index.get(word, self.unk_index)
        if wid != self.unk_index:
            return embedding.weight[wid]
        elif self.max_n >= self.min_n and self.min_n > 0 and self.bucket > 0:
            # oov and ngram is enabled
            ngrams = self._compute_ngram_hashes(word, self.min_n, self.max_n, self.bucket)
            if len(ngrams) == 0:
                return embedding.weight[0]  # PAD and it's zeros
            ngrams = [x + self.vocab_size for x in ngrams]
            return torch.mean(embedding.weight[ngrams], dim=0)
        else:
            raise KeyError("cannot calculate vector for OOV word without ngrams")
    
    def get_sentence_vector(self, embedding: torch.nn.Embedding, sentence: List[str]):
        input_ids = self._encode_ft(sentence)
        return torch.mean(embedding.weight[input_ids], dim=0)
    
    def save_vocab(self, fout):
        with open(fout, "wb") as f:
            pickle.dump(self.vocab, f)
    
    @classmethod
    def load_vocab(cls, fin, **kwargs):
        with open(fin, "rb") as f:
            vocab = pickle.load(f)
        tokenizer = cls(vocab, **kwargs)
        tokenizer.corpus_count = None
        tokenizer.corpus_total_words = None
        return tokenizer
    
    @classmethod
    def _hash(cls, bytez: bytes) -> int:
        return ft_hash_bytes(bytez)
    
    @classmethod
    def _compute_ngrams(cls, word: str, min_n, max_n) -> List[str]:
        if max_n >= min_n and min_n > 0:
            return compute_ngrams(word, min_n, max_n)
        else:
            return []
    
    @classmethod
    def _compute_ngrams_bytes(cls, word: str, min_n, max_n) -> List[bytes]:
        if max_n >= min_n and min_n > 0:
            return compute_ngrams_bytes(word, min_n, max_n)
        else:
            return []
    
    @classmethod
    def _compute_ngram_hashes(cls, word: str, min_n, max_n, bucket) -> List[int]:
        if max_n >= min_n and min_n > 0 and bucket > 0:
            hashes = cls._compute_ngrams_bytes(word, min_n, max_n)
            return [cls._hash(x) % bucket for x in hashes]
        else:
            return []
    
    @classmethod
    def _compute_wordNgram_hashes(cls, word_hashes: List[int], word_ngrams: int, bucket: int) -> List[int]:
        """
        ref: https://github.com/facebookresearch/fastText/blob/a20c0d27cd0ee88a25ea0433b7f03038cd728459/src/dictionary.cc#L312
        """
        wordNgram_hashes = []
        if bucket > 0:
            for i in range(len(word_hashes)):
                h = word_hashes[i]
                for j in range(i+1, min(len(word_hashes), i+word_ngrams)):
                    h = h * 116049371 + word_hashes[j]
                    wordNgram_hashes.append(h % bucket)
        return wordNgram_hashes
    
    @classmethod
    def _pad_sequence_pt(cls, seq):
        return torch.nn.utils.rnn.pad_sequence([torch.IntTensor(x) for x in seq], batch_first=True)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch
import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
from dataclasses import dataclass

@dataclass
class FastTextClassifierConfig():
    vocab_size: int = 0
    min_n: int = 2
    max_n: int = 5
    word_ngrams: int = 1
    dim: int = 100
    bucket: int = 2000000
    lr: float = 0.1
    lrUpdateRate: int = 100  # update lr by n tokens, here we update by batch
    num_classes: int = 1
    epoch: int = 5
    batch_size: int = 256
    
class FastTextClassifier(nn.Module):
    def __init__(self, config):
        super(FastTextClassifier, self).__init__()
        self.embedding = nn.Embedding(config.vocab_size + config.bucket, config.dim, padding_idx=0)
        self.fc1 = nn.Linear(config.dim, config.num_classes)
    
    def forward(self, input_ids):
        ntokens = torch.count_nonzero(input_ids)
        output = self.embedding(input_ids)
        output = torch.sum(output, 1) / ntokens.view([-1, 1])
        output = self.fc1(output)
        return output

In [4]:
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# @pysnooper.snoop()
def train():
    model.train()
    criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.SGD(model.parameters(), config.lr)
    # scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0, total_iters=config.epoch * len(trainloader))
    optimizer = torch.optim.Adam(model.parameters(), config.lr)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0, total_iters=config.epoch)
    
    EPOCH_BEG = 1
    for epoch in range(EPOCH_BEG, config.epoch+1):
        running_loss = 0.
        for data in (pbar:=tqdm(trainloader, desc=f"[epoch {epoch:>2}]")):
            # pbar.set_postfix({"lr": scheduler.get_last_lr()[0]})
            labels, input_ids = data
            optimizer.zero_grad()
            output = model(input_ids=input_ids)
            loss = criterion(output, labels)
            pbar.set_postfix({"loss": loss.item(), "lr": scheduler.get_last_lr()[0]})
            running_loss += loss.item()
            loss.backward()
            optimizer.step()
            # scheduler.step()  # update by batch
        print(f"""[epoch {epoch:>2}] train loss: {running_loss/len(trainloader):.3f} lr: {scheduler.get_last_lr()[0]:.3f}""")
        scheduler.step()  # update by epoch

def test(dataloader, disable_progress=True):
    # model.evalute()
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        running_loss = 0.
        labels_true = []
        labels_pred = []
        labels_prob = []
        for data in tqdm(dataloader, disable=disable_progress):
            labels, input_ids = data
            output = model(input_ids=input_ids)
            loss = criterion(output, labels)
            running_loss += loss.item()
            output_pred = torch.max(nn.functional.softmax(output, 1), 1)
            labels_pred += output_pred[1].tolist()
            labels_prob += output_pred[0].tolist()
            labels_true += labels.tolist()
    acc = accuracy_score(labels_true, labels_pred)
    p, r, f, _ = precision_recall_fscore_support(labels_true, labels_pred, average="weighted", zero_division=0)  # WAF1-scores
    out = {
        "loss": running_loss/len(dataloader),
        "accuracy": acc,
        "precision": p,
        "recall": r,
        "f1": f
    }
    return out

def evalute(dataloader, disable_progress=True):
    # model.evalute()
    with torch.no_grad():
        labels_pred = []
        labels_prob = []
        for data in tqdm(dataloader, disable=disable_progress):
            _, input_ids = data
            output = model(input_ids=input_ids)
            output_pred = torch.max(nn.functional.softmax(output, 1), 1)
            labels_pred += output_pred[1].tolist()
            labels_prob += output_pred[0].tolist()
    out = {
        "labels_pred": labels_pred,
        "labels_prob": labels_prob
    }
    return out

In [5]:
# from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list = torch.LongTensor([x["label"] for x in batch])
    out = tokenizer([_tokenize(x["text"]) for x in batch], return_tensors="pt", ft_mode=True)
    input_ids = out["input_ids"]
    return label_list.to(device), input_ids.to(device)

In [6]:
from datasets import load_dataset

dataset_name = "ag_news"
train_iter = load_dataset(dataset_name, split="train")
test_iter = load_dataset(dataset_name, split="test")

config = FastTextClassifierConfig(
    num_classes=4,
    batch_size=256,
    lr=0.5,
    min_n=2,
    max_n=6,
    word_ngrams=2,
    dim=10,
    bucket=10000
)

train_corpus = [_tokenize(x) for x in train_iter["text"]]
tokenizer = FastTextEncoder(train_corpus, min_n=config.min_n, max_n=config.max_n, word_ngrams=config.word_ngrams, bucket=config.bucket)
config.vocab_size = tokenizer.vocab_size

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trainloader = DataLoader(train_iter, batch_size=config.batch_size, shuffle=False, collate_fn=collate_batch)
testloader = DataLoader(test_iter, batch_size=config.batch_size, shuffle=False, collate_fn=collate_batch)

model = FastTextClassifier(config)
model.to(device)

Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)
Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


FastTextClassifier(
  (embedding): Embedding(198113, 10, padding_idx=0)
  (fc1): Linear(in_features=10, out_features=4, bias=True)
)

In [7]:
torch.set_num_threads(2)

In [8]:
%%time
train()

[epoch  1]:   0%|          | 0/469 [00:00<?, ?it/s]

[2022-09-02 09:41:39.422 pytorch-1-10-cpu-py38-ml-t3-medium-11e7d720a60c1349ee834b037bd3:1296 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None




[2022-09-02 09:41:39.614 pytorch-1-10-cpu-py38-ml-t3-medium-11e7d720a60c1349ee834b037bd3:1296 INFO profiler_config_parser.py:111] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.


[epoch  1]: 100%|██████████| 469/469 [01:08<00:00,  6.82it/s, loss=0.354, lr=0.5]


[epoch  1] train loss: 0.487 lr: 0.500


[epoch  2]: 100%|██████████| 469/469 [01:06<00:00,  7.06it/s, loss=0.244, lr=0.4]


[epoch  2] train loss: 0.250 lr: 0.400


[epoch  3]: 100%|██████████| 469/469 [01:05<00:00,  7.14it/s, loss=0.179, lr=0.3] 


[epoch  3] train loss: 0.199 lr: 0.300


[epoch  4]: 100%|██████████| 469/469 [01:06<00:00,  7.08it/s, loss=0.155, lr=0.2] 


[epoch  4] train loss: 0.164 lr: 0.200


[epoch  5]: 100%|██████████| 469/469 [01:05<00:00,  7.13it/s, loss=0.137, lr=0.1] 

[epoch  5] train loss: 0.141 lr: 0.100
CPU times: user 6min 51s, sys: 12.1 s, total: 7min 3s
Wall time: 5min 32s





In [9]:
test(testloader)

{'loss': 0.27932138641675314,
 'accuracy': 0.9089473684210526,
 'precision': 0.9092335200569491,
 'recall': 0.9089473684210526,
 'f1': 0.908725652790745}

TODO:
- remove `<unk>` in output? 最後平均時unk會被加進去？
- stochastic gradient descent and a linearly decaying learning rate. 為什麼沒用？
- Hierarchical softmax
- ngram index 排除0 `<pad>`，加offset？
- 吃跟ft 一樣的input format？
- 

In [16]:
emb = nn.Embedding.from_pretrained(model.embedding.weight, padding_idx=0)

In [17]:
emb

Embedding(198113, 10, padding_idx=0)

In [14]:
model.embedding.

torch.Size([198113, 10])

# FastText original

In [2]:
import fasttext

In [3]:
from datasets import load_dataset

dataset_name = "ag_news"
train_iter = load_dataset(dataset_name, split="train")
test_iter = load_dataset(dataset_name, split="test")

def _to_file(it, file):
    with open(file, "w") as f:
        for i in it:
            print(f"""__label__{i["label"]} {i["text"]}""", file=f)

_to_file(train_iter, "ft_train.txt")
_to_file(test_iter, "ft_test.txt")

  from .autonotebook import tqdm as notebook_tqdm
Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)
Using custom data configuration default
Reusing dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


In [6]:
%%time
model = fasttext.train_supervised(
    input="ft_train.txt",
    dim=10,
    epoch=5,
    wordNgrams=2,
    minn=2,
    maxn=6,
    lr=0.5,
    bucket=10000,
    # autotuneValidationFile="labels.test.txt",
    # autotuneDuration=600
)

CPU times: user 14.5 s, sys: 172 ms, total: 14.7 s
Wall time: 15.7 s


In [7]:
model.test("ft_test.txt")

(7600, 0.9030263157894737, 0.9030263157894737)