In [1]:
from fasttext_classifier.model import FastTextClassifier, FastTextClassifierConfig
from fasttext_classifier.encoder import FastTextEncoder
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from pytorch_lightning import Trainer
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor

load dataset

In [2]:
dataset_name = "ag_news"
dataset = load_dataset(dataset_name)

Found cached dataset ag_news (/Users/joseph.lee/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})

configuration

In [4]:
config = FastTextClassifierConfig(
    num_classes=4,
    batch_size=256,
    lr=0.5,
    min_n=2,
    max_n=6,
    word_ngrams=2,
    dim=10,
    bucket=10000,
)

initialize tokenizer

In [5]:
def _tokenize(s):
    return s.split()

def collate_batch(batch):
    label_list = torch.LongTensor([x["label"] for x in batch])
    out = tokenizer(
        [_tokenize(x["text"]) for x in batch], return_tensors="pt", ft_mode=True
    )
    return {
        "label": label_list,
        "input_ids": out["input_ids"]
    }

In [6]:
tokenizer = FastTextEncoder([_tokenize(x) for x in dataset["train"]["text"]], config=config)

In [7]:
config.vocab_size = tokenizer.vocab_size

generate dataloader with tokenizer

In [8]:
trainloader = DataLoader(
    dataset["train"], batch_size=config.batch_size, shuffle=False, collate_fn=collate_batch
)
testloader = DataLoader(
    dataset["test"], batch_size=config.batch_size, shuffle=False, collate_fn=collate_batch
)

initialize model

In [9]:
model = FastTextClassifier(config)

train!

In [10]:
loggers = [
    pl_loggers.TensorBoardLogger(save_dir="./"), 
    pl_loggers.CSVLogger(save_dir="./"),
]
callbacks = [
    LearningRateMonitor(logging_interval='step'),
]

trainer = Trainer(
    accelerator="cpu",
    max_epochs=2,
    logger=loggers,
    callbacks=callbacks,
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [11]:
trainer.fit(model, trainloader, testloader)

Loading `train_dataloader` to estimate number of stepping batches.
  rank_zero_warn(

  | Name       | Type                | Params
---------------------------------------------------
0 | criterion  | CrossEntropyLoss    | 0     
1 | embedding  | Embedding           | 2.0 M 
2 | fc1        | Linear              | 44    
3 | val_acc    | MulticlassAccuracy  | 0     
4 | val_prec   | MulticlassPrecision | 0     
5 | val_recall | MulticlassRecall    | 0     
6 | val_f1     | MulticlassF1Score   | 0     
---------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.925     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [12]:
trainer.validate(model, testloader)

Validation: 0it [00:00, ?it/s]

[{'eval:acc': 0.9027631282806396,
  'eval:precision': 0.902481734752655,
  'eval:recall': 0.9027631282806396,
  'eval:f1score': 0.9024980664253235}]

In [13]:
outs = trainer.predict(model, testloader)

  rank_zero_warn(


Predicting: 469it [00:00, ?it/s]

In [14]:
len(outs)

30

In [15]:
outs[0]

{'label': array([2, 3, 3, 1, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3,
        3, 2, 2, 3, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 0, 0, 0, 1, 2, 0, 3,
        1, 1, 0, 2, 0, 1, 0, 1, 0, 3, 2, 1, 2, 0, 2, 2, 1, 1, 1, 3, 0, 3,
        0, 0, 1, 0, 3, 3, 3, 0, 3, 1, 0, 1, 0, 1, 0, 1, 2, 3, 0, 0, 2, 0,
        0, 3, 0, 2, 3, 2, 0, 1, 1, 2, 0, 2, 1, 2, 3, 2, 0, 1, 2, 0, 0, 3,
        3, 3, 3, 3, 3, 1, 3, 3, 2, 1, 2, 1, 3, 0, 3, 3, 0, 1, 1, 0, 0, 0,
        1, 0, 1, 1, 0, 1, 1, 2, 1, 0, 1, 0, 0, 0, 1, 2, 1, 1, 1, 0, 1, 0,
        2, 0, 0, 1, 1, 0, 1, 2, 3, 0, 0, 2, 2, 2, 1, 0, 3, 3, 2, 3, 0, 0,
        3, 1, 3, 1, 2, 1, 1, 2, 2, 0, 3, 0, 1, 3, 3, 0, 0, 0, 2, 2, 2, 1,
        2, 1, 3, 3, 3, 0, 1, 1, 1, 2, 1, 3, 1, 0, 1, 1, 1, 2, 2, 2, 2, 1,
        1, 0, 2, 1, 2, 2, 0, 1, 2, 0, 1, 1, 2, 3, 2, 1, 2, 1, 0, 2, 3, 1,
        1, 3, 2, 2, 3, 3, 2, 0, 2, 0, 1, 2, 2, 3]),
 'score': array([0.7313169 , 0.9992514 , 0.968678  , 0.4581957 , 0.5157    ,
        0.9993032 , 0.9998324 , 0.99973446, 0.45

visualize logs using tensorboard

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/