In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
from transformers import AutoTokenizer, RobertaConfig, RobertaModel, EncoderDecoderModel, AdamW, Trainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from torch.utils.data import DataLoader

Load model

In [3]:
device = 'cuda:0'
# device = 'cpu'

In [4]:
encoder_name = 'deepset/gbert-base'
tokenizer = AutoTokenizer.from_pretrained(encoder_name)
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

decoder_model_path = "../models/decoder-initialization"
decoder_config = RobertaConfig(
    vocab_size=len(tokenizer),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)
decoder_model = RobertaModel(decoder_config)
decoder_model.save_pretrained(decoder_model_path)

model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_name,
    decoder_model_path,
)
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.to(device)

Some weights of the model checkpoint at deepset/gbert-base were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForCausalLM were not initialized from the model checkpoint at ../models/decoder-initialization and are newly initialized: ['encoder

EncoderDecoderModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31102, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [12]:
input_ids = tokenizer(
    "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side.During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was  finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft).Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.",
    return_tensors="pt",

).input_ids

labels = tokenizer(
    "the eiffel tower surpassed the washington monument to become the tallest structure in the world. it was the first structure to reach a height of 300 metres in paris in 1930. it is now taller than the chrysler building by 5. 2 metres ( 17 ft ) and is the second tallest free - standing structure in paris.",
    return_tensors="pt",
).input_ids


input_ids = input_ids.to(device)
labels = labels.to(device)

# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss
print(loss)

tensor(10.0595, device='cuda:0', grad_fn=<NllLossBackward0>)


In [13]:
prefix = ''
max_source_length=510
max_target_length=510
padding = False
ignore_pad_token_for_loss=True


def preprocess_function(examples):
    inputs = examples["source"]
    targets = examples["target"]
    inputs = [prefix + inp for inp in inputs]
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length" and ignore_pad_token_for_loss:
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [14]:
from datasets import load_dataset

train_file = '../data/processed/d2h-v1-aligned-para/train.json'
validation_file = '../data/processed/d2h-v1-aligned-para/val.json'
test_file = '../data/processed/d2h-v1-aligned-para/test.json'

data_files = {}
if train_file is not None:
    data_files["train"] = train_file
    extension = train_file.split(".")[-1]
if validation_file is not None:
    data_files["validation"] = validation_file
    extension = validation_file.split(".")[-1]
if test_file is not None:
    data_files["test"] = test_file
    extension = test_file.split(".")[-1]
raw_datasets = load_dataset(
    extension,
    data_files=data_files,
)

Using custom data configuration default-1b258437e2dfe0ac
Reusing dataset json (/homes/jan/.cache/huggingface/datasets/json/default-1b258437e2dfe0ac/0.0.0/da492aad5680612e4028e7f6ddc04b1dfcec4b64db470ed7cc5f2bb265b9b6b5)


  0%|          | 0/3 [00:00<?, ?it/s]

In [15]:
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["validation"]
predict_dataset = raw_datasets['test']

train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
)

eval_dataset = eval_dataset.map(
    preprocess_function,
    batched=True,
)

predict_dataset = predict_dataset.map(
    preprocess_function,
    batched=True,
)

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [16]:
print(max(sum(train_dataset['input_ids'], [])))
print(max(sum(eval_dataset['input_ids'], [])))
print(max(sum(predict_dataset['input_ids'], [])))

print(max(sum(train_dataset['labels'], [])))
print(max(sum(eval_dataset['labels'], [])))
print(max(sum(predict_dataset['labels'], [])))

31027
30975
30975
30975
30975
30975


In [11]:
train_tokenized = train_dataset.remove_columns(['resource.id', 'type', 'target_index', 'target_title_span', 'target_text_span', 'target', 'source_index', 'source_title_span', 'source_text_span', 'source'])
optimizer = AdamW(model.parameters(), lr=3e-5)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
train_dataloader = DataLoader(
    train_tokenized,
    shuffle=False,
    collate_fn=data_collator,
    batch_size=8,
)

model.train()
for i, batch in enumerate(train_dataloader):
    for key in batch:
        batch[key] = batch[key].to("cuda:0")

    try:
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward(loss)
    
        optimizer.step()
        optimizer.zero_grad()
        print(f'{i}/{len(train_dataloader)}')    
    except RuntimeError:
        print('error in batch =', i)
        print(batch)



0/405
1/405
2/405
3/405
4/405
5/405
6/405
7/405
8/405
9/405
10/405
11/405
12/405
13/405
14/405
15/405
16/405
17/405
18/405
19/405
20/405
21/405
22/405
23/405
24/405
25/405
26/405
27/405
28/405
29/405
30/405
31/405
32/405
33/405
34/405
35/405
36/405
37/405
38/405
39/405
40/405
41/405
42/405
43/405
44/405
45/405
46/405
47/405
48/405
49/405
50/405
51/405
52/405
53/405
54/405
55/405
56/405
57/405
58/405
59/405
60/405
61/405
62/405
63/405
64/405
65/405
66/405
67/405
68/405
69/405
70/405
71/405
72/405
73/405
74/405
75/405
76/405
77/405
78/405
79/405
80/405
81/405
82/405
83/405
84/405
85/405
86/405
87/405
88/405
89/405
90/405
91/405
92/405
93/405
94/405
95/405
96/405
97/405
98/405
99/405
100/405
101/405
102/405
103/405
104/405
105/405
106/405
107/405
108/405
109/405
110/405
111/405
112/405
113/405
114/405
115/405
116/405
117/405
118/405
119/405
120/405
121/405
122/405
123/405
124/405
125/405
126/405
127/405
128/405
129/405
130/405
131/405
132/405
133/405
134/405
135/405
136/405
137/405
138/40