In [1]:
import sys
sys.path.append('..')

import torch
import pandas as pd
import numpy as np
from torch import nn 
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import Dataset
from torchcrf import CRF
from tqdm.notebook import tqdm_notebook
import logging

from utils import *
from dataset import *
from preprocess import *
from wrapper import *
from models import *
from pipeline import POSTaggingPipeline, map_to_df

torch.cuda.is_available()
# device = torch.device('cpu')
device = torch.device('cuda:0')

In [2]:
df = pd.read_csv('../data/data-org/train.csv', sep='\t').set_index('id')
corpus = df[df.label == 0].drop(columns=['label'])

In [3]:
model_name = "KoichiYasuoka/chinese-bert-wwm-ext-upos"

tagger = POSTaggingPipeline(model_name=model_name)
ds = tagger(texts=corpus, device=device, return_tags=False)

  indexed_value = torch.tensor(value[index]).squeeze()
100%|██████████| 718/718 [01:15<00:00,  9.45it/s]


In [4]:
ds.test = False
ds.train_val_split = 0.8
ds.construct_dataset()

In [5]:
id2label = tagger.model.config.id2label
num_tags = len(id2label)

crf = CRF(num_tags=num_tags, batch_first=True)
if 'cuda' in device.type:
    crf.cuda()

In [50]:
seq_len = ds.maxlength
batch_size = 128

def get_tagging_datasets(ds):
    return (
        ds.dataset['train'].with_format('pytorch', columns=['emissions', 'attention_mask']).rename_columns({'attention_mask':'mask'}),
        ds.dataset['val'].with_format('pytorch', columns=['emissions', 'attention_mask']).rename_columns({'attention_mask':'mask'}),
    )

def process_batch(batch, device):
    inputs = {}
    inputs['mask'] = batch['mask'].bool().to(device=device)
    inputs['emissions'] = torch.concat([torch.concat([x]).unsqueeze(1) for x in batch['emissions']], dim=1).to(device=device)
    inputs['tags'] = inputs['emissions'].argmax(-1).to(device=device)
    return inputs
    
train_set, dev_set = get_tagging_datasets(ds)
train_dataloader = DataLoader(
    train_set, 
    batch_size=batch_size, 
    drop_last=True, 
)
dev_dataloader = DataLoader(
    dev_set, 
    batch_size=batch_size, 
    drop_last=True, 
)

In [39]:
class EarlyStopping():
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, train_loss, validation_loss):
        if (validation_loss - train_loss) > self.min_delta:
            self.counter +=1
            if self.counter >= self.tolerance:  
                self.early_stop = True

In [51]:
def process_batch(batch, device):
    inputs = {}
    inputs['mask'] = batch['mask'].bool().to(device=device)
    inputs['emissions'] = torch.concat([torch.concat([x]).unsqueeze(1) for x in batch['emissions']], dim=1).to(device=device)
    inputs['tags'] = inputs['emissions'].argmax(-1).to(device=device)
    return inputs

In [74]:
learning_rate=5e-3
optimiser = torch.optim.AdamW(crf.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimiser, gamma=0.7)

loss_history = []
eval_loss_history = []

nn.init.xavier_normal_(crf.transitions)
nn.init.normal_(crf.start_transitions)
nn.init.normal_(crf.start_transitions)

crf.train()

for epoch in range(3):
    for i, batch in enumerate(tqdm_notebook(train_dataloader)):
        inputs = process_batch(batch, device=device)
        loss = -crf(**inputs)
        loss_history.append(loss.cpu().detach().unsqueeze(0))
        
        crf.zero_grad()
        loss.backward()
        optimiser.step()

        if (i+1) % 1 == 0:
            print(f'Epoch {epoch+1}/3, batch {i+1} - loss={loss_history[-1][0]:.4f}')
        if (i+1) % 10 == 0:
            crf.eval()
            for j, dev_batch in enumerate(tqdm_notebook(dev_dataloader)):
                dev_inputs = process_batch(dev_batch, device=device)
                eval_loss = -crf(**dev_inputs)
                eval_loss_history.append(eval_loss.cpu().detach().unsqueeze(0))
            print(f'Epoch {epoch+1}/3, batch {i+1} - eval_loss={eval_loss_history[-1][0]:.4f}')
            crf.train()

        if i % 50 == 0:
            scheduler.step()
    

  0%|          | 0/73 [00:00<?, ?it/s]

