In [2]:
import csv
import numpy as np
import torch
from torch import optim
import random 
from pytorch_transformers.tokenization_distilbert import DistilBertTokenizer
from pytorch_transformers.modeling_distilbert import DistilBertModel

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
special_tokens_dict = {'additional_special_tokens': ['<PLH>', '<s>','</s>']}
tokenizer.add_special_tokens(special_tokens_dict)
encoder = DistilBertModel.from_pretrained('distilbert-base-uncased').cuda()
encoder.resize_token_embeddings(len(tokenizer))

for param in encoder.parameters():
    param.requires_grad = False

pad_token = tokenizer.pad_token
train_data = []
test_data = []
    
with open("memory_conversation.tsv") as csvDataFile:
    csvreader = csv.reader(csvDataFile, delimiter='\t')
    next(csvreader)
    
    data = []
    conversation = {
                "conv_id":None,
                "exchange_ids":[],
                "user_responses":[],
                "num_states":1,
                "state_updates":[],
                "agent_utterances":[],
                "state_embeddings":[]
            }
    max_states = 0
    max_utterances = 0
    
    for conv_id, exchange_id, agent, user, action1, action2, action3 in csvreader:
        if conversation["conv_id"] is None:
            conversation["conv_id"] = conv_id
        elif conv_id != conversation["conv_id"]:
            data.append(conversation)
            max_utterances = max(
                max_utterances, 
                len(conversation["exchange_ids"])
            )
            conversation = {
                "conv_id":conv_id,
                "exchange_ids":[],
                "user_responses":[],
                "num_states":1,
                "state_updates":[],
                "agent_utterances":[],
                "state_embeddings":[]
            }
        
        conversation["agent_utterances"].append(agent)
        conversation["user_responses"].append(user)
        
        state_updates = []
        for action in (action1, action2, action3):
            if action != "NULL":
                tokenized_action = torch.cuda.LongTensor([
                    tokenizer.encode(
                        action.split(", ")[1]
                    )
                ])
                embedding = encoder(tokenized_action)[0]
                idx = conversation["num_states"]
                conversation["num_states"] += 1
                state_updates.append((idx, embedding))
        conversation["state_updates"] = state_updates
        conversation["exchange_ids"].append(exchange_id)
        max_states = max(max_states, len(conversation["exchange_ids"]))
    if len(conversation) > 0:
        data.append(conversation)
    max_utterances = max(
        max_utterances, 
        len(conversation["agent_utterances"])
    )

In [3]:
class Attention(torch.nn.Module):
    def __init__(self, esz=768, hsz=768):
        super().__init__()
        self.attn = torch.nn.MultiheadAttention(esz, 8)
        
    def forward(self, input, hidden=None):
        attn_mask, _ = self.attn(input, input, input)
        return torch.sum(input * attn_mask, 1)
        
class UtterancePredictor(torch.nn.Module):    
    def __init__(self, esz=768, hsz=768):
        super().__init__()
        self.hsz = hsz
        self.attn = torch.nn.MultiheadAttention(esz, 8)
        self.dense = torch.nn.Linear(
            esz, hsz,
        )
        self.relu = torch.nn.ReLU()
        self.out = torch.nn.Linear(
            hsz, esz,
        )
    
    # accepts seq_len x num_values  returns num_keys
    def forward(self, input, hidden=None):
        input = input.unsqueeze(0).unsqueeze(0)        
        attended = input
        out = self.out(self.relu(self.dense(attended)))
        return out

class StateUpdater(torch.nn.Module):
    def __init__(self, num_values, hsz=768, esz=768):
        super().__init__()
        
        self.dense = torch.nn.Linear(
            num_values * 3, hsz,
        )
        self.relu = torch.nn.ReLU()
        self.out = torch.nn.Linear(
            hsz, esz,
        )
        self.attn = torch.nn.MultiheadAttention(esz, 8)
        
    # accepts tuple of (num_keys x esz, esz x pad_len, esz x pad_len)  returns num_keys x esz
    def forward(self, input, hidden=None):
        state, agent, user = input
        agent_attention, _ = self.attn(agent, agent, agent)
        agent_attention = torch.sum(agent_attention * agent, 1)
        #print("agent.size() : %s" % str(agent.size()))
        user_attention, _ = self.attn(user, user, user)
        user_attention = torch.sum(user_attention * user, 1)
        #print("user.size() : %s" % str(user.size()))
        dialog = torch.cat([agent_attention, user_attention], 1)
        #print("dialog.size() : %s" % str(dialog.size()))
        dialog = dialog.repeat(state.size(0), 1)
        
        inp = torch.cat([state, dialog], 1)
        return self.out(self.relu(self.dense(inp)))

