<a href="https://colab.research.google.com/github/e-caste/aiml-project/blob/master/AIML_project_with_more_batches.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **User Intention Prediction via Finetuned Transformer Models** 

Run the following cell to automatically create the needed directories in your Google Drive (if you're using Google Colab) or locally, depending on where your notebook is running at the moment.

In [None]:
is_in_colab = 'google.colab' in str(get_ipython())

if is_in_colab:
    from google.colab import drive
    drive.mount('/content/drive')
    ! [[ -d /content/drive/MyDrive ]] && echo "Your Google Drive is mounted" || echo "Please try re-running this cell"; \
    PROJECT_DIR=drive/MyDrive/aiml-project; \
    \
    [[ ! -d $PROJECT_DIR ]] && mkdir $PROJECT_DIR && echo "Created directory in your Google Drive ($PROJECT_DIR in the left panel)" || "$PROJECT_DIR already exists"; \
    [[ ! -d $PROJECT_DIR/RESULTS ]] && mkdir $PROJECT_DIR/RESULTS && mkdir $PROJECT_DIR/MODELS && echo "Created RESULTS and MODELS directories" || echo "RESULTS and MODELS directories already exist, skipping creation"; \
    [[ ! -d $PROJECT_DIR/data ]] && mkdir -p $PROJECT_DIR/data/simmc_fashion && echo -n "Created data directory" || echo -n "Data directory already exists, skipping creation"; \
    echo ", please upload the dataset into $PROJECT_DIR/data/simmc_fashion"; \
    echo; echo "Installing transformers Python library:"; echo; pip install transformers
else:
    import os
    if os.path.isdir("MODELS") and os.path.isdir("RESULTS"):
        print("RESULTS and MODELS directories already exist, skipping creation")
    else:
        if not os.path.isdir("RESULTS"):
            os.mkdir("RESULTS")
            print("Created RESULTS directory")
        if not os.path.isdir("MODELS"):
            os.mkdir("MODELS")
            print("Created MODELS directory")
    if not os.path.isdir("data"):
        os.mkdir("data")
        os.mkdir("data/simmc_fashion")
        print("Created data directory, please load the dataset into data/simmc_fashion")
    else:
        print("Data directory already exists, skipping creation. Check that your dataset is in data/simmc_fashion")


In [None]:
%%bash
! echo; echo "Printing assigned GPU stats: "; echo; nvidia-smi

# DATASET PREPROCESSING

To feed the correct input to our model, we first need to preprocess the dataset. If you need to do so:
1. please check that the Tokenizer imported into `tokenize_user_utterance.py` corresponds to your model (e.g. BertTokenizer for BertModel, DistilBertTokenizer for DistilBertModel, etc.)
2. check that `NUM_ATTRIBUTES` is set to the correct value in `enums/Attribute.py` (i.e. if `DISCARD_ATTRIBUTES_BELOW_COUNT` is set to 0 it should be 33, if `DABC` is 40 it should be 18, if `DABC` is 70 it should be 14)
3. check that `CONT` in `preprocess_dataset.sh` loops over your desired value(s) for the model's `HISTORY` parameter
4. then _uncomment the following cell_

In [None]:
%%bash
# ! ./preprocess_dataset.sh || return 0  # for some reason Jupyter thinks that Bash is returning 1, suppress the warning

# IMPORTS



In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
import collections
from enum import Enum
import math
import torch
import os
import transformers
from transformers import AdamW
from tqdm import tqdm
import torch.nn as nn
from transformers import BertModel, BertConfig
from sklearn import metrics
import warnings

warnings.filterwarnings('ignore')

# NAME FILES FOR ANALYSIS

In [None]:
PROJECT_DIR = "/content/drive/MyDrive/aiml-project/" if is_in_colab else ""
DATA_DIR = f"{PROJECT_DIR}data/simmc_fashion/"

TRAIN_FILE = f"{DATA_DIR}fashion_train_dials_clean_info.json"
TEST_FILE = f"{DATA_DIR}fashion_devtest_dials_clean_info.json"
DEV_FILE = f"{DATA_DIR}fashion_dev_dials_clean_info.json"

## ACTION ENUM

In [None]:
NUM_ACTIONS = 5

class Action(Enum):
    SpecifyInfo = 0
    SearchDatabase = 1
    AddToCart = 2
    SearchMemory = 3
    Nothing = 4

    @classmethod
    def length(cls):
        return NUM_ACTIONS

    @classmethod
    def from_str(cls, action):
        if action == "SearchDatabase":
            return cls.SearchDatabase.value
        elif action == "SpecifyInfo":
            return cls.SpecifyInfo.value
        elif action == "AddToCart":
            return cls.AddToCart.value
        elif action == "SearchMemory":
            return cls.SearchMemory.value
        else:
            return cls.Nothing.value

    @classmethod
    def from_number(cls, value):
        if value == cls.SpecifyInfo.value:
            return "SpecifyInfo"
        elif value == cls.SearchDatabase.value:
            return "SearchDatabase"
        elif value == cls.AddToCart.value:
            return "AddToCart"
        elif value == cls.SearchMemory.value:
            return "SearchMemory"
        else:
            return "None"

# DATASET ANALYSIS

In [None]:
files = [TRAIN_FILE, TEST_FILE, DEV_FILE]
for file in files:
    print(f"Analyse {file} file")
    with open(file, "r") as file_id:
        dials_clean = json.load(file_id)

    actions = []
    actions_counts = []
    actions_num = []
    attributes = []
    attributes_counts = []
    count_turns = []

    MAX_TURNS = 13
    for i in range(MAX_TURNS):
        count_turns.append(0)
        for dial in dials_clean:
            if len(dial) == i + 1:
                count_turns[i] += 1

    for dial in dials_clean:
        for turn in dial:
            if turn["action"] not in actions:
                actions.append(turn["action"])
                actions_counts.append(1)
                actions_num.append(Action.from_str(turn["action"]))
            else:
                actions_counts[actions.index(turn["action"])] += 1

            for attr in turn["attributes"]:
                if attr not in attributes:
                    attributes.append(attr)
                    attributes_counts.append(1)
                else:
                    attributes_counts[attributes.index(attr)] += 1

    actions_num, weight_action = (list(tup) for tup in zip(*sorted(zip(actions_num, actions_counts))))
                    
    #count distribution of actions for each attribute
    attrs_voc = []
    for i in range(len(attributes)):
        attrs_voc.append({"name": attributes[i],
                          "SearchDatabase": 0,
                          "SpecifyInfo": 0,
                          "AddToCart": 0,
                          "SearchMemory": 0,
                          "None": 0
                          })
        

    for i in range(len(attributes)):
        for dial in dials_clean:
            for turn in dial:
                if attributes[i] in turn["attributes"]:
                    attrs_voc[i][turn["action"]] += 1

    attrs_voc = sorted(attrs_voc, key=lambda d: d["SearchDatabase"] + d["SpecifyInfo"] + d["AddToCart"] + d["SearchMemory"] + d["None"])
    # sort actions by descending frequency
    actions_counts, actions = (list(tup) for tup in zip(*sorted(zip(actions_counts, actions))))
    
    # sort attributes by descending frequency
    attributes_counts, attributes  = (list(tup) for tup in zip(*sorted(zip(attributes_counts, attributes))))
    
    # list of tuples [(attribute, frequency), ...]
    if file == TRAIN_FILE:
        attr_attr_counts = sorted(zip(attributes_counts, attributes), reverse=True)
        count = 70
        print(attr_attr_counts, len(attr_attr_counts))
        attr_filtered = list(filter(lambda elem: elem[0] < count, attr_attr_counts))
        print(f"Attributes with count < {count}: {attr_filtered} with length {len(attr_filtered)}")

    # barplot for dial length distribution
    y_act_pos = np.arange(len(count_turns))
    plt.figure(dpi=300, tight_layout=True)
    plt.barh(y_act_pos, count_turns)
    plt.yticks(y_act_pos, range(1,MAX_TURNS+1))
    plt.grid(alpha=0.2)
    plt.ylabel('Dialog turn count')
    plt.xlabel('Number of occurrences')
    plt.show()


    # barplot for actions
    y_act_pos = np.arange(len(actions))
    plt.figure(dpi=300, tight_layout=True)
    plt.barh(y_act_pos, actions_counts)
    plt.yticks(y_act_pos, actions)
    plt.grid(alpha=0.2)
    plt.ylabel('Action label')
    plt.xlabel('Number of occurrences')
    plt.show()


    # barplot for attributes
    y_attr_pos = np.arange(len(attributes))
    plt.figure(dpi=300, tight_layout=True, figsize=(10,6))
    plt.barh(y_attr_pos, attributes_counts)
    plt.yticks(y_attr_pos, attributes)
    plt.grid(alpha=0.2)
    plt.ylabel('Attribute label')
    plt.xlabel('Number of occurrences')
    plt.semilogx()
    plt.show()

    # barplot for actions distribution for each attributes
    a0 = []
    a1 = []
    a2 = []
    a3 = []
    a4 = []

    for i in range(len(attributes)):
        a0.append(attrs_voc[i]["SearchDatabase"])
        a1.append(attrs_voc[i]["SpecifyInfo"])
        a2.append(attrs_voc[i]["AddToCart"])
        a3.append(attrs_voc[i]["SearchMemory"])
        a4.append(attrs_voc[i]["None"])
        
    sum1 = np.add(a0, a1).tolist()
    sum2 = np.add(sum1, a2).tolist()
    sum3 = np.add(sum2, a3).tolist()
    
    plt.figure(dpi=300, tight_layout=True, figsize=(10,6))
    plt.barh(y_attr_pos, a0, color='blue', edgecolor='black', label='SearchDatabase')
    plt.barh(y_attr_pos, a1, left=a0, color='red', edgecolor='black', label='SpecifyInfo')
    plt.barh(y_attr_pos, a2, left=sum1, color='yellow', edgecolor='black', label='AddToCart')
    plt.barh(y_attr_pos, a3, left=sum2, color='green', edgecolor='black', label='SearchMemory')
    plt.barh(y_attr_pos, a4, left=sum3, color='orange', edgecolor='black', label='None')
    plt.yticks(y_attr_pos, attributes)
    plt.ylabel('Attribute label')
    plt.xlabel('Number of occurrences')
    plt.grid(alpha=0.2)
    plt.semilogx()
    plt.legend()
    plt.show()

## SET HYPERPARAMETERS

In [None]:
LEARNING_RATE = 9e-7
DISCARD_ATTRIBUTES_BELOW_COUNT = 70  # you should re-run the pre-process script if you modify this
FEATURES = 768  # 768 for bert-base-uncased and 1024 for bert-large-uncased
ACTION_DROPOUT = .3
ATTRIBUTES_DROPOUT = .1
MAX_EPOCHS = 13
TRAIN_BATCH_SIZE = 4
TEST_BATCH_SIZE = 4
HISTORY = 6  # number of preceding utterances to consider

HISTORY_PAD = f"_history{HISTORY}" if HISTORY else ""
TRAIN_FILE = f"{DATA_DIR}fashion_train_dials_info_tokenized{HISTORY_PAD}.json"
TEST_FILE = f"{DATA_DIR}fashion_devtest_dials_info_tokenized{HISTORY_PAD}.json"
DEV_FILE = f"{DATA_DIR}fashion_dev_dials_info_tokenized{HISTORY_PAD}.json"
DEMO_FILE = f"{DATA_DIR}train_demo_small.json"

# CREATE DATALOADER AND ENUMS

## DATALOADER

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, file_dials):
        """Constructor of class
           Args:
                file_dials: path of file_dials
        """
        super(Dataset, self).__init__()

        # create a vectors of dials with turns
        with open(file_dials, "r") as file_id:
            self.dials = json.load(file_id)

        # create a vector of single turn (without considering the dial)
        self.turns = []
        self.phrases = {
            "dialog_id":[],
            "turn_idx":[],
            "input_ids": [],
            "attention_mask": []
        }
        self.labels = {
           "action": [],
           "attributes": []
        }
        for dial in self.dials:
            for turn in dial:
                self.turns.append(turn)

        for turn in self.turns:
            self.phrases["input_ids"].append(turn["user_utterance"]["input_ids"])
            self.phrases["attention_mask"].append(turn["user_utterance"]["attention_mask"])
            self.phrases["turn_idx"].append(turn["turn_idx"])
            self.phrases["dialog_id"].append(turn["dialog_id"])
            self.labels["action"].append(turn["action"])
            self.labels["attributes"].append(turn["attributes"])

        self.phrases["input_ids"] = torch.tensor(self.phrases["input_ids"], dtype=torch.long)
        self.phrases["attention_mask"] = torch.tensor(self.phrases["attention_mask"], dtype=torch.long)
        self.phrases["dialog_id"] = torch.tensor(self.phrases["dialog_id"], dtype=torch.long)
        self.phrases["turn_idx"] =  torch.tensor(self.phrases["turn_idx"], dtype=torch.long)
        self.labels["action"] = torch.tensor(self.labels["action"], dtype=torch.long)
        self.labels["attributes"] = torch.tensor(self.labels["attributes"], dtype=torch.float)

    def __len__(self):
        """Method to return the length of turns
             Returns:
                length of turns
        """
        return len(self.turns)

    def __getitem__(self, index_turn):
        """"Method to return an phrase of dataset (without considering the dials)
             Args:
                 index_turn: index of turn of interest (not numbered for dialogs)
             Returns:
                 x: user_utterance + info turns id, dialouge_id
                 y: label (action to predict and attributes)
        """
        phrase = {
            "turn_idx": self.phrases["turn_idx"][index_turn],
            "dialog_id": self.phrases["dialog_id"][index_turn],
            "input_ids": self.phrases["input_ids"][index_turn],
            "attention_mask": self.phrases["attention_mask"][index_turn]
        }
        label = {
            "action": self.labels["action"][index_turn],
            "attributes": self.labels["attributes"][index_turn]
         }
        return phrase, label

