In [1]:
import torch
from scripts.transformer.modeling import TinyBertForSequenceClassification
from scripts.transformer.tokenization import BertTokenizer
from scripts.transformer.optimization import BertAdam
from scripts.transformer.file_utils import WEIGHTS_NAME, CONFIG_NAME

from scripts.utils import *

In [10]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
teacher = TinyBertForSequenceClassification.from_pretrained(os.path.join('artifacts', 'BERT-title-content-benchmark:v0'), num_labels = 1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

In [11]:
# Initialize test data set
test_data_loader = create_reliable_news_dataloader(
    os.path.join('data','nela_gt_2018_site_split', 'test.jsonl'),
    tokenizer,
    max_len = 512,
    batch_size = 8 * max(1,n_gpu),
    sample = False,
    title_only = False
)

Max token length: 512 Batch size: 16 Shuffle: False Title only: False


In [17]:
teacher.to(device)
# teacher = torch.nn.DataParallel(teacher)
teacher.eval()
correct_predictions = 0
n_examples = 0
with torch.no_grad():
    loop = tqdm(test_data_loader)
    for idx, batch in enumerate(loop):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch["labels"].to(device).unsqueeze(1)
        
        logits, _, _ = teacher(input_ids = input_ids,
                                attention_mask = attention_mask,
                                token_type_ids = token_type_ids,
                                labels = labels)
        preds = torch.round(logits)

        correct_predictions += (preds == labels).sum().item()
        n_examples += len(labels)
        
        loop.set_postfix(val_acc = float(correct_predictions/n_examples))


  0%|          | 0/2427 [00:00<?, ?it/s]

In [23]:
tokenizer, model = create_model('bert-base-cased', 0.1, False)
checkpoint = torch.load(os.path.join('artifacts', 'BERT-title-content-benchmark:v0','pytorch_model.bin'))
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [24]:
model = torch.nn.DataParallel(model).to(device)

In [25]:
model.eval()
correct_predictions = 0
n_examples = 0
with torch.no_grad():
    loop = tqdm(test_data_loader)
    for idx, batch in enumerate(loop):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch["labels"].to(device).unsqueeze(1)
        
        outputs = model(input_ids = input_ids,
                        attention_mask = attention_mask,
                        token_type_ids = token_type_ids,
                        labels = labels)
        preds = torch.round(outputs['logits'])

        correct_predictions += (preds == labels).sum().item()
        n_examples += len(labels)
        
        loop.set_postfix(val_acc = float(correct_predictions/n_examples))

  0%|          | 0/2427 [00:00<?, ?it/s]

