In [1]:
import torch
from torchtext.datasets import AG_NEWS
import tiktoken
from torch.utils.data import Dataset, DataLoader
import lightning as L
import finalnlp
from finalnlp.replacer import replace_linears_in_pytorch_model
from finalnlp import bitnet1
from finalnlp import bitnet158
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.tuner import Tuner
from finalnlp.sentence_classifier.model import SentenceClassifier
import random

%load_ext autoreload
%autoreload 2

In [2]:
encoding_name = "p50k_base" # "cl100k_base" is bigger

og_encoder = tiktoken.get_encoding(encoding_name)
tokenizer = tiktoken.Encoding(
    # If you're changing the set of special tokens, make sure to use a different name
    # It should be clear from the name what behaviour to expect.
    name=encoding_name+"_with_cls",
    pat_str=og_encoder._pat_str,
    mergeable_ranks=og_encoder._mergeable_ranks,
    special_tokens={
        **og_encoder._special_tokens,
        "<CLS>": og_encoder.n_vocab,
        "<NULL>": og_encoder.n_vocab+1,
    }
)

tokenizer.encode("<CLS>", allowed_special=set(["<CLS>", "<NULL>"]))

[50281]

In [3]:
train_iter = AG_NEWS(split="train")

ag_news_classes = ["World", "Sports", "Business", "Sci/Tech"]
max_toks_len = 128

def yield_tokens(data_iter, num_classes, ):
    cls_tok = torch.tensor(
        tokenizer.encode("<CLS>", allowed_special=set(["<CLS>", "<NULL>"]))
    )
    null_tok = torch.tensor(
        tokenizer.encode("<NULL>", allowed_special=set(["<CLS>", "<NULL>"]))
    )
    for label, text in data_iter:
        tok_list = torch.cat([cls_tok, torch.tensor(tokenizer.encode(text))])
        if len(tok_list) > max_toks_len:
            tok_list = tok_list[:max_toks_len]
        else:
            tok_list = torch.cat(
                [tok_list, *(null_tok for _ in range(max_toks_len - len(tok_list)))],
            )
        assert len(tok_list) == max_toks_len
        yield (
            tok_list,
            torch.tensor(label-1),
        )


train_iter = iter(AG_NEWS(split="train"))
val_iter = iter(AG_NEWS(split="test"))

train_dataset = list(yield_tokens(train_iter, len(ag_news_classes)))
test_dataset = list(yield_tokens(val_iter, len(ag_news_classes)))

# shuffle the training data
L.seed_everything(42, workers=True)
random.shuffle(train_dataset)
random.shuffle(test_dataset)

train_dataset[0]

Seed set to 42


(tensor([50281,   817,   667,    11, 16132,   290,   347, 22090,   319, 18692,
         14473,   422,  5478,   220, 13077, 19266,    57,    11, 23731,   357,
         12637,     8,   532, 23888,  3790,   287,   428, 36972,   220, 22982,
         19173,  3240,  2900,   257,  7770,  4151,   355,  1865,  1194,  1126,
           868,  7779,   220, 33654,   351, 13020,    11, 46738, 17626,   290,
          9264,   286, 12783,   220,  9087,   329,  2031,   374, 11137,   572,
           656,   262, 10326,    13, 50282, 50282, 50282, 50282, 50282, 50282,
         50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282,
         50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282,
         50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282,
         50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282,
         50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282, 50282,
         50282, 50282, 50282, 50282, 50282, 50282, 5

In [4]:
print(tokenizer.decode(list(train_dataset[0][0])))
print(ag_news_classes[train_dataset[0][1]])


<CLS>Thirst, Fear and Bribes on Desert Escape from Africa  AGADEZ, Niger (Reuters) - Customs officers in this dusty  Saharan town turned a blind eye as yet another creaking truck  piled with grain, smuggled cigarettes and dozens of migrants  heading for Europe rumbled off into the desert.<NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL>
World


In [5]:
# autoencoder = LitAutoEncoder(Encoder(), Decoder())

d_model = 128
nhead = 8
num_layers = 2
d_ffl = 512

max_steps = 100_000
val_check_interval = 5_000

model = SentenceClassifier(
    tokenizer.n_vocab,
    len(ag_news_classes), 
    d_model=d_model,
    nhead=nhead,
    num_layers=num_layers,
    d_ffl=d_ffl,
    # linear_replacer=bitnet1.BitLinear1B
)

wandb_logger = pl_loggers.WandbLogger("Classify-Plain")
wandb_logger.experiment.config.update({"d_model": d_model, "nhead": nhead, "num_layers": num_layers})
wandb_logger.experiment.config.update({"problem": "classify", "linear_replacer": "None"})

trainer = L.Trainer(
    # callbacks=[EarlyStopping(monitor="train_loss", mode="min")],
    logger=wandb_logger,
    max_steps=max_steps,
    val_check_interval=val_check_interval,
)

# tuner = Tuner(trainer)
# tuner.scale_batch_size(model, mode="power")

wandb_logger.watch(model)
torch.set_float32_matmul_precision('medium')
trainer.fit(
    model=model,
    train_dataloaders=DataLoader(train_dataset, num_workers=15),
    val_dataloaders=DataLoader(test_dataset, num_workers=15),
)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcandrewlee14[0m ([33mandrews-org[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


RuntimeError: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero.

In [None]:
count = 0
for i in range(10):
    logits = model.forward(test_dataset[i][0].unsqueeze(0).to(model.device))
    print(tokenizer.decode(list(test_dataset[i][0])))
    print("Guess:", ag_news_classes[torch.argmax(logits, dim=-1).item()])
    print("Exp:", ag_news_classes[test_dataset[i][1].item()])
    is_match = torch.argmax(logits, dim=-1).item() == test_dataset[i][1].item()
    print(is_match)
    count += int(is_match)
print(f"Accuracy: {count}/10")

<CLS>Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.<NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL><NULL>
Guess: Sports
Exp: Business
False
<CLS>The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\privately funded suborbita