Epoch 1/3, batch 1 - loss=423.0383
Epoch 1/3, batch 2 - loss=384.0992
Epoch 1/3, batch 3 - loss=440.8916
Epoch 1/3, batch 4 - loss=434.1529
Epoch 1/3, batch 5 - loss=411.3040
Epoch 1/3, batch 6 - loss=409.9524
Epoch 1/3, batch 7 - loss=400.6971
Epoch 1/3, batch 8 - loss=429.5026
Epoch 1/3, batch 9 - loss=416.3147
Epoch 1/3, batch 10 - loss=395.2034


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 1/3, batch 10 - eval_loss=401.0936
Epoch 1/3, batch 11 - loss=427.2185
Epoch 1/3, batch 12 - loss=422.9277
Epoch 1/3, batch 13 - loss=403.8193
Epoch 1/3, batch 14 - loss=410.3123
Epoch 1/3, batch 15 - loss=392.2834
Epoch 1/3, batch 16 - loss=412.0963
Epoch 1/3, batch 17 - loss=421.4496
Epoch 1/3, batch 18 - loss=402.5340
Epoch 1/3, batch 19 - loss=398.7207
Epoch 1/3, batch 20 - loss=393.5833


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 1/3, batch 20 - eval_loss=396.7120
Epoch 1/3, batch 21 - loss=437.3124
Epoch 1/3, batch 22 - loss=401.1963
Epoch 1/3, batch 23 - loss=413.6838
Epoch 1/3, batch 24 - loss=404.9303
Epoch 1/3, batch 25 - loss=416.8388
Epoch 1/3, batch 26 - loss=429.9406
Epoch 1/3, batch 27 - loss=433.1796
Epoch 1/3, batch 28 - loss=419.8614
Epoch 1/3, batch 29 - loss=389.1666
Epoch 1/3, batch 30 - loss=443.5643


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 1/3, batch 30 - eval_loss=392.5508
Epoch 1/3, batch 31 - loss=409.0026
Epoch 1/3, batch 32 - loss=430.6183
Epoch 1/3, batch 33 - loss=404.6949
Epoch 1/3, batch 34 - loss=409.4085
Epoch 1/3, batch 35 - loss=420.1406
Epoch 1/3, batch 36 - loss=399.0977
Epoch 1/3, batch 37 - loss=408.4098
Epoch 1/3, batch 38 - loss=407.6553
Epoch 1/3, batch 39 - loss=395.3962
Epoch 1/3, batch 40 - loss=410.2718


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 1/3, batch 40 - eval_loss=388.6565
Epoch 1/3, batch 41 - loss=380.2943
Epoch 1/3, batch 42 - loss=413.4385
Epoch 1/3, batch 43 - loss=415.3864
Epoch 1/3, batch 44 - loss=394.0885
Epoch 1/3, batch 45 - loss=422.6746
Epoch 1/3, batch 46 - loss=429.2598
Epoch 1/3, batch 47 - loss=411.3257
Epoch 1/3, batch 48 - loss=419.4204
Epoch 1/3, batch 49 - loss=400.8913
Epoch 1/3, batch 50 - loss=426.1041


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 1/3, batch 50 - eval_loss=384.9334
Epoch 1/3, batch 51 - loss=385.0130
Epoch 1/3, batch 52 - loss=396.1605
Epoch 1/3, batch 53 - loss=406.7057
Epoch 1/3, batch 54 - loss=407.1459
Epoch 1/3, batch 55 - loss=396.9057
Epoch 1/3, batch 56 - loss=405.9680
Epoch 1/3, batch 57 - loss=409.4943
Epoch 1/3, batch 58 - loss=408.0563
Epoch 1/3, batch 59 - loss=405.7230
Epoch 1/3, batch 60 - loss=399.4765


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 1/3, batch 60 - eval_loss=382.4023
Epoch 1/3, batch 61 - loss=404.7292
Epoch 1/3, batch 62 - loss=399.7075
Epoch 1/3, batch 63 - loss=402.0840
Epoch 1/3, batch 64 - loss=387.5009
Epoch 1/3, batch 65 - loss=425.2590
Epoch 1/3, batch 66 - loss=404.6015
Epoch 1/3, batch 67 - loss=404.3321
Epoch 1/3, batch 68 - loss=386.5385
Epoch 1/3, batch 69 - loss=383.0920
Epoch 1/3, batch 70 - loss=415.2444


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 1/3, batch 70 - eval_loss=380.1236
Epoch 1/3, batch 71 - loss=369.6786
Epoch 1/3, batch 72 - loss=405.6377
Epoch 1/3, batch 73 - loss=404.1538


  0%|          | 0/73 [00:00<?, ?it/s]

