In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import sys
sys.path.append('/content/drive/MyDrive/SloganGenerator/venv/lib/python3.10/site-packages')

# Create the dataset

In [None]:
import transformers
from datasets import load_dataset, load_metric

In [None]:
slogan_dataset_raw = load_dataset("/content/drive/MyDrive/SloganGenerator/dataset/", data_files="merged.csv")

Downloading and preparing dataset csv/dataset to /root/.cache/huggingface/datasets/csv/dataset-81ac2a11ea7f9045/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/dataset-81ac2a11ea7f9045/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
print(slogan_dataset_raw)

DatasetDict({
    train: Dataset({
        features: ['company', 'slogan'],
        num_rows: 11902
    })
})


In [None]:
slogan_dataset = slogan_dataset_raw["train"].train_test_split(train_size=0.9, seed=20)
slogan_dataset["validation"] = slogan_dataset.pop("test")
print(slogan_dataset)

DatasetDict({
    train: Dataset({
        features: ['company', 'slogan'],
        num_rows: 10711
    })
    validation: Dataset({
        features: ['company', 'slogan'],
        num_rows: 1191
    })
})


In [None]:
import nltk
nltk.download('punkt')
import string
from transformers import AutoTokenizer

MAX_LENGTH = 64
model_checkpoint = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


Downloading (…)lve/main/config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [None]:
slogan_dataset = slogan_dataset.filter(
    lambda example: (example['company'] != None) and
    (example['slogan'] != None)
)

Filter:   0%|          | 0/10711 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1191 [00:00<?, ? examples/s]

In [None]:
def preprocess_data(examples):
  inputs = [ex for ex in examples['company']]
  model_inputs = tokenizer(inputs, max_length=MAX_LENGTH, truncation=True)

  # Setup the tokenizer for targets
  with tokenizer.as_target_tokenizer():
    labels = tokenizer(examples["slogan"], max_length=MAX_LENGTH, 
                       truncation=True)

  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

In [None]:
tokenized_dataset = slogan_dataset.map(preprocess_data, batched=True, remove_columns=slogan_dataset["train"].column_names)

print(tokenized_dataset['train'])

Map:   0%|          | 0/10711 [00:00<?, ? examples/s]



Map:   0%|          | 0/1191 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 10711
})


# Instantiate the model

In [None]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

batch_size = 8
model_name = "bart_merged"
model_dir = f"/content/drive/MyDrive/SloganGenerator/models/{model_name}"

In [None]:
args = Seq2SeqTrainingArguments(
    model_dir,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    learning_rate=4e-5,
    per_device_train_batch_size=batch_size*4,
    per_device_eval_batch_size=batch_size*8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="rouge1"
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer)

In [None]:
metric = load_metric("rouge")

In [None]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip()))
                      for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) 
                      for label in decoded_labels]
    
    # Compute ROUGE scores
    result = metric.compute(predictions=decoded_preds, references=decoded_labels,
                            use_stemmer=True)

    # Extract ROUGE f1 scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length to metrics
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id)
                      for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [None]:
def model_init():
    return AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

trainer = Seq2SeqTrainer(
    model_init=model_init,
    args=args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# Training

In [None]:
trainer.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,2.9981,2.534993,18.6389,7.4846,17.9621,18.1984,10.9555
2,2.5207,2.464668,18.0012,7.1387,17.366,17.5796,10.5256
3,2.3274,2.456346,18.796,7.5338,18.0531,18.1863,10.293


TrainOutput(global_step=1005, training_loss=2.615373450132152, metrics={'train_runtime': 238.348, 'train_samples_per_second': 134.815, 'train_steps_per_second': 4.217, 'total_flos': 799948629043200.0, 'train_loss': 2.615373450132152, 'epoch': 3.0})

# Testing

In [None]:
import tqdm

In [None]:
model_name = "bart_merged/checkpoint-1005"
model_dir = f"/content/drive/MyDrive/SloganGenerator/models/{model_name}"

tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)

max_input_length = 64

In [None]:
inputs = ["Olivetti is an Italian manufacturer of computers, tablets, smartphones, printers and other such business products as calculators and fax machines."]

inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt")

slogan = []

for i in tqdm.tqdm(range(20)):
  output = model.generate(**inputs, num_beams=1, do_sample=True, min_length=10, max_length=64)
  decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
  slogan.append(decoded_output)

print()
for s in slogan:
  print(s)

100%|██████████| 20/20 [00:07<00:00,  2.52it/s]


Olivetti. Where everything counts.
Olivetti. We're the People
Olivetti. A more balanced future.
Olivetti. The power of machines.
Olivetti. Technology to help you succeed.
Olivetti. Innovative technology.
Olivetti. Making technology easier.
Olivetti. The Italian specialist in information technology.
Olivetti. Smart business. Smart future.
Olivetti. Tools for business success.
Olivetti. All things. Everything else.
Olivetti. Smart machines for all.
Olivetti. The power of technology.
Olivetti. Technology where a value is found.
Olivetti. More intelligent designs, more powerful customers.
Olivetti. Computers for the global communications market.
Olivetti. Real business solutions. Real value. Real experience.
Olivetti. Made to do.
Olivetti. The modern technology manufacturer.
Olivetti. The Next Generation of Computers.