## ATTRIBUTE ENUM

In [None]:
NUM_ATTRIBUTES = len(list(filter(lambda elem: elem[0] > DISCARD_ATTRIBUTES_BELOW_COUNT, attr_attr_counts))) + 1
print(f"Using the {NUM_ATTRIBUTES-1} most frequent attibutes.")

class Attribute(Enum):
    other = 0
    price = 1
    availableSizes = 2
    customerRating = 3
    brand = 4
    info = 5
    color = 6
    embellishment = 7
    pattern = 8
    hemLength = 9
    skirtStyle = 10
    dressStyle = 11
    material = 12
    clothingStyle = 13
    necklineStyle = 14
    size = 15
    jacketStyle = 16
    sweaterStyle = 17
    hemStyle = 18
    sleeveStyle = 19
    waistStyle = 20
    sleeveLength = 21
    clothingCategory = 22
    skirtLength = 23
    soldBy = 24
    madeIn = 25
    ageRange = 26
    waterResistance = 27
    warmthRating = 28
    sequential = 29
    hasPart = 30
    forOccasion = 31
    forGender = 32
    amountInStock = 33

    @classmethod
    def length(cls):
        return NUM_ATTRIBUTES

    @classmethod
    def from_str(cls, attribute):
        if attribute == "price":
            value = cls.price.value
        elif attribute == "availableSizes":
            value = cls.availableSizes.value
        elif attribute == "customerRating":
            value = cls.customerRating.value
        elif attribute == "brand":
            value = cls.brand.value
        elif attribute == "info":
            value = cls.info.value
        elif attribute == "color":
            value = cls.color.value
        elif attribute == "embellishment":
            value = cls.embellishment.value
        elif attribute == "pattern":
            value = cls.pattern.value
        elif attribute == "hemLength":
            value = cls.hemLength.value
        elif attribute == "skirtStyle":
            value = cls.skirtStyle.value
        elif attribute == "dressStyle":
            value = cls.dressStyle.value
        elif attribute == "material":
            value = cls.material.value
        elif attribute == "clothingStyle":
            value = cls.clothingStyle.value
        elif attribute == "necklineStyle":
            value = cls.necklineStyle.value
        elif attribute == "size":
            value = cls.size.value
        elif attribute == "jacketStyle":
            value = cls.jacketStyle.value
        elif attribute == "sweaterStyle":
            value = cls.sweaterStyle.value
        elif attribute == "hemStyle":
            value = cls.hemStyle.value
        elif attribute == "sleeveStyle":
            value = cls.sleeveStyle.value
        elif attribute == "waistStyle":
            value = cls.waistStyle.value
        elif attribute == "sleeveLength":
            value = cls.sleeveLength.value
        elif attribute == "clothingCategory":
            value = cls.clothingCategory.value
        elif attribute == "skirtLength":
            value = cls.skirtLength.value
        elif attribute == "soldBy":
            value = cls.soldBy.value
        elif attribute == "madeIn":
            value = cls.madeIn.value
        elif attribute == "ageRange":
            value = cls.ageRange.value
        elif attribute == "waterResistance":
            value = cls.waterResistance.value
        elif attribute == "warmthRating":
            value = cls.warmthRating.value
        elif attribute == "sequential":
            value = cls.sequential.value
        elif attribute == "hasPart":
            value = cls.hasPart.value
        elif attribute == "forOccasion":
            value = cls.forOccasion.value
        elif attribute == "forGender":
            value = cls.forGender.value
        elif attribute == "amountInStock":
            value = cls.amountInStock.value
        if value < NUM_ATTRIBUTES:
            return value
        else:
            return cls.other.value
            
    @classmethod
    def from_number(cls, value):
        if value == cls.price.value:
            return "price"
        elif value == cls.availableSizes.value:
            return "availableSizes"
        elif value == cls.customerRating.value:
            return "customerRating"
        elif value == cls.brand.value:
            return "brand"
        elif value == cls.info.value:
            return "info"
        elif value == cls.color.value:
            return "color"
        elif value == cls.embellishment.value:
            return "embellishment"
        elif value == cls.pattern.value:
            return "pattern"
        elif value == cls.hemLength.value:
            return "hemLength"
        elif value == cls.skirtStyle.value:
            return "skirtStyle"
        elif value == cls.dressStyle.value:
            return "dressStyle"
        elif value == cls.material.value:
            return "material"
        elif value == cls.clothingStyle.value:
            return "clothingStyle"
        elif value == cls.necklineStyle.value:
            return "necklineStyle"
        elif value == cls.size.value:
            return "size"
        elif value == cls.jacketStyle.value:
            return "jacketStyle"
        elif value == cls.sweaterStyle.value:
            return "sweaterStyle"
        elif value == cls.hemStyle.value:
            return "hemStyle"
        elif value == cls.sleeveStyle.value:
            return "sleeveStyle"
        elif value == cls.waistStyle.value:
            return "waistStyle"
        elif value == cls.sleeveLength.value:
            return "sleeveLength"
        elif value == cls.clothingCategory.value:
            return "clothingCategory"
        elif value == cls.skirtLength.value:
            return "skirtLength"
        elif value == cls.soldBy.value:
            return "soldBy"
        elif value == cls.madeIn.value:
            return "madeIn"
        elif value == cls.ageRange.value:
            return "ageRange"
        elif value == cls.waterResistance.value:
            return "waterResistance"
        elif value == cls.warmthRating.value:
            return "warmthRating"
        elif value == cls.sequential.value:
            return "sequential"
        elif value == cls.hasPart.value:
            return "hasPart"
        elif value == cls.forOccasion.value:
            return "forOccasion"
        elif value == cls.forGender.value:
            return "forGender"
        elif value == cls.amountInStock.value:
            return "amountInStock"
        else:
            return "other"

