In [None]:
import torch
import torchvision
import torchaudio

In [None]:
 torch.cuda.empty_cache()

In [None]:
torch.__version__

In [None]:
#check for GPU
torch.cuda.is_available()

In [None]:
import numpy as np
import json
import random
import time
from transformers import BertTokenizer

In [None]:
map_relations = {'Comment':0, 'Contrast':1, 'Correction':2, 'Question-answer_pair':3, 'Acknowledgement':4,'Elaboration':5,
                 'Clarification_question':6, 'Conditional':7, 'Continuation':8, 'Result':9, 'Explanation':10, 'Q-Elab':11,
                 'Alternation':12, 'Narration':13, 'Confirmation_question':14, 'Sequence':15, 'Break':16}

In [None]:
home=%pwd
filename = home + '/data/TRAIN+VAL_407_bert.json'

In [None]:
from utils import load_data, input_format, position_ids_compute, tokenize
from bert_format import undersample, format_time, flat_accuracy

In [None]:
#no train validation split here
train_data = load_data(filename, map_relations)

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', use_fast=True)

In [None]:
put = ['1','0']
colors = ['r', 'b', 'g', 'o', 'y', 'p']
listx = ['b', 'c', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n']
listy = ['0', '1', '2', '3', '4', '5', '6', '7', '8']
listz = ['a', 'e', 'i', 'o', 'u', 'p', 'q', 'r', 'x', 'y', 'z']

In [None]:
coord_tokens = [''.join([s, t, i, j, k]) for s in put
                for t in colors
                for i in listx
                for j in listy
                for k in listz]

In [None]:
tokenizer.add_tokens(coord_tokens)

In [None]:
len(tokenizer)

In [None]:
device = torch.device('cuda')

Make attach data

In [None]:
inputs, labels_input, raw = input_format(train_data, 10)

In [None]:
batch_tokenized = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True)

In [None]:
input_ids = batch_tokenized["input_ids"].to(device) # list of token ids of dialogs in batch
attention_masks = batch_tokenized["attention_mask"].to(device)
token_type_ids = batch_tokenized["token_type_ids"].to(device)

In [None]:
labels = [label[3] for label in list(labels_input)]
labels = torch.tensor(labels)
labels_complete = torch.tensor(labels_input)

In [None]:
position_ids = position_ids_compute(tokenizer, input_ids, raw, labels_complete)

In [None]:
position_ids = torch.tensor(position_ids)

In [None]:
attach_labels_complete, attach_labels, attach_input_ids, attach_attention_masks, attach_token_type_ids, attach_position_ids = undersample(105753, labels_complete, labels, input_ids, attention_masks, token_type_ids, position_ids)

In [None]:
attach_task_ids = torch.tensor([0 for i in range(len(attach_labels))])

make relation data

In [None]:
inputs, labels_input, raw = input_format(train_data, 10, relations=True)

In [None]:
num_labels = 17

In [None]:
batch_tokenized = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True)

In [None]:
relation_input_ids = batch_tokenized["input_ids"].to(device) # list of token ids of dialogs in batch
relation_attention_masks = batch_tokenized["attention_mask"].to(device)
relation_token_type_ids = batch_tokenized["token_type_ids"].to(device)

In [None]:
labels = [label[3] for label in list(labels_input)]
relation_labels = torch.tensor(labels)
relation_labels_complete = torch.tensor(labels_input)

In [None]:
position_ids = position_ids_compute(tokenizer, relation_input_ids, raw, relation_labels_complete)

In [None]:
relation_position_ids = torch.tensor(position_ids)

In [None]:
relation_task_ids = torch.tensor([1 for i in range(len(relation_labels))])

end pre-processing

In [None]:
from torch import nn

In [None]:
# regroup the attach and relation datasets
pad_value = np.shape(attach_input_ids)[1]-np.shape(relation_input_ids)[1]
relation_input_ids = nn.functional.pad(input=relation_input_ids, pad=(0,pad_value), mode='constant', value=0)
relation_attention_masks = nn.functional.pad(input=relation_attention_masks, pad=(0,pad_value), mode='constant', value=0)
relation_token_type_ids = nn.functional.pad(input=relation_token_type_ids, pad=(0,pad_value), mode='constant', value=0)
relation_position_ids = nn.functional.pad(input=relation_position_ids, pad=(0,pad_value), mode='constant', value=0)

In [None]:
pad_value

In [None]:
from multitask_format import Task, MultiTaskModel

In [None]:
attach_task = Task(id = 0, name = 'attach prediction', type = "seq_classification", num_labels=2)
relation_task = Task(id = 1, name = 'relation prediction', type = "seq_classification", num_labels = num_labels)
tasks = [attach_task, relation_task]

In [None]:
input_ids = torch.cat((attach_input_ids,relation_input_ids))
attention_masks = torch.cat((attach_attention_masks, relation_attention_masks))
token_type_ids = torch.cat((attach_token_type_ids, relation_token_type_ids))
position_ids = torch.cat((attach_position_ids, relation_position_ids))
labels = torch.cat((attach_labels ,relation_labels))
task_ids = torch.cat((attach_task_ids ,relation_task_ids))

In [None]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

In [None]:
dataset = TensorDataset(input_ids, attention_masks, token_type_ids, position_ids, labels, task_ids)

In [None]:
train_dataloader = DataLoader(
            dataset,
            sampler = RandomSampler(dataset),
            batch_size = 32
        )

In [None]:
model = MultiTaskModel('bert-base-cased', tasks, len(tokenizer))

In [None]:
model.to(device)

In [None]:
from transformers import AdamW
import random, time

In [None]:
optimizer = AdamW(model.parameters(),
                  lr = 1.5e-5,
                  eps = 1e-8
                )

training_stats = []

seed_val = 18

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
if device == 'cuda' :
    torch.cuda.manual_seed_all(seed_val)

total_t0 = time.time()

In [None]:

model_path = home + '<name of your model folder>'
save_multitask_name =  '<name of your multitask .pth file output>'

In [None]:
for epoch_i in range(3):

    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, 3))

    t0 = time.time()

    total_train_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):
          if step % 500 == 0 and not step == 0:
              elapsed = format_time(time.time() - t0)
              print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

          model.zero_grad()

          outputs, embed = model(input_ids=batch[0].to(device),
                  attention_mask=batch[1].to(device),
                  token_type_ids=batch[2].to(device),
                  position_ids=batch[3].to(device),
                  labels=batch[4].to(device),
                  task_ids=batch[5].to(device)
                  )

          loss = outputs[0]
          total_train_loss += loss.item()
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
          optimizer.step()
          avg_train_loss = total_train_loss / len(train_dataloader)
          training_time = format_time(time.time() - t0)

output_model = model_path + save_multitask_name

print('finished_training, saving to : ', output_model)

torch.save({
    'model_state_dict': model.state_dict(),
}, output_model)
