In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
os.chdir("drive/My Drive/target-guided-sat-chatbot")

In [None]:
!pip install torch_geometric

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
import networkx as nx
from torch_geometric.utils.convert import to_networkx, from_networkx
from torch_geometric.nn import GCNConv
import torch.optim as optim
from tqdm.auto import tqdm
import torchtext
import random
import time
import datetime
from nltk import word_tokenize
import json
from global_planning import GlobalPlanning
import pickle

In [None]:
import nltk
nltk.download('punkt')

Keyword Predictor Model

In [None]:
class KeywordPredictor(nn.Module):
    def __init__(self, global_planning):
        super().__init__()
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        print(f"Using Cuda?: {torch.cuda.is_available()}")
        self.embed_size = 300
        self.context_enc_hidden_size = 256
        self.global_planning = global_planning
        self.glove = self.global_planning.glove
        self.fc = nn.Sequential(
            nn.Linear(self.embed_size + self.embed_size + self.context_enc_hidden_size * 2, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 2),
        ).to(self.device)
        self.embedding_layer = nn.Embedding.from_pretrained(self.glove.vectors).to(self.device)
        self.gcn_graph_encoder = GCN(self.embed_size).to(self.device)
        self.context_encoder = nn.GRU(self.embed_size, self.context_enc_hidden_size, 1, bidirectional=True).to(self.device)
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)
        # self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, nodes_ids, global_graphs, target_ids, context_embeddings, labels):
        loss = None
        logits_list = []
        for node_id, global_graph, target_id, context_embedding, label in zip(nodes_ids, global_graphs, target_ids, context_embeddings, labels):
            inputs = torch.cat((global_graph, target_id, context_embedding), 1)
            logits = self.fc(inputs) # logits: (vocab_size, 2)
            logits_list.append(logits)

            if loss is None:
                loss = self.criterion(logits, label) # labels: (vocab_size)
            else:
                loss += self.criterion(logits, label)
        return {'loss': loss, 'logits': logits_list}

    def predict(self, nodes_ids, global_graphs, target_ids, context_embeddings):
        self.eval()
        logits_list = []
        for node_id, global_graph, target_id, context_embedding in zip(nodes_ids, global_graphs, target_ids, context_embeddings):
            inputs = torch.cat((global_graph, target_id, context_embedding), 1)
            logits = self.fc(inputs) # logits: (vocab_size, 2)
            logits_list.append(logits)

        return {'logits': logits_list}

    def collate_fn(self, batch):

        global_graphs_nodes_ids = [b['global_graphs'].x.to(self.device) for b in batch]

        global_graphs_embeddings = [self.gcn_graph_encoder(self.embedding_layer(b['global_graphs'].x.to(self.device)), b['global_graphs'].edge_index.to(self.device)) for b in batch]

        target_ids_embeddings = [self.embedding_layer(torch.tensor(b['target_ids']).to(self.device)).mean(0).repeat(b['global_graphs'].num_nodes, 1) for b in batch]

        context_ids_embeddings = [self.context_encoder(self.embedding_layer(torch.tensor(b['context_ids']).to(self.device)))[1].reshape(-1).repeat(b['global_graphs'].num_nodes, 1) for b in batch]

        labels = [b['labels'].to(self.device) for b in batch]
        return {'nodes_ids':global_graphs_nodes_ids, 'global_graphs':global_graphs_embeddings, 'target_ids':target_ids_embeddings, 'context_embeddings':context_ids_embeddings, 'labels':labels}

