<a href="https://colab.research.google.com/github/lucarinelli/conditional_text_generation/blob/main/notebooks/COCO_Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import utilities

In [None]:
!rm -r conditional_text_generation
!git clone https://github.com/lucarinelli/conditional_text_generation.git

In [None]:
!pip install import-ipynb

%cd conditional_text_generation/notebooks

import import_ipynb

%load_ext autoreload
%autoreload 2

from CtrlUtilities import *

%cd ../..

# Configuration

In [None]:
experiment_parameters["run_name"] = "exp1"  # String, experiment name
experiment_parameters["use_control_codes"] = True  # True/False, enable conditional text generation or do basic text generation
experiment_parameters["force_dataset_update"] = True # True/False, enable database updates even if it is already present on the file system
experiment_parameters["control_codes_type"] = "special_token" # "special_token"/"separators"
experiment_parameters["use_supercategories"] = True  # True/False, add supercategories as control codes 
experiment_parameters["use_categories"] = False # True/False, add categories as control codes    
experiment_parameters["use_control_codes_powerset"] = False  # True/False, use powerset of control codes for each caption to augment dataset
experiment_parameters["max_control_codes_per_caption"] = 3  # positive integer, maximum number of control codes to use with one caption during training
experiment_parameters["limited_run"] = True # if set to True, the datasets will be reduced in size
experiment_parameters["max_train_set_len"] = 1500  # positive integer, maximum number of items for the training set used
experiment_parameters["max_val_set_len"] = 1000  # positive integer, maximum number of items for the validation set used
experiment_parameters["model"]= "gpt2"  # we tested "distilgpt2" and "gpt2" for now
    #save_model_path = "OUTPUT",
    #random_seed = 42,  # integer, random seed used anywhere it could be useful to add some determinism


In [None]:
!mkdir data
DATA_PATH="./data"
data_path=DATA_PATH

# Coco Analysis

In [None]:
def computeAverageOnDataset(dataset, fieldExtractor):
  count = 0
  minV = None
  maxV = None
  for i in dataset:
    l = len(fieldExtractor(i)) 
    count += l
    if minV is None or l < minV: minV = l
    if maxV is None or l > maxV: maxV = l
  return count / len(dataset), minV, maxV

In [None]:
def logControlCodeAnalysis(type):
  print("Creating dataset using {}".format(type))
  dataset_train, _, _ = load_or_setup_dataset(data_path=data_path, split="train")
  number_of_categories = list(map(lambda e: len(e["categories"]), dataset_train))
  average, min, max = computeAverageOnDataset(dataset_train, lambda e: e["categories"])
  percentile = 100 - len(list(filter(lambda nc: nc > average, number_of_categories))) / len(dataset_train) *100

  print("For {} the average number of control codes per caption is {}.\nIt's the {:.0f}th percentile. Minimum is {}. Maximum is {}".format(type, average, percentile, min, max))

In [None]:
experiment_parameters["use_supercategories"] = True
experiment_parameters["use_categories"] = False
logControlCodeAnalysis("supercategories only")

In [None]:
experiment_parameters["use_supercategories"] = False
experiment_parameters["use_categories"] = True
logControlCodeAnalysis("categories only")

In [None]:
experiment_parameters["use_supercategories"] = True
experiment_parameters["use_categories"] = True
logControlCodeAnalysis("categories and supercategories")

In [None]:
experiment_parameters["use_supercategories"] = True
experiment_parameters["use_categories"] = False
dataset_train, _, categories = load_or_setup_dataset(data_path=data_path, split="train")
averageChar, minC, maxC = computeAverageOnDataset(dataset_train, lambda e: e["caption"])
print("Average length of captions is {} chars. Min {} and max {}".format(averageChar, minC, maxC))
averageWords, minW, maxW = computeAverageOnDataset(dataset_train, lambda x: x["caption"].split())
print("Average length of captions is {} words. Min {} and max {}".format(averageWords, minW, maxW))

In [None]:
print("Database has {} entries.".format(len(dataset_train)))