In [0]:
!nvidia-smi

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
pip install transformers

In [0]:
import os
import pickle
import random

import numpy as np
from sklearn.metrics import classification_report 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader, Sampler
from torch.nn.utils.rnn import pad_sequence
import transformers
from transformers import BertTokenizer, BertForTokenClassification


random.seed(17)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

train_data_dir = '/content/drive/My Drive/Colab Notebooks/ezafe/data/bijankhan_corpus.tsv'
pickled_data_dir = '/content/drive/My Drive/Colab Notebooks/ezafe/data/bert_data.pickle'
model_dir = '/content/drive/My Drive/Colab Notebooks/ezafe/model/'

NUM_EPOCHS = 100
BATCH_SIZE = 8
LEARNING_RATE = 2e-5
DROPOUT_RATE = .1
TAG_PAD_OR_MASK_TOKEN = 2

In [0]:
def data_reader(directory):
    sents, sent = [], [101]
    all_ezafe_tags, ezafe_tags = [], [TAG_PAD_OR_MASK_TOKEN]
    with open(directory) as bijankhan_corpus:
        for line in bijankhan_corpus:
            if line != '\n':
                word, pos_tag, ezafe_tag = line.strip().split('\t')

                word = word.replace('ي', 'ی').replace('ك', 'ک').replace('ة', 'ه')
                word = tokenizer.encode(word, add_special_tokens=False)
                sent.extend(word)

                ezafe_tags.extend([int(ezafe_tag)] + [TAG_PAD_OR_MASK_TOKEN for _ in range(len(word) - 1)])
            else:
                sents.append(sent + [102])
                all_ezafe_tags.append(ezafe_tags + [TAG_PAD_OR_MASK_TOKEN])
                
                sent = [101]
                ezafe_tags = [2]

    return sents, all_ezafe_tags


class MySampler(Sampler):
    def __init__(self, data, i=0):
        random.shuffle(data)
        self.seq = list(range(len(data)))[i * BATCH_SIZE:]
    def __iter__(self):
        return iter(self.seq)
    def __len__(self):
        return len(self.seq)


class EzafeDataset(Dataset):
    def __init__(self, data):
       self.samples = data

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        x = torch.tensor(sample[0])
        mask = torch.sign(x)
        y = torch.tensor(sample[1])

        return x, mask, y


def collate_fn(batch):
    xs, masks, ys = zip(*batch)
    padded_xs = pad_sequence(xs, batch_first=True)
    padded_masks = pad_sequence(masks, batch_first=True)
    padded_ys = pad_sequence(ys, batch_first=True, padding_value=2)

    return {'input_ids': padded_xs, 'attention_masks': padded_masks}, padded_ys

In [0]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

try:
    with open(model_dir + 'last_i') as in_file:
        last_i = int(in_file.read().strip())
except FileNotFoundError:
    last_i = 0

print(last_i)

try:
    model = BertForTokenClassification.from_pretrained(model_dir + '_' + str(last_i))
    print('model is loaded:', last_i)
except:
    model = BertForTokenClassification.from_pretrained('bert-base-multilingual-cased')

try:
    with open(pickled_data_dir, 'rb') as out_file:
        sents, ezafe_tags = pickle.load(out_file)
except FileNotFoundError:
    sents, ezafe_tags = data_reader(train_data_dir) 
    with open(pickled_data_dir, 'wb') as out_file:
        pickle.dump([sents, ezafe_tags], out_file)

sents_shuf = []
ezafe_tags_shuf = []
index_shuf = list(range(len(sents)))

random.seed(17)
random.shuffle(index_shuf)

for i in index_shuf:
    sents_shuf.append(sents[i])
    ezafe_tags_shuf.append(ezafe_tags[i])

sents_shuf, ezafe_tags_shuf = zip(*[(sent, ezafe) for sent, ezafe in zip(sents_shuf, ezafe_tags_shuf) if len(sent) <= 512])

data_split_1 = int(len(sents_shuf) * .1)
data_split_2 = int(len(sents_shuf) * .2)

test_data = [(sents_shuf[i], ezafe_tags_shuf[i]) for i in range(len(sents_shuf))[:data_split_1]]
valid_data = [(sents_shuf[i], ezafe_tags_shuf[i]) for i in range(len(sents_shuf))[data_split_1:data_split_2]]
train_data = [(sents_shuf[i], ezafe_tags_shuf[i]) for i in range(len(sents_shuf))[data_split_2:]]