# CREATE MODEL

## MODEL

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = use_cuda

class Assistant(nn.Module):
    def __init__(self, num_actions, num_attributes):
        """
        :param num_actions: number of possible actions to be predicted
        :param num_attributes: number of possible attributes to be predicted
        """
        super(Assistant, self).__init__()
        self.model = 'bert-base-uncased' if FEATURES == 768 else 'bert-large-uncased'
        self.bert = BertModel.from_pretrained(self.model)
        
        # layer for action
        self.action_pre_classifier = nn.Linear(FEATURES, FEATURES)
        self.action_activation = nn.ReLU()
        self.action_dropout = nn.Dropout(ACTION_DROPOUT)  # avoid overfitting
        
        self.action_classifier = nn.Linear(FEATURES, num_actions)

        # layer for attributes
        self.attributes_dropout = nn.Dropout(ATTRIBUTES_DROPOUT)  # avoid overfitting
        self.attributes_classifier = nn.Linear(FEATURES, num_attributes)
        self.sigmoid = nn.Sigmoid()
           

    def forward(self, batch: dict, label, predict):
        output = self.bert(input_ids=batch['input_ids'],attention_mask=batch['attention_mask'])
        hidden_state= output[0]
        pooler_output = hidden_state[:, 0]
        #predict action
        pooler_action = self.action_pre_classifier(pooler_output)
        pooler_action = self.action_activation(pooler_action)
        pooler_action = self.action_dropout(pooler_action)
        action_logits = self.action_classifier(pooler_action)
        if predict:
            action_pred_tensor = torch.nn.functional.softmax(action_logits, dim=1)
            _, action_pred = torch.max(action_pred_tensor.data, dim=1) 
        else:
            action_pred = []

        # Predict attributes
        pooler_attributes = self.attributes_dropout(pooler_output)
        attributes_logits = self.attributes_classifier(pooler_attributes)
        # predict best predicted attrs vector
        if predict:
            attributes_prob = self.sigmoid(attributes_logits) # used to calculate the probability of every attributes in prediction, not as an activation function
            attributes_pred = self.search_best_predict_attributes(label['attributes'], attributes_prob)
        else:
            attributes_pred = []
        return action_logits, action_pred, attributes_logits, attributes_pred

    def search_best_predict_attributes(self, true_attributes, pred_attributes_prob):
        # Store the list of f1 scores for prediction on each threshold
        # convert labels to 1D array
        opt_pred_attrs = [] 
        i=0
        for true_attr in true_attributes:
          pred_attrs = []
          scores = []
          tr_attr = np.array(true_attr.cpu())
          is_all_zero = np.all(tr_attr == 0.0)
          if is_all_zero: #not predict attributes if label attributes is empty
            best_attrs_choose = true_attr
          else: #select the best attr predict using dynamic trhesold
            pred_prob = pred_attributes_prob[i]
            step_increment = 0.001
            min_threshold = torch.min(pred_prob).item() 
            max_threshold = torch.max(pred_prob).item()
            thresholds = np.arange(min_threshold-0.01, max_threshold+0.01, step_increment)
            thresholds = thresholds[::-1] #reverse array
            # classes for each threshold
            for thresh in thresholds:
                # convert probability of vector of pred_prob in multi-label output
                pred_attr = np.where(pred_prob.cpu() > thresh, 1, 0)
                is_all_zero_pred = np.all(pred_attr == 0) # attrs is not empty: skip this case in possible prediction
                if not is_all_zero_pred:
                  pred_attr = torch.tensor(pred_attr).to(device=device)
                  pred_attrs.append(pred_attr) 
                  # save metric to search best convertion of attributes array respect the true target
                  score = metrics.f1_score(true_attr.cpu(), pred_attr.cpu(), average='weighted')
                  scores.append(score)
                  if score == 1: #best score found!
                     break
            # select the pred_attrs with max score
            best_attrs_choose = pred_attrs[scores.index(max(scores))]
          opt_pred_attrs.append(best_attrs_choose)
          i = i + 1
        opt_pred_attrs = torch.stack(opt_pred_attrs)
        return opt_pred_attrs