In [5]:
class Model():
    def __init__(self, lr=0.0001, hsz=768, pad_len=20):
        self.pad_len = pad_len
        self.utterance_selector = UtterancePredictor().cuda()
        self.state_updater = StateUpdater(hsz).cuda()
        self.encoder_attention = Attention().cuda()
        self.optims = {
            'utterance_selector': optim.Adam(self.utterance_selector.parameters(), lr=lr),
            'state_updater': optim.Adam(self.state_updater.parameters(), lr=lr),
            'encoder_attention': optim.Adam(self.encoder_attention.parameters(), lr=lr),
        }
        self.loss = 0
        self.step = 0
        self.utterance_criterion = torch.nn.MSELoss()
        self.state_criterion = torch.nn.MSELoss()
    
    def zero_grad(self):
        for optimizer in self.optims.values():
            optimizer.zero_grad()
            
    def update_params(self):
        for optimizer in self.optims.values():
            optimizer.step()
            
    def encode(self, agent_utterance, user_response):
        agent_utterance = random.choice(agent_utterance.split("|"))
        encoded = encoder(
            torch.LongTensor(
                [
                    tokenizer.encode(
                            "[CLS] " + agent_utterance + " [SEP] " + user_response
                    )
                ]
            ).cuda()
        )[0].cuda()
        return self.encoder_attention(encoded)
    
    def encode_option(self, i, conversation, utterance_matrix, dialog):
        prefix = conversation["exchange_ids"][i][:-1] 
        ids, utterances, responses = [], [], []
        level = conversation["exchange_ids"][i].count("_")
        while conversation["exchange_ids"][i].count("_") == level:
            utterances.append(conversation["agent_utterances"][i])
            responses.append(conversation["user_responses"][i])
            ids.append(conversation["exchange_ids"][i])
            i += 1

        idx = random.choice(range(len(ids)))
        id = ids[idx]
        utterances = [utterances[idx]]
        responses = [responses[idx]]
        
#        print("Encoding %s" % id)

        utterance_encoded = self.encode(
            utterances[0],
            responses[0]
        )
        utterance_matrix.append((utterance_encoded + utterance_matrix[-1]) / 2)
        
        while prefix in conversation["exchange_ids"][i]:
            #print("Prefix %s in %s " % (prefix,conversation["exchange_ids"][i] ))
            if id not in conversation["exchange_ids"][i]:
                #print("ID %s not in %s " % (id,conversation["exchange_ids"][i] ))
                i += 1
                continue
            #print("ID %s IS in %s " % (id,conversation["exchange_ids"][i] ))

            if conversation["exchange_ids"][i].count("_") == level + 1:
                #print("Encoding conditional %s" % (conversation["exchange_ids"][i]))
                utterances.append(conversation["agent_utterances"][i])
                responses.append(conversation["user_responses"][i])
                utterance_encoded = self.encode(
                    utterances[-1], 
                    responses[-1]
                )
                utterance_matrix.append((utterance_encoded + utterance_matrix[-1]) / 2)
                i += 1
            elif conversation["exchange_ids"][i].count("_") == level + 2:
                i = self.encode_option(i, conversation, utterance_matrix, dialog)
            else:
                raise Exception()
        for j in range(len(utterances)):
            dialog.append((utterances[j], responses[j]))
        return i
                        
    def encode_utterance_matrix(self, conversation, esz=768):
        dialog = []
        utterance_matrix = []
        i = 0
        while i < len(conversation["exchange_ids"]):
            exchange_id = conversation["exchange_ids"][i]
            if not exchange_id.endswith("_A"):
                agent_utterance = conversation["agent_utterances"][i]
                user_response = conversation["user_responses"][i]
                dialog.append((agent_utterance, user_response))
                utterance_encoded = self.encode(agent_utterance, user_response)
                if i > 0:
                    utterance_matrix.append((utterance_encoded + utterance_matrix[-1]) / 2)
                else:
                    utterance_matrix.append(utterance_encoded)
                i += 1
            else:
                i = self.encode_option(i, conversation, utterance_matrix, dialog)

        return torch.cat(utterance_matrix, 0).cuda(), dialog
   
    def eval_step(self):
        self.utterance_selector.eval()
        self.state_updater.eval()
        
        accuracy = 0
        
        unset_token = tokenizer.encode("UNSET")
        unset_embedding = encoder(torch.cuda.LongTensor([unset_token]))[0].mean(1)
        
        conversation = random.choice(data)
        utterance_matrix, dialog = self.encode_utterance_matrix(conversation)

        # initialize state matrix with every entry set to "UNSET"
        state_matrix = unset_embedding.repeat(conversation["num_states"], 1)

        for i in range(utterance_matrix.size(0) - 1):   
            predicted_utterance = self.utterance_selector(utterance_matrix[i])
            dists = torch.nn.PairwiseDistance()(predicted_utterance.squeeze(0), utterance_matrix[:,:768])
            dists, indices = torch.min(dists, 0)
            if dialog[i+1][0] == dialog[indices.item()][0]:
                accuracy += 1
            print("Agent: %s User: %s"  % (dialog[i][0], dialog[i][1]))
            print("Predicted Agent: %s"  % (dialog[indices.item()][0]))
            print("Ground Agent: %s"  % (dialog[i+1][0]))
            print("#####")
        print("Evaluation Accuracy : %f" % (accuracy / (utterance_matrix.size(0) - 1)))
        
    def train_step(self):
        
        utterance_loss = None
        state_loss = None

        unset_token = tokenizer.encode("UNSET")
        unset_embedding = encoder(torch.cuda.LongTensor([unset_token]))[0].mean(1)
        
        # one conversation = one episode = one training loop iteration
        loss = 0
        for conversation in data:            
            utterance_matrix, dialog = self.encode_utterance_matrix(conversation)

            
            # initialize state matrix with every entry set to "UNSET"
            state_matrix = unset_embedding.repeat(conversation["num_states"], 1)
            
            for i in range(utterance_matrix.size(0) - 1):
                self.zero_grad()
                self.utterance_selector.train()
                self.state_updater.train()
                
                predicted_utterance = self.utterance_selector(utterance_matrix[i])
                utterance_loss = self.utterance_criterion(
                    predicted_utterance,
                    utterance_matrix[i+1,:768]
                )
                                
                if utterance_loss is not None:
                    loss += utterance_loss
                if state_loss is not None:
                    loss += state_loss
                
            if loss > 0:
                loss.backward()
                self.update_params()
            
            self.step += 1
            self.loss += loss
            loss = 0
            if self.step % 100 == 0:
                print("Step %d Loss %f" % (self.step, self.loss.item() / 50))

                self.loss = 0
                self.eval_step()

