In [1]:
import os

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
N=2048
class BertDataset(Dataset):
    def __init__(self, path, path_targets, is_train=False, pred_len=N):

        self.is_train = is_train
        self.encoded_texts = [0] * N
        self.targets = [-1] * N
        self.pred_len = pred_len
        with open(path, 'r') as f:
            for text in f.readlines():
                self.encoded_texts.extend(list(map(int, text.split())))
        with open(path_targets, 'r') as ft:
            for text in ft.readlines():
                self.targets.extend(list(map(int, text.split())))
        self.encoded_texts.extend([0] * N)
        self.targets.extend([-1] * N)
        idxs = []
        
        for i, (text, target) in enumerate(zip(self.encoded_texts, self.targets)):
            if target >= 1:
                idxs.append(i)
                self.targets[i - 1] = target

        self.encoded_texts = np.delete(self.encoded_texts, idxs)
        self.targets = np.delete(self.targets, idxs)

            
    def __getitem__(self, idx):
        start_idx = idx * self.pred_len
        start_idx = max(0, start_idx)
        end_idx = start_idx + N
        return torch.LongTensor(self.encoded_texts[start_idx: end_idx]),\
               torch.LongTensor(self.targets[start_idx: end_idx])

    def __len__(self):
        return (len(self.encoded_texts) - 2048)//self.pred_len - 1


def collate(batch):
    texts, targets = zip(*batch)
    return torch.stack(texts), torch.stack(targets)

def get_datasets(pred_len):
    train_dataset = BertDataset('processed_train_words.txt', 'processed_train_targets.txt', is_train=True)
    valid_dataset = BertDataset('processed_val_words.txt', 'processed_val_targets.txt', pred_len=pred_len)
    return train_dataset, valid_dataset


def get_data_loaders(train_dataset, valid_dataset):
    train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0, collate_fn=collate, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=4, collate_fn=collate)
    return train_loader, valid_loader

In [28]:
import os
from glob import glob
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from bert_punc import BertPunctuator, get_eval_metrics

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.optim import AdamW
from torch import nn

import numpy as np
import pickle

from torch.utils.data import Dataset, DataLoader
from itertools import product

def collate(batch):
    texts, targets = zip(*batch)
    try:
        texts, targets = torch.stack(texts), torch.stack(targets)
    except Exception:
        return texts[0][None, ...], targets[0][None, ...]
    return texts, targets