train_dataset = EzafeDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)

valid_dataset = EzafeDataset(valid_data)

print(len(train_dataset))
print(len(valid_dataset))

In [0]:
train_data_loader = DataLoader(dataset=train_dataset,                                                         
                               batch_size=BATCH_SIZE,                                            
                               collate_fn=collate_fn,
                               sampler=train_sampler,
                               shuffle=False,
                               num_workers=4)

valid_data_loader = DataLoader(dataset=valid_dataset,                                                         
                               batch_size=BATCH_SIZE,                                            
                               collate_fn=collate_fn,
                               shuffle=False,
                               num_workers=4)

criterion = nn.CrossEntropyLoss(ignore_index=TAG_PAD_OR_MASK_TOKEN).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

model.to(device)

for epoch in range(NUM_EPOCHS):
    for i, (x, y_true) in enumerate(train_data_loader):
        i += last_i

        if i % 1000 == 0 and i != 0:
            print(i)
        
        optimizer.zero_grad()

        y_true = y_true.to(device)

        y_pred = model(x['input_ids'].to(device), x['attention_masks'].to(device))[0]

        loss = criterion(y_pred.view(-1, y_pred.shape[-1]), y_true.view(-1))

        loss.backward()
        optimizer.step()

        if i % 2500 == 0 and i != 0:
            print('validating...')
            valid_true_labels, valid_pred_labels = [], []
            for x, y_true in valid_data_loader:
                y_true = y_true.to(device)

                y_pred = model(x['input_ids'].to(device), x['attention_masks'].to(device))[0]
   
                valid_true_labels.extend(y_true.reshape(-1).cpu().numpy().tolist())
                valid_pred_labels.extend(torch.argmax(y_pred, -1).reshape(-1).cpu().numpy().tolist())

            valid_true_labels_, valid_pred_labels_ = np.array(valid_true_labels), np.array(valid_pred_labels)
            valid_pad_mask_ids = np.where(valid_true_labels_ == TAG_PAD_OR_MASK_TOKEN)[0]
            valid_true_labels_ = np.delete(valid_true_labels_, valid_pad_mask_ids)
            valid_pred_labels_ = np.delete(valid_pred_labels_, valid_pad_mask_ids)

            valid_f1 = classification_report(valid_true_labels_, valid_pred_labels_, digits=4)
       
            summary = f'epoch: {epoch + 1} | step: {i}:\n\n{valid_f1}\n'
            print(summary)
   
            # save everything
            with open(model_dir + 'summary.txt', 'a+') as out_file:
                out_file.write(summary + '\n') 

            with open(model_dir + 'last_i', 'w+') as out_file:
                out_file.write(str(i))

            os.mkdir(model_dir + str(i))
            model.save_pretrained(model_dir + str(i))

In [0]:
# testing
test_dataset = EzafeDataset(test_data)

test_data_loader = DataLoader(dataset=test_dataset,                                                         
                              batch_size=BATCH_SIZE,                                            
                              collate_fn=collate_fn,
                              shuffle=False,
                              num_workers=4)

i = 2500
model = BertForTokenClassification.from_pretrained(model_dir + str(i))

model.to(device)

test_true_labels, test_pred_labels = [], []
for x, y_true in test_data_loader:
    y_true = y_true.to(device)

    y_pred = model(x['input_ids'].to(device), x['attention_masks'].to(device))[0]

    test_true_labels.extend(y_true.reshape(-1).cpu().numpy().tolist())
    test_pred_labels.extend(torch.argmax(y_pred, -1).reshape(-1).cpu().numpy().tolist())

test_true_labels_, test_pred_labels_ = np.array(test_true_labels), np.array(test_pred_labels)
test_pad_mask_ids = np.where(test_true_labels_ == TAG_PAD_OR_MASK_TOKEN)[0]
test_true_labels_ = np.delete(test_true_labels_, test_pad_mask_ids)
test_pred_labels_ = np.delete(test_pred_labels_, test_pad_mask_ids)

test_f1 = classification_report(test_true_labels_, test_pred_labels_, digits=4)

summary = f'epoch: {epoch + 1} | step: {i}:\n\n{test_f1}\n'
print(summary)