model = Model()
for i in range(1000):
    model.train_step()

Step 100 Loss 0.492603
Agent: NULL User: NULL
Predicted Agent: NULL
Ground Agent: Hi! 
#####
Agent: Hi!  User: NULL
Predicted Agent: NULL
Ground Agent: Did you want to talk about your family?
#####
Agent: Did you want to talk about your family? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: Sure
Predicted Agent: NULL
Ground Agent: Let's start with your father.
#####
Agent: Let's start with your father. User: NULL
Predicted Agent: NULL
Ground Agent: Is your father still alive?
#####
Agent: Is your father still alive? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: No, he is dead.
Predicted Agent: NULL
Ground Agent: I'm sorry to hear that. 
#####
Agent: I'm sorry to hear that.  User: NULL
Predicted Agent: NULL
Ground Agent: What about your mother?
#####
Agent: What about your mother? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: She is still alive.
Predicted Agent: NULL
Ground Agent: That's good to hea

Step 500 Loss 0.016105
Agent: NULL User: NULL
Predicted Agent: Hello
Ground Agent: Hello
#####
Agent: Hello User: NULL
Predicted Agent: Do you want to talk about sports?
Ground Agent: Do you want to talk about sports?
#####
Agent: Do you want to talk about sports? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: No, I'd prefer to talk about family
Predicted Agent: OK, let's talk about family.
Ground Agent: OK, let's talk about family.
#####
Agent: OK, let's talk about family. User: NULL
Predicted Agent: How old is your father?
Ground Agent: How old is your father?
#####
Agent: How old is your father? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: He is 35.
Predicted Agent: That's so young!
Ground Agent: That's so young!
#####
Agent: That's so young! User: NULL
Predicted Agent: How about your mother?
Ground Agent: How about your mother?
#####
Agent: How about your mother? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Ag

Step 900 Loss 0.005300
Agent: NULL User: NULL
Predicted Agent: Hello
Ground Agent: Hello
#####
Agent: Hello User: NULL
Predicted Agent: Do you want to talk about sports?
Ground Agent: Do you want to talk about sports?
#####
Agent: Do you want to talk about sports? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: No, I'd prefer to talk about family
Predicted Agent: OK, let's talk about family.
Ground Agent: OK, let's talk about family.
#####
Agent: OK, let's talk about family. User: NULL
Predicted Agent: How old is your father?
Ground Agent: How old is your father?
#####
Agent: How old is your father? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: He is 35.
Predicted Agent: That's so young!
Ground Agent: That's so young!
#####
Agent: That's so young! User: NULL
Predicted Agent: How about your mother?
Ground Agent: How about your mother?
#####
Agent: How about your mother? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Ag