def combine(pred_num, preds):

    ps = []
    for i in range(preds.shape[0]):
        start_idx = max(0, i-2048//pred_num+1)
        end_idx = min(preds.shape[0], i+1)

        p = []
        for j, k in enumerate(range(start_idx, end_idx)):
            j = end_idx - start_idx - j - 1
            p.append(preds[k][j*pred_num:(j+1)*pred_num])
        p = np.stack(p)
        if p.shape[0] > 2:
            p = p[1:-1, :, :]
            
        ps.append(np.log(np.exp(p).mean(0)))
    ps = np.concatenate(ps)
    return ps

device = torch.device('cuda:1')
torch.cuda.set_device(device)


In [None]:
model = BertPunctuator()
model.to(device)
model.load_state_dict(torch.load('model_gdown.pt', map_location=device))

In [52]:
def make_multi_preds(N_PREDICTIONS_FOR_TOKEN, model):
    PREDICTION_NUM = N_PREDICTIONS_FOR_TOKEN
    WINDOW_SHIFT = 2048 // PREDICTION_NUM
    train, test_dataset = get_datasets(pred_len=WINDOW_SHIFT)
    train_loader, test_loader = get_data_loaders(train, test_dataset)
    model.eval()
    all_test_preds = []

    for data in tqdm(test_loader):
        text, targets = data
        with torch.no_grad():
            preds, _ = model(text.to(device))

        all_test_preds.append(preds.detach().cpu().numpy())


    all_valid_target = test_dataset.targets
    all_valid_preds = np.concatenate(all_test_preds)
    print(all_valid_preds.shape)
    ps = combine(2048 // PREDICTION_NUM, all_valid_preds)
    _targets = np.array(all_valid_target[:ps.shape[0]])

    ps = ps[_targets != -1]
    _targets = _targets[_targets != -1]

    return(get_eval_metrics(_targets, ps), ps)


In [None]:
1.00           0.92           0.96
0.85           0.96           0.90
0.42           0.80           0.55
0.51           0.97           0.67
0.86           0.92           0.89
0.35           0.52           0.42

In [54]:
rev, ps1= make_multi_preds(1, model)

100%|███████████████████████████████████████████| 41/41 [00:56<00:00,  1.37s/it]


(161, 2048, 6)
              precision    recall  f1-score   support

           0       1.00      0.92      0.96    253499
           1       0.85      0.96      0.90     13355
           2       0.42      0.80      0.55        35
           3       0.51      0.97      0.67     20247
           4       0.86      0.92      0.89       577
           5       0.35      0.52      0.42       102

    accuracy                           0.93    287815
   macro avg       0.67      0.85      0.73    287815
weighted avg       0.96      0.93      0.94    287815



In [50]:
rev, ps2 = make_multi_preds(2, model)

100%|███████████████████████████████████████████| 81/81 [01:52<00:00,  1.39s/it]


(321, 2048, 6)
              precision    recall  f1-score   support

           0       1.00      0.92      0.96    253499
           1       0.85      0.96      0.90     13355
           2       0.43      0.80      0.56        35
           3       0.52      0.97      0.67     20247
           4       0.86      0.92      0.89       577
           5       0.34      0.51      0.41       102

    accuracy                           0.93    287815
   macro avg       0.67      0.85      0.73    287815
weighted avg       0.96      0.93      0.94    287815



{'cls_report': {'0': {'precision': 0.9987220856929159,
   'recall': 0.9218024528696366,
   'f1-score': 0.9587219010654927,
   'support': 253499},
  '1': {'precision': 0.8534007047403763,
   'recall': 0.9611381505054287,
   'f1-score': 0.9040709959149176,
   'support': 13355},
  '2': {'precision': 0.4307692307692308,
   'recall': 0.8,
   'f1-score': 0.56,
   'support': 35},
  '3': {'precision': 0.5156710914454278,
   'recall': 0.9670074578949968,
   'f1-score': 0.6726445074293567,
   'support': 20247},
  '4': {'precision': 0.8617886178861789,
   'recall': 0.9185441941074524,
   'f1-score': 0.889261744966443,
   'support': 577},
  '5': {'precision': 0.3443708609271523,
   'recall': 0.5098039215686274,
   'f1-score': 0.4110671936758893,
   'support': 102},
  'accuracy': 0.9266403766308219,
  'macro avg': {'precision': 0.6674537652435469,
   'recall': 0.8463826961576904,
   'f1-score': 0.7326277238420166,
   'support': 287815},
  'weighted avg': {'precision': 0.9574221766926102,
   'recall

In [None]:
1.00           0.92           0.96
0.85           0.96           0.90
0.43           0.80           0.56
0.52           0.97           0.67
0.86           0.92           0.89
0.34           0.51           0.41

In [51]:
rev, ps4 = make_multi_preds(4, model)

100%|█████████████████████████████████████████| 161/161 [03:46<00:00,  1.41s/it]


(642, 2048, 6)
              precision    recall  f1-score   support

           0       1.00      0.92      0.96    253499
           1       0.85      0.96      0.90     13355
           2       0.46      0.80      0.58        35
           3       0.52      0.97      0.67     20247
           4       0.86      0.92      0.89       577
           5       0.35      0.51      0.41       102

    accuracy                           0.93    287815
   macro avg       0.67      0.85      0.74    287815
weighted avg       0.96      0.93      0.94    287815



ValueError: too many values to unpack (expected 2)

In [53]:
rev, ps8 = make_multi_preds(8, model)

100%|█████████████████████████████████████████| 321/321 [07:28<00:00,  1.40s/it]


(1283, 2048, 6)
              precision    recall  f1-score   support

           0       1.00      0.92      0.96    253499
           1       0.85      0.96      0.90     13355
           2       0.47      0.80      0.59        35
           3       0.52      0.97      0.67     20247
           4       0.86      0.92      0.89       577
           5       0.34      0.51      0.41       102

    accuracy                           0.93    287815
   macro avg       0.67      0.85      0.74    287815
weighted avg       0.96      0.93      0.94    287815



In [58]:
rev, ps16 = make_multi_preds(16, model)

100%|█████████████████████████████████████████| 642/642 [14:50<00:00,  1.39s/it]


(2566, 2048, 6)
              precision    recall  f1-score   support

           0       1.00      0.92      0.96    253499
           1       0.85      0.96      0.90     13355
           2       0.46      0.80      0.58        35
           3       0.52      0.97      0.67     20247
           4       0.86      0.92      0.89       577
           5       0.34      0.51      0.41       102

    accuracy                           0.93    287815
   macro avg       0.67      0.85      0.74    287815
weighted avg       0.96      0.93      0.94    287815



In [59]:
rev, ps32 = make_multi_preds(32, model)

100%|███████████████████████████████████████| 1283/1283 [29:42<00:00,  1.39s/it]


(5132, 2048, 6)
              precision    recall  f1-score   support

           0       1.00      0.92      0.96    253499
           1       0.85      0.96      0.90     13355
           2       0.46      0.80      0.58        35
           3       0.52      0.97      0.67     20247
           4       0.86      0.92      0.89       577
           5       0.34      0.51      0.41       102

    accuracy                           0.93    287815
   macro avg       0.67      0.85      0.74    287815
weighted avg       0.96      0.93      0.94    287815



In [63]:
rev, ps64 = make_multi_preds(64, model)

100%|█████████████████████████████████████| 2566/2566 [1:00:19<00:00,  1.41s/it]


(10264, 2048, 6)
              precision    recall  f1-score   support

           0       1.00      0.92      0.96    253499
           1       0.85      0.96      0.90     13355
           2       0.46      0.80      0.58        35
           3       0.52      0.97      0.67     20247
           4       0.86      0.92      0.89       577
           5       0.34      0.51      0.41       102

    accuracy                           0.93    287815
   macro avg       0.67      0.85      0.74    287815
weighted avg       0.96      0.93      0.94    287815



In [64]:
rev

{'cls_report': {'0': {'precision': 0.998709280359692,
   'recall': 0.9218024528696366,
   'f1-score': 0.9587160009600453,
   'support': 253499},
  '1': {'precision': 0.853316496078692,
   'recall': 0.9613627854736054,
   'f1-score': 0.9041230942572445,
   'support': 13355},
  '2': {'precision': 0.45901639344262296,
   'recall': 0.8,
   'f1-score': 0.5833333333333334,
   'support': 35},
  '3': {'precision': 0.5156608097784569,
   'recall': 0.9668098977626315,
   'f1-score': 0.672587960417812,
   'support': 20247},
  '4': {'precision': 0.8603896103896104,
   'recall': 0.9185441941074524,
   'f1-score': 0.8885163453478626,
   'support': 577},
  '5': {'precision': 0.33986928104575165,
   'recall': 0.5098039215686274,
   'f1-score': 0.407843137254902,
   'support': 102},
  'accuracy': 0.9266369021767454,
  'macro avg': {'precision': 0.6711603118491376,
   'recall': 0.8463872086303255,
   'f1-score': 0.7358533119285333,
   'support': 287815},
  'weighted avg': {'precision': 0.957405302462515

In [61]:
ps16[0], ps32[0], ps1[0]

(array([-0.06916147, -7.937645  , -9.8527    , -2.714312  , -9.3552475 ,
        -9.466344  ], dtype=float32),
 array([-0.11192638, -5.9801354 , -7.689894  , -2.278664  , -8.919961  ,
        -7.959797  ], dtype=float32),
 array([-7.3791533e-03, -9.3142328e+00, -1.0746041e+01, -4.9413085e+00,
        -9.5344534e+00, -1.0690517e+01], dtype=float32))