In [1]:
import torch
import torchvision
import torchaudio

In [47]:
 torch.cuda.empty_cache()

In [2]:
torch.__version__

'2.1.1+cu121'

In [3]:
#check for GPU
torch.cuda.is_available()
# device = torch.device('cuda')
# device

True

In [4]:
import numpy as np
import json
import random
import time
from transformers import BertTokenizer

In [5]:
map_relations = {'Comment':0, 'Contrast':1, 'Correction':2, 'Question-answer_pair':3, 'Acknowledgement':4,'Elaboration':5,
                 'Clarification_question':6, 'Conditional':7, 'Continuation':8, 'Result':9, 'Explanation':10, 'Q-Elab':11,
                 'Alternation':12, 'Narration':13, 'Confirmation_question':14, 'Sequence':15, 'Break':16}

In [6]:
home=%pwd
filename = home + '/data/TRAIN_407.json'

In [7]:
from utils import load_data, input_format, position_ids_compute, tokenize
from bert_format import undersample, format_time, flat_accuracy

In [8]:
#no train validation split here
train_data = load_data(filename, map_relations)

Loading data: /home/kate/LREC/data/TRAIN_407.json
407 dialogs, 21822 edus, 26299 relations, 194 backward relations
4787 edus have multiple parents


In [9]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', use_fast=True)

In [10]:
put = ['1','0']
colors = ['r', 'b', 'g', 'o', 'y', 'p']
listx = ['b', 'c', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n']
listy = ['0', '1', '2', '3', '4', '5', '6', '7', '8']
listz = ['a', 'e', 'i', 'o', 'u', 'p', 'q', 'r', 'x', 'y', 'z']

In [11]:
coord_tokens = [''.join([s, t, i, j, k]) for s in put
                for t in colors
                for i in listx
                for j in listy
                for k in listz]

In [12]:
tokenizer.add_tokens(coord_tokens)

13068

In [13]:
len(tokenizer)

42064

In [14]:
device = torch.device('cuda')

Make attach data

In [15]:
inputs, labels_input, raw = input_format(train_data, 10)

25714 relations
195839 candidates
170125 non attached


In [16]:
batch_tokenized = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True)

In [17]:
input_ids = batch_tokenized["input_ids"].to(device) # list of token ids of dialogs in batch
attention_masks = batch_tokenized["attention_mask"].to(device)
token_type_ids = batch_tokenized["token_type_ids"].to(device)

In [18]:
labels = [label[3] for label in list(labels_input)]
labels = torch.tensor(labels)
labels_complete = torch.tensor(labels_input)

In [19]:
position_ids = position_ids_compute(tokenizer, input_ids, raw, labels_complete)

In [20]:
position_ids = torch.tensor(position_ids)

In [21]:
attach_labels_complete, attach_labels, attach_input_ids, attach_attention_masks, attach_token_type_ids, attach_position_ids = undersample(105753, labels_complete, labels, input_ids, attention_masks, token_type_ids, position_ids)

In [22]:
attach_task_ids = torch.tensor([0 for i in range(len(attach_labels))])

make relation data

In [23]:
inputs, labels_input, raw = input_format(train_data, 10, relations=True)

relation types only...
25773 relations/candidates


In [24]:
num_labels = 17
#including break?

In [25]:
batch_tokenized = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True)

In [26]:
relation_input_ids = batch_tokenized["input_ids"].to(device) # list of token ids of dialogs in batch
relation_attention_masks = batch_tokenized["attention_mask"].to(device)
relation_token_type_ids = batch_tokenized["token_type_ids"].to(device)

In [27]:
labels = [label[3] for label in list(labels_input)]
relation_labels = torch.tensor(labels)
relation_labels_complete = torch.tensor(labels_input)

In [28]:
position_ids = position_ids_compute(tokenizer, relation_input_ids, raw, relation_labels_complete)

In [29]:
relation_position_ids = torch.tensor(position_ids)

In [30]:
relation_task_ids = torch.tensor([1 for i in range(len(relation_labels))])

end pre-processing

In [31]:
from torch import nn

In [32]:
# regroup the attach and relation datasets
pad_value = np.shape(attach_input_ids)[1]-np.shape(relation_input_ids)[1]
relation_input_ids = nn.functional.pad(input=relation_input_ids, pad=(0,pad_value), mode='constant', value=0)
relation_attention_masks = nn.functional.pad(input=relation_attention_masks, pad=(0,pad_value), mode='constant', value=0)
relation_token_type_ids = nn.functional.pad(input=relation_token_type_ids, pad=(0,pad_value), mode='constant', value=0)
relation_position_ids = nn.functional.pad(input=relation_position_ids, pad=(0,pad_value), mode='constant', value=0)

In [33]:
pad_value

26

In [35]:
from multitask_format import Task, MultiTaskModel

In [36]:
attach_task = Task(id = 0, name = 'attach prediction', type = "seq_classification", num_labels=2)
relation_task = Task(id = 1, name = 'relation prediction', type = "seq_classification", num_labels = num_labels)
tasks = [attach_task, relation_task]