# TRAIN/TEST MODEL

# TRAINING STEP

In [None]:
print(f"Using device: {device}")

min_epochs = 0

# checkpoints is the list of already available checkpoint numbers
checkpoints = sorted([
    int(fname.split("epoch")[1].split(".pth")[0]) 
    for fname in os.listdir(f"{PROJECT_DIR}MODELS") if fname.endswith(".pth")
])
if checkpoints:
    ans = input(f"The following saved checkpoints already exist:\n{checkpoints}\n" \
                "Enter the number of the checkpoint to resume from, or enter to train a model from scratch: [1,2,3... | Enter]\n")
    if not ans.isdigit():
        print("Instantiating default pretrained model to train from scratch...")
        model = Assistant(Action.length(), Attribute.length())
    else:
        min_epochs = int(ans)
        if min_epochs > MAX_EPOCHS:
            print(f"WARNING: model checkpoint to load is {ans}, but MAX_EPOCHS is set to {MAX_EPOCHS}.\n" \
                  "Training is already finished, or some parameters are wrongly set up. Stopping...")
            exit(1)
        print(f"Loading checkpoint {ans} into model...")
        model = torch.load(f"{PROJECT_DIR}MODELS/model_epoch{ans}.pth")
        with open(f"{PROJECT_DIR}MODELS/losses_epoch{ans}.json", "r") as f:
            loss_act_array_tr, loss_attrs_array_tr = json.load(f)