# GCN to encode graph
class GCN(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.conv1 = GCNConv(embed_size, 512)
        self.conv2 = GCNConv(512, embed_size)

    def forward(self, x, edge_index):
        # x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        return x




Dataset

In [None]:
# ========================================
#               Dataset
# ========================================

class KeywordPredictionDataset(Dataset):

  def __init__(self, data_json, global_planning, topk_num):

    self.context_ids = []
    self.global_graphs = []
    self.target_ids = []
    self.labels = []
    self.global_planning = global_planning
    self.glove = self.global_planning.glove
    self.topk_num = topk_num

    max_context_length = 3

    for data in tqdm(data_json):

        dialog = data['dialog']
        concepts = data['concepts']

        idx = random.randint(1, len(dialog) - 2)
        start_idx = max(0, idx - max_context_length)
        context = dialog[start_idx:idx]
        target = dialog[idx + 1]

        # TODO: check concept in wv and graph
        start_concept = [c for c in concepts[idx - 1] if self.global_planning.word_exists_in_conceptnet(c) and self.global_planning.word_embedding_exists(c)]
        bridge_concept = [c for c in concepts[idx] if self.global_planning.word_exists_in_conceptnet(c) and self.global_planning.word_embedding_exists(c)]
        target_concept = [c for c in concepts[idx + 1] if self.global_planning.word_exists_in_conceptnet(c) and self.global_planning.word_embedding_exists(c)]

        # TODO: check bc and tc is not empty
        if len(start_concept) == 0 or len(bridge_concept) == 0 or len(target_concept) == 0:
           continue

        # Context
        context_id = []
        for dialog in context:
          for word in word_tokenize(dialog):
            if word in self.glove.stoi:
              context_id.append(self.glove.stoi[word])

        # Global graph
        global_graph = nx.Graph()
        for s in start_concept:
            for t in target_concept:
                self.global_planning.find_path(s, t, global_graph)

        global_graph_nodes = list(global_graph.nodes)
        if len(global_graph_nodes) < self.topk_num:
           continue # not enough nodes

        for n in global_graph_nodes:
            global_graph.nodes[n]['x'] = self.glove.stoi[n]

        # Target
        target_id = []
        for t in target_concept:
          if t in global_graph_nodes:
            target_id.append(global_planning.glove.stoi[t])
            for n in global_graph.neighbors(t):
              target_id.append(global_planning.glove.stoi[n])

        # Label of classification task
        node_to_idx = dict(zip(global_graph_nodes, range(len(global_graph_nodes))))
        bridge_idxs = [node_to_idx[n] for n in bridge_concept if n in node_to_idx]
        if len(bridge_idxs) == 0:
           print("no true labels", bridge_concept)
           continue # no true labels

        candidate_nodes = set()
        # limited to 2 hops
        for c0 in start_concept:
          if c0 in global_graph_nodes:
            candidate_nodes.add(c0)
            for c1 in global_graph.neighbors(c0):
              if c1 in global_graph_nodes:
                candidate_nodes.add(c1)
                for c2 in global_graph.neighbors(c1):
                    candidate_nodes.add(c2)
        candidate_idxs = [node_to_idx[n] for n in candidate_nodes]

        label = torch.ones(len(global_graph_nodes), dtype=int) * -100
        label[candidate_idxs] = 0
        label[bridge_idxs] = 1

        self.context_ids.append(context_id)
        self.global_graphs.append(from_networkx(global_graph)) # Converts the graph to a torch_geometric.data.Data instance.
        self.target_ids.append(target_id)
        self.labels.append(label)

  def __len__(self):
    return len(self.context_ids)

  def __getitem__(self, idx):
    return {'context_ids' : self.context_ids[idx], 'global_graphs' : self.global_graphs[idx], 'target_ids' : self.target_ids[idx], 'labels' : self.labels[idx]}



Training loop

In [None]:
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [None]:
use_numberbatch = False
global_planning = GlobalPlanning(use_numberbatch) # around 2 min

Produce dataset:

In [None]:
with open('data/train/concepts_nv.json') as f:
    train_data_json = [json.loads(row) for row in f]
train_data_json = train_data_json[:10]
print(f"train length: {len(train_data_json)}")

In [None]:
topk_num = 5
train_dataset = KeywordPredictionDataset(train_data_json, global_planning, topk_num)
print("Train dataset length: ", len(train_dataset))

In [None]:
# pickle dataset
with open("keyword_predictor_train_dataset.pickle", "wb") as f:
    pickle.dump(train_dataset, f)

In [None]:
batch_size = 2

train_dataloader = DataLoader(train_dataset, sampler = RandomSampler(train_dataset), batch_size=batch_size, collate_fn=keyword_predictor_model.collate_fn)
# print(next(iter(train_dataloader)))

Use pickled data

In [None]:
# open pickled dataset
with open("keyword_prediction_dataset/keyword_predictor_train_dataset.pickle", "rb") as f:
    train_dataset_pickled = pickle.load(f)

In [None]:
with open("keyword_prediction_dataset/keyword_predictor_val_dataset.pickle", "rb") as f:
    val_dataset_pickled = pickle.load(f)

In [None]:
keyword_predictor_model.to('cpu')
keyword_predictor_model = KeywordPredictor(global_planning)
keyword_predictor_model.to(keyword_predictor_model.device)

In [None]:
batch_size = 8

train_dataloader = DataLoader(train_dataset_pickled, sampler = RandomSampler(train_dataset_pickled), batch_size=batch_size, collate_fn=keyword_predictor_model.collate_fn)
validation_dataloader = DataLoader(val_dataset_pickled, sampler = SequentialSampler(val_dataset_pickled), batch_size=batch_size, collate_fn=keyword_predictor_model.collate_fn)

In [None]:
num_epochs = 10
num_training_steps = len(train_dataloader) * num_epochs
optimizer = optim.Adam(keyword_predictor_model.parameters(), lr=1e-6)

In [None]:
total_t0 = time.time()
training_stats = []

for epoch in tqdm(range(num_epochs)):
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch + 1, num_epochs))

    # ========================================
    #               Training
    # ========================================
    print('Training...')
    t0 = time.time()
    total_train_loss = 0

    keyword_predictor_model.train()

    for step, batch in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()

        outputs = keyword_predictor_model(**batch)
        loss = torch.div(outputs['loss'], batch_size)
        logits = outputs['logits']

        total_train_loss += loss.item()
        # Get sample every x batches.
        if step % 50 == 0:    # print every 100 mini-batches
          elapsed = format_time(time.time() - t0)
          print('\n  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), loss, elapsed))
          # print('logits: ', logits[0])
          all_preds_idx = (logits[0].softmax(1).argmax(1)==1).nonzero().flatten()
          preds_ones = logits[0].softmax(1)[:, 1]
          _, topk_preds_idx_1 = logits[0].softmax(1)[:, 1].topk(5)
          _, topk_preds_idx_0 = logits[0].softmax(1)[:, 0].topk(5)
          topk = batch['nodes_ids'][0][topk_preds_idx_1]
          preds = batch['nodes_ids'][0][all_preds_idx]
          candidates = batch['nodes_ids'][0][batch['labels'][0] == 0]
          true = batch['nodes_ids'][0][batch['labels'][0] == 1]

          topk_pred_words = [keyword_predictor_model.global_planning.glove.itos[id] for id in topk]
          print('Predicted: ', [keyword_predictor_model.global_planning.glove.itos[id] for id in preds])
          print('Softmax: ', preds_ones[topk_preds_idx_1], preds_ones[topk_preds_idx_0])
          print('Top 5 predictions: ', topk_pred_words)
          print([keyword_predictor_model.global_planning.glove.itos[id] for id in candidates])
          predicted_candidates_id = [id for id in topk if id in candidates]
          predicted_true_id = [id for id in topk if id in true]
          print('True labels: ', [keyword_predictor_model.global_planning.glove.itos[id] for id in true])
          print('Predicted candidates: ', len(predicted_candidates_id), [keyword_predictor_model.global_planning.glove.itos[id] for id in predicted_candidates_id])
          print('True candidates: ', len(predicted_true_id), [keyword_predictor_model.global_planning.glove.itos[id] for id in predicted_true_id])

        loss.backward()
        optimizer.step()



    # Calculate the average loss over all of the batches.
    avg_train_loss = total_train_loss / len(train_dataloader)

    # Measure how long this epoch took.
    training_time = format_time(time.time() - t0)
    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epoch took: {:}".format(training_time))

    # ========================================
    #               Validation
    # ========================================

    print("")
    print("Running Validation...")

    t0 = time.time()

    keyword_predictor_model.eval()

    total_eval_loss = 0
    for step, batch in enumerate(tqdm(validation_dataloader)):
      loss = torch.div(keyword_predictor_model(**batch)['loss'], batch_size)
      total_eval_loss += loss.item()
    avg_val_loss = total_eval_loss / len(validation_dataloader)

    validation_time = format_time(time.time() - t0)

    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Validation took: {:}".format(validation_time))

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    )