Epoch 2/3, batch 1 - loss=401.0311
Epoch 2/3, batch 2 - loss=360.5025
Epoch 2/3, batch 3 - loss=415.7914
Epoch 2/3, batch 4 - loss=410.8350
Epoch 2/3, batch 5 - loss=391.2723
Epoch 2/3, batch 6 - loss=388.6416
Epoch 2/3, batch 7 - loss=377.8962
Epoch 2/3, batch 8 - loss=407.7055
Epoch 2/3, batch 9 - loss=394.2191
Epoch 2/3, batch 10 - loss=373.2824


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 2/3, batch 10 - eval_loss=377.8331
Epoch 2/3, batch 11 - loss=406.3651
Epoch 2/3, batch 12 - loss=401.2399
Epoch 2/3, batch 13 - loss=383.5083
Epoch 2/3, batch 14 - loss=390.1543
Epoch 2/3, batch 15 - loss=369.9137
Epoch 2/3, batch 16 - loss=392.2078
Epoch 2/3, batch 17 - loss=399.8146
Epoch 2/3, batch 18 - loss=383.0024
Epoch 2/3, batch 19 - loss=377.6260
Epoch 2/3, batch 20 - loss=374.1037


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 2/3, batch 20 - eval_loss=376.2682
Epoch 2/3, batch 21 - loss=418.3176
Epoch 2/3, batch 22 - loss=383.4324
Epoch 2/3, batch 23 - loss=394.8574
Epoch 2/3, batch 24 - loss=387.5274
Epoch 2/3, batch 25 - loss=400.4399
Epoch 2/3, batch 26 - loss=411.2983
Epoch 2/3, batch 27 - loss=416.8867
Epoch 2/3, batch 28 - loss=401.6122
Epoch 2/3, batch 29 - loss=371.8058
Epoch 2/3, batch 30 - loss=428.1654


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 2/3, batch 30 - eval_loss=374.7233
Epoch 2/3, batch 31 - loss=393.6469
Epoch 2/3, batch 32 - loss=415.7027
Epoch 2/3, batch 33 - loss=386.7276
Epoch 2/3, batch 34 - loss=394.8110
Epoch 2/3, batch 35 - loss=403.5959
Epoch 2/3, batch 36 - loss=382.1316
Epoch 2/3, batch 37 - loss=393.3418
Epoch 2/3, batch 38 - loss=391.5686
Epoch 2/3, batch 39 - loss=379.9682
Epoch 2/3, batch 40 - loss=395.6619


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 2/3, batch 40 - eval_loss=373.2346
Epoch 2/3, batch 41 - loss=366.3731
Epoch 2/3, batch 42 - loss=397.6897
Epoch 2/3, batch 43 - loss=401.6999
Epoch 2/3, batch 44 - loss=380.8022
Epoch 2/3, batch 45 - loss=407.7068
Epoch 2/3, batch 46 - loss=417.2634
Epoch 2/3, batch 47 - loss=398.4596
Epoch 2/3, batch 48 - loss=407.8481
Epoch 2/3, batch 49 - loss=388.1662
Epoch 2/3, batch 50 - loss=413.5661


  0%|          | 0/17 [00:00<?, ?it/s]

Epoch 2/3, batch 50 - eval_loss=371.7516
Epoch 2/3, batch 51 - loss=373.4078
Epoch 2/3, batch 52 - loss=385.1151
Epoch 2/3, batch 53 - loss=394.3046
Epoch 2/3, batch 54 - loss=395.0264
Epoch 2/3, batch 55 - loss=385.6790
Epoch 2/3, batch 56 - loss=393.6330


In [64]:
import json
import os
from time import time


state_dict = crf.state_dict()

state_dicts_dir = 'crf_runs'
if not os.path.exists(state_dicts_dir):
    os.mkdir(state_dicts_dir)
cp_dir = os.path.join(state_dicts_dir, str(time()).split('.')[-1]+'.tar')
torch.save({
    'epoch': epoch,
    'model_state_dict': crf.state_dict(),
    'optimizer_state_dict': optimiser.state_dict(),
    'loss': loss,
}, cp_dir)