In [1]:
import sys
sys.path.append('../scripts')

In [2]:
import os
# Disable weights and biases (if installed)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
#os.environ["WANDB_PROJECT"] = "ggponc_ellipses"
os.environ["WANDB_DISABLED"] = "true"

In [3]:
from pathlib import Path

from generative.dataset import EllipsesDataset
from generative.run_experiment import get_training_args, get_trainer, get_tokenizer

  from .autonotebook import tqdm as notebook_tqdm


# Training

In [4]:
import hydra
from hydra import compose, initialize

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=Path('..'), job_name='foo', version_base='1.1')
config = compose(config_name='experiment.yaml')
config.model_name = "google/mt5-base"
config.metrics = ['exact_match', 'google_bleu']
config.learning_rate = 5e-5

In [5]:
training_args = get_training_args(config, report_to=None)
tokenizer = get_tokenizer(config)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [6]:
import pandas as pd
df_ellipses = pd.read_csv('../../ggponc_annotation/notebooks/ggponc_ccnfs.tsv', sep='\t')
df_controls = pd.read_csv('../../ggponc_annotation/notebooks/ggponc_cnfs_controls_small.tsv', sep='\t')

In [7]:
train_cnfs = df_ellipses[df_ellipses.split == 'train']
valid_cnfs = df_ellipses[df_ellipses.split == 'dev']
test_cnfs = df_ellipses[df_ellipses.split == 'test']

train_controls = df_controls[df_controls.split == 'train']
valid_controls = df_controls[df_controls.split == 'dev']
test_controls = df_controls[df_controls.split == 'test']

len(train_cnfs), len(valid_cnfs),  len(test_cnfs), len(train_controls), len(valid_controls), len(test_controls)

(2241, 462, 462, 2269, 447, 449)

In [145]:
train_data = EllipsesDataset(pd.concat([train_cnfs.raw_sentence]), pd.concat([train_cnfs.full_resolution]), tokenizer)
val_data = EllipsesDataset(pd.concat([valid_cnfs.raw_sentence]), pd.concat([valid_cnfs.full_resolution]), tokenizer)
test_data = EllipsesDataset(pd.concat([test_cnfs.raw_sentence]), pd.concat([test_cnfs.full_resolution]), tokenizer)

In [146]:
#train_data = EllipsesDataset(pd.concat([train_cnfs.raw_sentence, train_controls.raw_sentence]), pd.concat([train_cnfs.full_resolution, train_controls.raw_sentence]), tokenizer)
#val_data = EllipsesDataset(pd.concat([valid_cnfs.raw_sentence, valid_controls.raw_sentence]), pd.concat([valid_cnfs.full_resolution, valid_controls.raw_sentence]), tokenizer)

In [147]:
training_args.num_train_epochs = 10

In [148]:
trainer = get_trainer(config, tokenizer, training_args, train_data, val_data)