print("")
print("Training complete!")
print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))



In [None]:
keyword_predictor_model.to('cpu')
model_state_dict_path = "keyword_predictor_state_dict_model_8_10_1e-6.pt"
torch.save(keyword_predictor_model.state_dict(), model_state_dict_path)
"""
To load:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
"""

In [None]:
import pandas as pd
# Display floats with two decimal places.
# pd.set_option('precision', 2)

# Create a DataFrame from our training statistics.
df_stats = pd.DataFrame(data=training_stats)

# Use the 'epoch' as the row index.
df_stats = df_stats.set_index('epoch')

# A hack to force the column headers to wrap.
#df = df.style.set_table_styles([dict(selector="th",props=[('max-width', '70px')])])

# Display the table.
df_stats

In [None]:
with open("predictor_stats.pickle", "wb") as f:
  pickle.dump(training_stats, f)

In [None]:
# Use plot styling from seaborn.
# sns.set(style='darkgrid')

# Increase the plot size and font size.
# sns.set(font_scale=1.5)
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (12,6)

# Plot the learning curve.
plt.plot(df_stats['Training Loss'], 'b-o', label="Training")
plt.plot(df_stats['Valid. Loss'], 'g-o', label="Validation")

# Label the plot.
plt.title("Training & Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.xticks(range(1, num_epochs + 1))

plt.show()

In [None]:
state_dict_model = KeywordPredictor(global_planning)
state_dict_model.load_state_dict(torch.load(model_state_dict_path))
state_dict_model.eval() # must be set before inference