else:
    print(f"No previous checkpoints found in {PROJECT_DIR}MODELS.\n" \
          "Instantiating default pretrained model to train from scratch...")
    model = Assistant(Action.length(), Attribute.length())
    
model.to(device)
model.train()
optim = AdamW(model.parameters(), lr=LEARNING_RATE)
loss_function_action = torch.nn.CrossEntropyLoss()
# if you wish to use a weighted CrossEntropy loss, comment the line above and uncomment the line below
# loss_function_action = torch.nn.CrossEntropyLoss(weight=weight_action)
loss_function_attributes = torch.nn.BCEWithLogitsLoss()
dataset = Dataset(TRAIN_FILE)
max_epochs = MAX_EPOCHS
params = {
    'shuffle': True,
    'batch_size': TRAIN_BATCH_SIZE,
}
dataloader = torch.utils.data.DataLoader(dataset, **params)
print(f"Our dataset has {len(dataset)} entries.")
print(f"Training for {max_epochs - min_epochs} epochs, each with {len(dataloader)} steps.")
if min_epochs == 0:
    loss_act_array_tr = []
    loss_attrs_array_tr = []

for epoch in range(min_epochs, max_epochs):
    loss_act_tot = 0
    loss_attrs_tot = 0
    counter_loss = 0 #counter of number of iteration to calculate mean of losses
    for batch, label in tqdm(dataloader, position=0, leave=True):
        for key in batch:
            batch[key] = batch[key].to(device, dtype=torch.long)
        label['action'] = label['action'].to(device, dtype=torch.long)
        label['attributes'] = label['attributes'].to(device, dtype=torch.float)
        action_logits, action_pred, attributes_logits, attributes_pred = model(batch, label, predict=False)
        # action, attributes  metrics loss
        loss_action = loss_function_action(action_logits.view(-1, Action.length()), label['action'].view(-1))
        # if you wish to use a weighted CrossEntropy loss, comment the line above and uncomment the line below
        # loss_action = loss_function_action(action_logits, label['action'])
        loss_act_tot += loss_action.item()
        #
        loss_attributes = loss_function_attributes(attributes_logits.view(-1, Attribute.length()), label['attributes'].view(-1, Attribute.length()))
        loss_attrs_tot += loss_attributes.item()
        loss = (loss_action + loss_attributes)/2
        counter_loss += 1
        optim.zero_grad()
        loss.backward()
        optim.step()

    # save results
    loss_act_array_tr.append(loss_act_tot/counter_loss) # mean of losses of action in epoch
    loss_attrs_array_tr.append(loss_attrs_tot/counter_loss)  # mean of losses of attributes in epoch
    print(f"Epoch {epoch + 1}:\nAction loss: {loss_act_tot/counter_loss} | Attributes loss: {loss_attrs_tot/counter_loss}")
    with open(f"{PROJECT_DIR}MODELS/losses_epoch{epoch + 1}.json", "w") as f:
        json.dump([loss_act_array_tr, loss_attrs_array_tr], f)
    torch.save(model, f"{PROJECT_DIR}MODELS/model_epoch{epoch + 1}.pth")

loss_action = loss_act_array_tr[MAX_EPOCHS-1]
print(f'Evaluate TRAIN Loss action: {loss_action}')
#
loss_attributes =  loss_attrs_array_tr[MAX_EPOCHS-1]
print(f'Evaluate TRAIN Loss attributes: {loss_attributes}')


## TESTING STEP


In [None]:
SPLITS_TEST = [TEST_FILE, DEV_FILE]
#vector of vector of value of losses for every split
losses_act_array = []
losses_attrs_array = []
accs_act_array = []
accs_attrs_array = []