loading configuration file config.json from cache at /home/Florian.Borchert/.cache/huggingface/hub/models--google--mt5-base/snapshots/d86816880b5acc27e697e52bc237e816dc828b17/config.json
Model config MT5Config {
  "_name_or_path": "google/mt5-base",
  "architectures": [
    "MT5ForConditionalGeneration"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "mt5",
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "tokenizer_class": "T5Tokenizer",
  "transformers_version": "4.23.1",
  "use_cache": true,
  "vocab_size": 250112
}

loading weights fil

In [None]:
trainer.train()

***** Running training *****
  Num examples = 2241
  Num Epochs = 10
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2810


Epoch,Training Loss,Validation Loss,Exact Match,Google Bleu
1,4.0209,0.717223,0.0,0.374842
2,0.4349,0.116564,0.441558,0.792539
3,0.21,0.073871,0.65368,0.893547
4,0.1311,0.069113,0.707792,0.890707
5,0.1252,0.053545,0.731602,0.904029


***** Running Evaluation *****
  Num examples = 462
  Batch size = 8
***** Running Evaluation *****
  Num examples = 462
  Batch size = 8
***** Running Evaluation *****
  Num examples = 462
  Batch size = 8
***** Running Evaluation *****
  Num examples = 462
  Batch size = 8
***** Running Evaluation *****
  Num examples = 462
  Batch size = 8
***** Running Evaluation *****
  Num examples = 462
  Batch size = 8


In [13]:
eval_metrics = trainer.evaluate(val_data)
print(eval_metrics)

test_metrics = trainer.evaluate(test_data, metric_key_prefix='test')
print(test_metrics)

***** Running Evaluation *****
  Num examples = 909
  Batch size = 8


***** Running Evaluation *****
  Num examples = 462
  Batch size = 8


{'eval_loss': 0.0302598737180233, 'eval_exact_match': 0.8305830583058306, 'eval_google_bleu': 0.9631266857901272, 'eval_runtime': 137.5663, 'eval_samples_per_second': 6.608, 'eval_steps_per_second': 0.829, 'epoch': 10.0}
{'eval_loss': 0.05185458064079285, 'eval_exact_match': 0.7705627705627706, 'eval_google_bleu': 0.9249460749068567, 'eval_runtime': 89.1542, 'eval_samples_per_second': 5.182, 'eval_steps_per_second': 0.651, 'epoch': 10.0}


# Show output

In [23]:
from transformers import Text2TextGenerationPipeline

In [121]:
pipeline = Text2TextGenerationPipeline(model=trainer.model, tokenizer=tokenizer, max_length=500, device=0)

In [122]:
my_sample = valid_cnfs

In [123]:
my_sample.raw_sentence.iloc[0]

'Hauptrisikofaktoren für das Auftreten eines Mundhöhlenkarzinoms sind chronischer Tabak- oder Alkoholabusus, wesentlich seltener auch andere Faktoren.'

In [124]:
my_sample.full_resolution.iloc[0]

'Hauptrisikofaktoren für das Auftreten eines Mundhöhlenkarzinoms sind chronischer Tabakabusus oder Alkoholabusus, wesentlich seltener auch andere Faktoren.'

In [126]:
%%time
out = pipeline(list(my_sample.raw_sentence.values))

CPU times: user 4min 23s, sys: 52.1 ms, total: 4min 23s
Wall time: 4min 23s


In [127]:
%%html
<style>
div.jp-OutputArea pre {
    white-space: pre;
}
</style>

In [128]:
from collections import Counter, defaultdict
import difflib

def get_errors(predictions, gt_resolutions, original_sentences):
    d = difflib.Differ()
    errors = defaultdict(lambda: 0)

    for pred_gen, true, sent in zip(predictions, gt_resolutions, original_sentences):
        if pred_gen == true:
            errors['tp'] += 1
        elif pred_gen == sent:
            errors['fn'] += 1        
        else:
            op_codes = difflib.SequenceMatcher(None, true, pred_gen).get_opcodes()
            counts = Counter([o[0] for o in op_codes])
            del counts["equal"]
            if len(counts) > 1:
                errors['complex'] += 1
            else:
                errors[list(counts.keys())[0]] += 1
    assert sum([v for v in errors.values()]) == len(predictions)
    return dict(errors)

In [139]:
gen_text = [o['generated_text'] for o in out]
errors = get_errors(gen_text, my_sample.full_resolution, my_sample.raw_sentence)
print(errors)
print({k : v / len(my_sample) for k, v in errors.items()})

{'tp': 354, 'replace': 20, 'delete': 39, 'complex': 16, 'insert': 15, 'fn': 18}
{'tp': 0.7662337662337663, 'replace': 0.04329004329004329, 'delete': 0.08441558441558442, 'complex': 0.03463203463203463, 'insert': 0.032467532467532464, 'fn': 0.03896103896103896}


In [133]:
from evaluate import load