In [1]:
import sys
sys.path.append('..')

import torch
import pandas as pd
import numpy as np
from torch import nn 
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import Dataset
from torchcrf import CRF
from tqdm.notebook import tqdm_notebook
import logging

from utils import *
from dataset import *
from preprocess import *
from wrapper import *
from models import *
from pipeline import POSTaggingPipeline, map_to_df

torch.cuda.is_available()
# device = torch.device('cpu')
device = torch.device('cuda:0')

In [2]:
df = pd.read_csv('../data/data-org/train.csv', sep='\t').set_index('id')
corpus = df[df.label == 0].drop(columns=['label'])

In [3]:
model_name = "KoichiYasuoka/chinese-bert-wwm-ext-upos"

tagger = POSTaggingPipeline(model_name=model_name)
ds = tagger(texts=corpus, device=device, return_tags=False)

  indexed_value = torch.tensor(value[index]).squeeze()
100%|██████████| 718/718 [01:18<00:00,  9.15it/s]


In [4]:
ds.test = False
ds.train_val_split = 0.8
ds.construct_dataset()

In [5]:
id2label = tagger.model.config.id2label
num_tags = len(id2label)

crf = CRF(num_tags=num_tags, batch_first=True)
if 'cuda' in device.type:
    crf.cuda()

In [6]:
seq_len = ds.maxlength
batch_size = 32
tagging_ds = ds.dataset['train'].with_format('pytorch', columns=['emissions', 'attention_mask']).rename_columns({'attention_mask':'mask'})

In [7]:
dataloader = DataLoader(
    tagging_ds, 
    batch_size=batch_size, 
    drop_last=True, 
)

learning_rate=1e-5
optimiser = torch.optim.AdamW(crf.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimiser, gamma=0.7)

loss_history = []

nn.init.uniform_(crf.start_transitions, 0, 1)
nn.init.uniform_(crf.end_transitions, 0, 1)
nn.init.uniform_(crf.transitions, 0, 1)

crf.train()
for epoch in range(3):
    for i, batch in enumerate(tqdm_notebook(dataloader)):
        inputs = {}
        inputs['mask'] = batch['mask'].bool().to(device=device)
        inputs['emissions'] = torch.concat([torch.concat([x]).unsqueeze(1) for x in batch['emissions']], dim=1).to(device=device)
        inputs['tags'] = inputs['emissions'].argmax(-1).to(device=device)
        loss = -crf(**inputs)
        loss_history.append(loss.cpu().detach().unsqueeze(0))
        if (i+1) % 5 == 0:
            print(f'Epoch {epoch+1}/3, batch {i+1} - loss={loss_history[-1][0]:.4f}')

        crf.zero_grad()
        loss.backward()
        optimiser.step()
        if i % 100 == 0:
            scheduler.step()
    

  0%|          | 0/294 [00:00<?, ?it/s]

Epoch 1/3, batch 5 - loss=110.5920
Epoch 1/3, batch 10 - loss=113.0825
Epoch 1/3, batch 15 - loss=99.4346


In [1]:
import matplotlib.pyplot as plt

plt.plot(torch.concat([loss.cpu().detach().unsqueeze(0) for loss in loss_history]))

NameError: name 'torch' is not defined