Step 1200 Loss 0.002906
Agent: NULL User: NULL
Predicted Agent: Hello
Ground Agent: Hello
#####
Agent: Hello User: NULL
Predicted Agent: Do you want to talk about sports?
Ground Agent: Do you want to talk about sports?
#####
Agent: Do you want to talk about sports? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: No, I'd prefer to talk about family
Predicted Agent: OK, let's talk about family.
Ground Agent: OK, let's talk about family.
#####
Agent: OK, let's talk about family. User: NULL
Predicted Agent: How old is your father?
Ground Agent: How old is your father?
#####
Agent: How old is your father? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: He is 35.
Predicted Agent: That's so young!
Ground Agent: That's so young!
#####
Agent: That's so young! User: NULL
Predicted Agent: How about your mother?
Ground Agent: How about your mother?
#####
Agent: How about your mother? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
A

Step 1600 Loss 0.001582
Agent: NULL User: NULL
Predicted Agent: Hello
Ground Agent: Hello
#####
Agent: Hello User: NULL
Predicted Agent: Do you want to talk about sports?
Ground Agent: Do you want to talk about sports?
#####
Agent: Do you want to talk about sports? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: No, I'd prefer to talk about family
Predicted Agent: OK, let's talk about family.
Ground Agent: OK, let's talk about family.
#####
Agent: OK, let's talk about family. User: NULL
Predicted Agent: How old is your father?
Ground Agent: How old is your father?
#####
Agent: How old is your father? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: He is 35.
Predicted Agent: That's so young!
Ground Agent: That's so young!
#####
Agent: That's so young! User: NULL
Predicted Agent: How about your mother?
Ground Agent: How about your mother?
#####
Agent: How about your mother? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
A

Step 2000 Loss 0.002063
Agent: NULL User: NULL
Predicted Agent: Hi! 
Ground Agent: Hi! 
#####
Agent: Hi!  User: NULL
Predicted Agent: Did you want to talk about your family?
Ground Agent: Did you want to talk about your family?
#####
Agent: Did you want to talk about your family? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: Sure
Predicted Agent: Let's start with your father.
Ground Agent: Let's start with your father.
#####
Agent: Let's start with your father. User: NULL
Predicted Agent: Is your father still alive?
Ground Agent: Is your father still alive?
#####
Agent: Is your father still alive? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: No, he is dead.
Predicted Agent: I'm sorry to hear that. 
Ground Agent: I'm sorry to hear that. 
#####
Agent: I'm sorry to hear that.  User: NULL
Predicted Agent: What about your mother?
Ground Agent: What about your mother?
#####
Agent: What about your mother? User: NULL
Predicted Agent:

Step 2400 Loss 0.001389
Agent: NULL User: NULL
Predicted Agent: Hi! 
Ground Agent: Hi! 
#####
Agent: Hi!  User: NULL
Predicted Agent: Did you want to talk about your family?
Ground Agent: Did you want to talk about your family?
#####
Agent: Did you want to talk about your family? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: Sure
Predicted Agent: Let's start with your father.
Ground Agent: Let's start with your father.
#####
Agent: Let's start with your father. User: NULL
Predicted Agent: Is your father still alive?
Ground Agent: Is your father still alive?
#####
Agent: Is your father still alive? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: No, he is dead.
Predicted Agent: I'm sorry to hear that. 
Ground Agent: I'm sorry to hear that. 
#####
Agent: I'm sorry to hear that.  User: NULL
Predicted Agent: What about your mother?
Ground Agent: What about your mother?
#####
Agent: What about your mother? User: NULL
Predicted Agent:

Step 2800 Loss 0.000933
Agent: NULL User: NULL
Predicted Agent: Good evening!
Ground Agent: Good evening!
#####
Agent: Good evening! User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: Hello.
Predicted Agent: I feel like talking about my dad. Is that OK?
Ground Agent: I feel like talking about my dad. Is that OK?
#####
Agent: I feel like talking about my dad. Is that OK? User: NULL
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: Sure
Predicted Agent: NULL
Ground Agent: NULL
#####
Agent: NULL User: Can we talk about your mother?
Predicted Agent: My dad just turned 65
Ground Agent: My dad just turned 65
#####
Agent: My dad just turned 65 User: NULL
Predicted Agent: He used to be a doctor, but now he's retired.
Ground Agent: He used to be a doctor, but now he's retired.
#####
Agent: He used to be a doctor, but now he's retired. User: NULL
Predicted Agent: OK fine. How is your mother doing?
Ground Agent: OK fine. How is your mother doing?
#####
Agen

In [None]:
data[2]