for file in SPLITS_TEST:
    split = "devtest" if "devtest" in file else "dev"

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    print(f"Using device: {device}")
    torch.backends.cudnn.benchmark = True

    loss_function_action = torch.nn.CrossEntropyLoss()
    # if you wish to use a weighted CrossEntropy loss, comment the line above and uncomment the line below
    # loss_function_action = torch.nn.CrossEntropyLoss(weight=weight_action)
    loss_function_attributes = torch.nn.BCEWithLogitsLoss()
    dataset = Dataset(file)
    max_epochs = MAX_EPOCHS
    params = {
        'shuffle': True,
        'batch_size': TEST_BATCH_SIZE,
    }
    dataloader = torch.utils.data.DataLoader(dataset, **params)
    print(f"Evaluate Our dataset with {len(dataset)} entries.")
    print(f"Evaluate for {max_epochs} epochs, each with {len(dataloader)} steps.")
    loss_act_array = []
    loss_attrs_array = []
    acc_act_array = []
    acc_attrs_array = []
    for epoch in range(max_epochs):
        total_loss_action = 0
        total_loss_attrs = 0
        counter_loss = 0 #counter of number of iteration to calculate mean of losses
        model = torch.load(f"{PROJECT_DIR}MODELS/model_epoch{epoch+1}.pth")
        model.to(device)
        model.eval()
        for batch, label in tqdm(dataloader, position=0, leave=True):
            for key in batch:
                batch[key] = batch[key].to(device, dtype=torch.long)

            label['action'] = label['action'].to(device, dtype=torch.long)
            label['attributes'] = label['attributes'].to(device, dtype=torch.float)
            with torch.no_grad():
                action_logits, action_pred, attributes_logits, attributes_pred = model(batch, label, predict=False)

            # action, attributes  metrics loss
            loss_action = loss_function_action(action_logits.view(-1, Action.length()), label['action'].view(-1))
            # if you wish to use a weighted CrossEntropy loss, comment the line above and uncomment the line below
            # loss_action = loss_function_action(action_logits, label['action'])
            total_loss_action += loss_action.item()
            #:
            loss_attrs = loss_function_attributes(attributes_logits.view(-1, Attribute.length()), label['attributes'].view(-1,Attribute.length()))
            total_loss_attrs += loss_attrs.item()
            counter_loss += 1
        # save results
        loss_action = total_loss_action/counter_loss
        loss_attributes = total_loss_attrs/counter_loss
        loss_act_array.append(loss_action)
        loss_attrs_array.append(loss_attributes)
        print(f"Epoch {epoch + 1}:\nAction loss: {loss_action} | Attributes loss: {loss_attributes}")
    # append vector in list for split 
    losses_act_array.append(loss_act_array)
    losses_attrs_array.append(loss_attrs_array)
    #
    loss_action = total_loss_action/counter_loss
    print(f'Evaluate {split} Loss action: {loss_action}')
    #
    loss_attributes = total_loss_attrs/counter_loss
    print(f'Evaluate {split} Loss attributes: {loss_attributes}')

# RESULT PLOTS

Use these plots to check for the overfitting point.

In [None]:
x = np.arange(1, MAX_EPOCHS + 1, 1)
x_minor = np.arange(1, MAX_EPOCHS + 1, 0.5)

def make_plot(y_train, y_test, y_dev, title: str, ylabel: str, fig_path: str):
    fig, ax = plt.subplots(figsize=(7.5, 7.5))
    plt.plot(x, y_train, label=f'TRAIN Set')
    plt.plot(x, y_test, label=f'TEST set', linestyle="dashed")
    plt.plot(x, y_dev, label=f'DEV set', linestyle="dotted")
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.legend()
    ax.set_xticks(x)
    ax.set_xticks(x_minor, minor=True)
    ax.grid(which='minor', alpha=0.2)
    ax.grid(which='major', alpha=0.5)
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

hyper_params = " - ".join([
    f"lr={LEARNING_RATE}",
    f"epochs={MAX_EPOCHS}",
    f"tr_btch_sz={TRAIN_BATCH_SIZE}",
    f"tst_btch_sz={TEST_BATCH_SIZE}",
    f"hdn_state={FEATURES}",
    f"act_drop={ACTION_DROPOUT}",
    f"attr_drop={ATTRIBUTES_DROPOUT}",
    f"disc_attrs={DISCARD_ATTRIBUTES_BELOW_COUNT}",
    f"history={HISTORY}",
])
    
# Action Loss
print("ACTION LOSS PLOTS")
make_plot(loss_act_array_tr, 
          losses_act_array[0], 
          losses_act_array[1],
          f"CrossEntropyLoss ACTION {hyper_params}",
          "CrossEntropyLoss",
          f"{PROJECT_DIR}RESULTS/losses_action{HISTORY}.png")

# Attributes Loss
print(f"ATTRIBUTES LOSS PLOTS")
make_plot(loss_attrs_array_tr,
          losses_attrs_array[0],
          losses_attrs_array[1],
          f"BCEWithLogitsLoss ATTRIBUTES {hyper_params}",
          "BCEWithLogitsLoss",
          f"{PROJECT_DIR}RESULTS/losses_attributes{HISTORY}.png")


# GENERATE OUTPUT 

In the following cell, we generate the actual output of the best checkpoint of the model (to choose by looking at the graphs above, `MAX_EPOCHS` by default).  
If we observe an intersection between the TRAIN and DEV lines in the actions or attributes losses graphs above, we should pick the highest epoch number after that point where the DEV loss is approximately the same as it was at the intersection point, to allow for the other loss (typically, the one of the attributes) to decrease.