In [37]:
input_ids = torch.cat((attach_input_ids,relation_input_ids))
attention_masks = torch.cat((attach_attention_masks, relation_attention_masks))
token_type_ids = torch.cat((attach_token_type_ids, relation_token_type_ids))
position_ids = torch.cat((attach_position_ids, relation_position_ids))
labels = torch.cat((attach_labels ,relation_labels))
task_ids = torch.cat((attach_task_ids ,relation_task_ids))

In [38]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

In [39]:
dataset = TensorDataset(input_ids, attention_masks, token_type_ids, position_ids, labels, task_ids)

In [40]:
train_dataloader = DataLoader(
            dataset,
            sampler = RandomSampler(dataset),
            batch_size = 32
        )

In [41]:
model = MultiTaskModel('bert-base-cased', tasks, len(tokenizer))

In [42]:
model.to(device)

MultiTaskModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(42064, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
     

In [43]:
from transformers import AdamW
import random, time

In [44]:
optimizer = AdamW(model.parameters(),
                  lr = 1.5e-5,
                  eps = 1e-8
                )

training_stats = []

seed_val = 18

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
if device == 'cuda' :
    torch.cuda.manual_seed_all(seed_val)

total_t0 = time.time()



In [45]:
model_path = home + '/models/'

In [46]:
for epoch_i in range(3):

    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, 3))

    t0 = time.time()

    total_train_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):
      #  if (epoch_i == 0 and step not in [242, 245, 615, 1294]) or (epoch_i == 1 and step not in [303, 422, 1105, 1240]) or (epoch_i == 2 and step not in [788, 1002, 1667, 1690]):
      # if (epoch_i == 0 and step not in [123, 425]) or (epoch_i == 1 and step not in [48, 74]) or (epoch_i == 2 and step not in [564, 762]):
    #   if (epoch_i == 0 and step not in [471]) or (epoch_i == 1 and step not in []) or (epoch_i == 2 and step not in []):
          if step % 500 == 0 and not step == 0:
              elapsed = format_time(time.time() - t0)
              print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

          model.zero_grad()

          outputs, embed = model(input_ids=batch[0].to(device),
                  attention_mask=batch[1].to(device),
                  token_type_ids=batch[2].to(device),
                  position_ids=batch[3].to(device),
                  labels=batch[4].to(device),
                  task_ids=batch[5].to(device)
                  # labels=batch[3].to(device),
                  # task_ids=batch[4].to(device)
                  )
          # if step == 123:
          #   print(len(batch[0]))
          #   print(len(batch[1]))
          #   print(len(batch[2]))
          #   print(len(batch[3]))
          #   print(len(batch[4]))
          #   print(len(batch[5]))

          # try:
          #   loss = outputs[0]
          #   print(step, loss)
          #   total_train_loss += loss.item()
          #   loss.backward()
          #   torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
          #   optimizer.step()
          #   avg_train_loss = total_train_loss / len(train_dataloader)
          #   training_time = format_time(time.time() - t0)

          # except RuntimeError as e:
          #   print('{}...Skipping epoch {} step {}'.format(e, epoch_i, step))
          #   continue

          loss = outputs[0]
          print(step, loss)
          total_train_loss += loss.item()
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
          optimizer.step()
          avg_train_loss = total_train_loss / len(train_dataloader)
          training_time = format_time(time.time() - t0)

torch.save({
    'model_state_dict': model.state_dict(),
}, model_path + 'multitask_d10.pth')

0 tensor(1.7936, device='cuda:0', grad_fn=<MeanBackward0>)
1 tensor(1.9270, device='cuda:0', grad_fn=<MeanBackward0>)
2 tensor(1.6034, device='cuda:0', grad_fn=<MeanBackward0>)
3 tensor(1.8977, device='cuda:0', grad_fn=<MeanBackward0>)
4 tensor(1.8006, device='cuda:0', grad_fn=<MeanBackward0>)
5 tensor(1.7963, device='cuda:0', grad_fn=<MeanBackward0>)
6 tensor(1.7757, device='cuda:0', grad_fn=<MeanBackward0>)
7 tensor(1.6586, device='cuda:0', grad_fn=<MeanBackward0>)
8 tensor(1.6789, device='cuda:0', grad_fn=<MeanBackward0>)
9 tensor(1.7983, device='cuda:0', grad_fn=<MeanBackward0>)
10 tensor(1.6627, device='cuda:0', grad_fn=<MeanBackward0>)
11 tensor(1.7174, device='cuda:0', grad_fn=<MeanBackward0>)
12 tensor(1.7397, device='cuda:0', grad_fn=<MeanBackward0>)
13 tensor(1.5707, device='cuda:0', grad_fn=<MeanBackward0>)
14 tensor(1.5589, device='cuda:0', grad_fn=<MeanBackward0>)
15 tensor(1.5331, device='cuda:0', grad_fn=<MeanBackward0>)
16 tensor(1.6968, device='cuda:0', grad_fn=<MeanBa