<a href="https://colab.research.google.com/github/lucarinelli/conditional_text_generation/blob/main/notebooks/Conditional_Text_Generation_Skeleton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

# Configuration

In [None]:
experiment_parameters = dict(
    run_name = "exp1",  # String, experiment name
    use_control_codes = True,  # True/False, enable conditional text generation or do basic text generation
    force_dataset_update = False, # True/False, enable database updates even if it is already present on the file system
    control_codes_type = "special_token",  # "special_token"/"separators"
    use_supercategories = True,  # True/False, add supercategories as control codes 
    use_categories = False, # True/False, add categories as control codes    
    use_control_codes_powerset = False,  # True/False, use powerset of control codes for each caption to augment dataset
    max_control_codes_per_caption = 3,  # positive integer, maximum number of control codes to use with one caption during training
    limited_run = True, # if set to True, the datasets will be reduced in size
    max_train_set_len = 1500,  # positive integer, maximum number of items for the training set used
    max_val_set_len = 1000,  # positive integer, maximum number of items for the validation set used
    model="gpt2",  # we tested "distilgpt2" and "gpt2" for now
    #save_model_path = "OUTPUT",
    #random_seed = 42,  # integer, random seed used anywhere it could be useful to add some determinism
)

%env experiment_parameters = experiment_parameters

# Import utilities

In [None]:
!rm -r conditional_text_generation
!git clone https://github.com/lucarinelli/conditional_text_generation.git

In [None]:
!pip install import-ipynb

%cd conditional_text_generation/notebooks

import import_ipynb
from CtrlUtilities import *

%cd ../..

# WanDB

In [None]:
import wandb

wandb.login()

%env WANDB_PROJECT=ctrl_dry_runs
%env WANDB_ENTITY=polito_aiml2021_textgen
%env WANDB_LOG_MODEL=true
%env WANDB_WATCH=all
%env WANDB_SILENT=true

# Dataset

In [None]:
!mkdir data
DATA_PATH="./data"

data_path=DATA_PATH

dataset_train, _, categories = load_or_setup_dataset(data_path=data_path, split="train")
dataset_val, references, _ = load_or_setup_dataset(data_path=data_path, split="val")

print("There are "+str(len(dataset_train))+" captions considered in total (train)")
print("There are "+str(len(dataset_val))+" captions considered in total (val)")
 
print("The following "+str(len(categories))+" categories are present in the dataset:")
print(categories)

if experiment_parameters["use_control_codes"] and experiment_parameters["control_codes_type"] == "special_token":
    control_codes = []
    for category in categories:
        control_codes += ["<CTRL:"+category.replace(" ","_")+">"]

    print("Processed control codes:")
    print(control_codes)

In [None]:
write_json_chunks(dataset_train, "train", data_path, chunk_size)
write_json_chunks(dataset_val, "val", data_path, chunk_size)

In [None]:
from datasets import load_dataset, Dataset
import glob

dataset_train, dataset_val = load_dataset('json', data_files={'train': glob.glob('./data/captions_train_*.json'), 'val': glob.glob('./data/captions_val_*.json')}, split=['train', 'val'], field="data")
print("Augmented dataset has: "+str(len(dataset_train))+" train elements and "+str(len(dataset_val))+" validation elements")

if experiment_parameters["limited_run"]: # shuffle and cut the datasets
  dataset_train = dataset_train.shuffle(42).select(range(experiment_parameters["max_train_set_len"]))
  dataset_val = dataset_val.shuffle(42).select(range(experiment_parameters["max_val_set_len"]))
  print("We take only a small part of that: "+str(len(dataset_train))+" train elements and "+str(len(dataset_val))+" validation elements")
else: # just shuffle them
  dataset_train = dataset_train.shuffle(42)
  dataset_val = dataset_val.shuffle(42)
  print("Train elements: "+str(len(dataset_train))+"\nValidation elements: "+str(len(dataset_val)))

# Tokenization

In [None]:
from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained(experiment_parameters['model'])
tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer before added special tokens "+str(len(tokenizer)))

if experiment_parameters["use_control_codes"] and experiment_parameters["control_codes_type"] == "special_token":
    special_tokens_dict = {'additional_special_tokens': control_codes}
    num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    print("added "+str(num_added_toks)+" tokens to the pretrained tokenizer")

In [None]:
dataset_train_encoded = dataset_train.map(encode, batched=True)
dataset_val_encoded = dataset_val.map(encode, batched=True)

# Model

In [None]:
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained(experiment_parameters['model'], pad_token_id=tokenizer.eos_token_id)
model.resize_token_embeddings(len(tokenizer))

# Training

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./data/results",  # output directory
    save_total_limit=3,
    num_train_epochs=3,  # total # of training epochs
    per_device_train_batch_size=64,  # batch size per device during training
    per_device_eval_batch_size=1,  # batch size for evaluation
    warmup_steps=500,  # number of warmup steps for learning rate scheduler
    weight_decay=0.01,
    logging_dir='./data/logs',  # directory for storing logs
    evaluation_strategy="epoch",
    report_to="wandb",
    load_best_model_at_end=True,
    remove_unused_columns=False
)

In [None]:
import random
import torch
import numpy as np

seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [None]:
dataset_train_encoded.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
dataset_val_encoded.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'image_id'])

trainer = MyTrainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=dataset_train_encoded,         # training dataset
    eval_dataset=dataset_val_encoded,
    compute_metrics=compute_metrics,
    )

In [None]:
trainer.train()

config = wandb.config
config.update(experiment_parameters)

In [None]:
trainer.save_model("./data/results")
wandb.finish()