In [9]:
from collections import Counter

import torch

from src.model import load_model
from src.utils import load_config
from src.data import load_dataset


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
config = load_config("classifier")

model, tokenizer = load_model(
    "classifier",
    config["model"]["params"],
    load_weights=True,
    weights_filename="misty-voice-22.pt",
)

  model.load_state_dict(torch.load(model_path))


In [3]:
model.to(DEVICE)

Classifier(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise

In [4]:
trainset, validset = load_dataset("q_timelines")

In [5]:
def classify(example):
    inputs = tokenizer(
        example["text"], return_tensors="pt", padding="max_length", truncation=True
    )
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    outputs = model(**inputs)
    preds = outputs.argmax(dim=-1).tolist()
    return {"preds": preds}

In [6]:
valid_set = validset.map(classify, batched=True, batch_size=128)

Map: 100%|██████████| 14424/14424 [00:55<00:00, 261.96 examples/s]


In [10]:
preds_counter = Counter(valid_set["preds"])

print(preds_counter)

Counter({0: 14424})