In [None]:
def save_output_model(num_epoch, dataloader, split):
    print(f"Creating output of model in RESULTS folder...")
    model = torch.load(f"{PROJECT_DIR}MODELS/model_epoch{num_epoch}.pth")
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    torch.backends.cudnn.benchmark = True
    model.to(device)
    model.eval()
    dialog_ids = []
    results_dials = []
    output = []
    cont=0
    for batch, label in tqdm(dataloader, position=0, leave=True):
        for key in batch:
            batch[key] = batch[key].to(device, dtype=torch.long)
        label['action'] = label['action'].to(device, dtype=torch.long)
        label['attributes'] = label['attributes'].to(device, dtype=torch.float)
        with torch.no_grad():
          action_logits, action_pred, attributes_logits, attributes_pred = model(batch, label, predict=True)
        i = 0
        for dialog_id in batch["dialog_id"]:
          #save info for dials turns
          attributes = []
          j = 0
          for index in range(len(attributes_pred[i])):
            if attributes_pred[i][index] == 1:
              attributes.append(Attribute.from_number(index))
          result_turn = {
              "turn_id": batch["turn_idx"][i].item(),
              "action": Action.from_number(action_pred[i]),
              "action_log_prob": {
                  Action.from_number(0):action_logits[i][0].item(),
                  Action.from_number(1):action_logits[i][1].item(),
                  Action.from_number(2):action_logits[i][2].item(),
                  Action.from_number(3):action_logits[i][3].item(),
                  Action.from_number(4):action_logits[i][4].item()
              },
              "attributes": attributes
          }
          if dialog_id not in dialog_ids:
              dialog_ids.append(dialog_id.item())
              vect = [result_turn]
              results_dials.append(vect)
          else:
              results_dials[dialog_ids.index(dialog_id)].append(result_turn)
          i += 1
    i = 0

    for dialog_id in dialog_ids:
      predictions = []
      for result in results_dials[i]:
          predictions.append({
              "action": result["action"],
              "action_log_prob": result["action_log_prob"],
              "attributes": {
                  "attributes": result["attributes"]
              }, 
              "turn_id": result["turn_id"]
          })
      output.append({
        "dialog_id": dialog_id,
        "predictions": predictions
      })
      i = i+1
    # training results
    print(f"Savings output of model in RESULTS folder")
    with open(f"{PROJECT_DIR}RESULTS/output_{split}{HISTORY}.json", "w") as f:
        json.dump(output, f)

SPLITS_FILE = [TEST_FILE]
# uncomment the following line if you want to generate the output for all the dataset splits
# keep in mind it takes a very long time, in the hours order of magnitude
# SPLITS_FILE = [TRAIN_FILE, TEST_FILE, DEV_FILE]

for file in SPLITS_FILE:
    if "train" in file:
        split = "train"
    elif "devtest" in file:
        split = "devtest"
    else:
        split = "dev"

    dataset = Dataset(file)
    params = {
        'shuffle': True,
        'batch_size': TRAIN_BATCH_SIZE if split == "train" else TEST_BATCH_SIZE
    }
    dataloader = torch.utils.data.DataLoader(dataset, **params)
    print(f"Generate output of model for {split} dataset with {len(dataset)} entries.")
    save_output_model(MAX_EPOCHS, dataloader, split)

# EVALUATE ACCURACY AND PERPLEXITY OF OUTPUTS

In [None]:
"""Script evaluates action prediction along with attributes.
Author(s): Satwik Kottur
"""


IGNORE_ATTRIBUTES = [
    "minPrice",
    "maxPrice",
    "furniture_id",
    "material",
    "decorStyle",
    "intendedRoom",
    "raw_matches",
    "focus",  # fashion
]

def evaluate_action_prediction(
    gt_actions,
    model_actions,
    single_round_eval=False,
    compute_std_err=False,
    record_instance_results=None,
):
    """Evaluates action prediction using the raw data and model predictions.
    Args:
        gt_actions: Ground truth actions + action attributes
        model_actions: Actions + attributes predicted by the model
        single_round_eval: Evaluate only for the last turn
        compute_std_err: Computes standard error for the metrics
        record_instance_results: Record the result per instance
    """
    gt_actions_pool = {ii["dialog_id"]: ii for ii in gt_actions}
    matches = {"action": [], "attributes": [], "perplexity": []}
    confusion_dict = collections.defaultdict(list)
    for model_datum in model_actions:
        dialog_id = model_datum["dialog_id"]
        num_gt_rounds = len(gt_actions_pool[dialog_id]["actions"])
        for round_datum in model_datum["predictions"]:
            round_id = round_datum["turn_id"]
            # Skip if single_round_eval and this is not the last round.
            if single_round_eval and round_id != num_gt_rounds - 1:
                continue

            gt_datum = gt_actions_pool[dialog_id]["actions"][round_id]
            action_match = gt_datum["action"] == round_datum["action"]
            # Record matches and confusion.
            matches["action"].append(action_match)
            matches["perplexity"].append(
                round_datum["action_log_prob"][gt_datum["action"]]
            )
            confusion_dict[gt_datum["action"]].append(round_datum["action"])

            # Add the result to datum and save it back.
            if record_instance_results:
                round_datum["action_result"] = action_match
                round_datum["gt_action"] = gt_datum["action"]

            # Get supervision for action attributes.
            supervision = gt_datum["action_supervision"]
            if supervision is not None and "args" in supervision:
                supervision = supervision["args"]
            if supervision is None:
                continue
            #Case 1: Action mismatch -- record False for all attributes.
            if not action_match:
                for key in supervision.keys():
                    if key in IGNORE_ATTRIBUTES:
                        continue
                    matches["attributes"].append(False)
            # Case 2: Action matches -- use model predictions for attributes.
            else:
                for key in supervision.keys():
                    if key in IGNORE_ATTRIBUTES:
                        continue
                    gt_key_vals = supervision[key]
                    model_key_vals = round_datum["attributes"][key]

                    if not len(gt_key_vals):
                        continue
                    # For fashion, this is a list -- multi label prediction.
                    if isinstance(gt_key_vals, list):
                        assert isinstance(
                            model_key_vals, list
                        ), "Model should also predict a list for attributes"
                        recall = np.mean([(ii in model_key_vals) for ii in gt_key_vals])
                        if len(model_key_vals):
                            precision = np.mean(
                                [(ii in gt_key_vals) for ii in model_key_vals]
                            )
                        else:
                            precision = 0.0
                        f1_score = (2 * recall * precision) / (
                            recall + precision + 1e-6
                        )
                        matches["attributes"].append(f1_score)
                    else:
                        # For furniture, this is a string -- single label prediction.
                        matches["attributes"].append(gt_key_vals == model_key_vals)

    print("#Instances evaluated API: {}".format(len(matches["action"])))
    # Record and save per instance results.
    if record_instance_results:
        print("Saving per instance result: {}".format(record_instance_results))
        with open(record_instance_results, "w") as file_id:
            json.dump(model_actions, file_id)

    # Compute the confusion matrix.
    all_actions = sorted(
        set(confusion_dict.keys()).union(
            {jj for ii in confusion_dict.values() for jj in ii}
        )
    )
    matrix = np.zeros((len(all_actions), len(all_actions)))
    for index, action in enumerate(all_actions):
        labels, counts = np.unique(confusion_dict[action], return_counts=True)
        for label, count in zip(labels, counts):
            matrix[all_actions.index(label), index] += count
    metrics = {
        "action_accuracy": np.mean(matches["action"]),
        "action_perplexity": np.exp(-1 * np.mean(matches["perplexity"])),
        "attribute_accuracy": np.mean(matches["attributes"]),
        "confusion_matrix": matrix,
    }
    if compute_std_err:
        metrics_std_err = {
            "action_accuracy": (
                np.std(matches["action"]) / np.sqrt(len(matches["action"]))
            ),
            "action_perplexity": (
                (
                    np.exp(-1 * np.std(matches["perplexity"]))
                    / np.sqrt(len(matches["perplexity"]))
                )
            ),
            "attribute_accuracy": (
                np.std(matches["attributes"]) / np.sqrt(len(matches["attributes"]))
            ),
        }
        return metrics, metrics_std_err
    else:
        return metrics
    

for file in SPLITS_FILE:
    if "train" in file:
        split = "train"
    elif "devtest" in file:
        split = "devtest"
    else:
        split = "dev"
    
    if HISTORY != 0:
        HISTORY_PAD = HISTORY
        
    print("Reading: {}".format(f"{DATA_DIR}fashion_{split}_dials_api_calls.json"))
    with open(f"{DATA_DIR}fashion_{split}_dials_api_calls.json", "r") as file_id:
        gt_actions = json.load(file_id)
    
    print("Reading: {}".format(f"{PROJECT_DIR}RESULTS/output_{split}{HISTORY_PAD}.json"))
    with open(f"{PROJECT_DIR}RESULTS/output_{split}{HISTORY_PAD}.json", "r") as file_id:
        model_actions = json.load(file_id)
    
    action_metrics = evaluate_action_prediction(
        gt_actions,
        model_actions,
    )
    # print(action_metrics)
    
    if split == "devtest":
        results_json = f"{PROJECT_DIR}RESULTS/metrics.json"
        print(f"Saving metrics result as {results_json}")
        metrics = {
            "hyperparameters":{
                "learning_rate": LEARNING_RATE,
                "discard_attributes_below_count": DISCARD_ATTRIBUTES_BELOW_COUNT,
                "features": FEATURES,
                "action_dropout": ACTION_DROPOUT,
                "attributes_dropout": ATTRIBUTES_DROPOUT,
                "max_epochs": MAX_EPOCHS,
                "train_batch_size": TRAIN_BATCH_SIZE,
                "test_batch_size": TEST_BATCH_SIZE,
                "history": HISTORY,
            },
            "action_accuracy": action_metrics['action_accuracy'].astype(float),
            "action_perplexity": action_metrics['action_perplexity'].astype(float),
            "attribute_accuracy":action_metrics['attribute_accuracy'].astype(float),
        }
        with open(results_json, "w") as f:
            json.dump(metrics, f)
        
    plt.rcParams["figure.autolayout"] = True
    fig, ax = plt.subplots(1, 1)
    data = [[action_metrics['action_accuracy'], action_metrics['action_perplexity'], action_metrics['attribute_accuracy']]]
    column_labels=["Action Accuracy", "Action Perplexity", "Attribute Accuracy"]
    ax.axis('tight')
    ax.axis('off')
    table = ax.table(cellText=data, colLabels=column_labels, loc="center")
    table.auto_set_font_size(False)
    table.set_fontsize(24)
    table.scale(5, 5)
    plt.savefig(f"{PROJECT_DIR}RESULTS/metrics_{split}{HISTORY}.png", dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()
         
    fig, ax = plt.subplots(figsize=(10, 8))
    labels = [''] + [Action.from_number(a.value) for a in Action]
    mat = ax.matshow(action_metrics['confusion_matrix'], alpha=0.5)
    for i in range(action_metrics['confusion_matrix'].shape[0]):
        for j in range(action_metrics['confusion_matrix'].shape[1]):
            ax.text(x=j, y=i, s=action_metrics['confusion_matrix'][i, j], va='center', ha='center', size='xx-large')
    ax.set_xticklabels(labels)
    ax.tick_params('x', labelrotation=45)
    ax.xaxis.set_ticks_position("bottom")
    ax.set_yticklabels(labels)
    plt.colorbar(mat)
    plt.xlabel('Action Predictions', fontsize=18)
    plt.ylabel('Action Ground Truth', fontsize=18)
    plt.title('Actions Confusion Matrix', fontsize=22)
    fig.tight_layout()
    plt.savefig(f"{PROJECT_DIR}RESULTS/metrics_actions_conf_mat_{split}{HISTORY}.png", dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()