In [1]:
import csv
import torch
from torch import optim
import random 
from pytorch_transformers.tokenization_distilbert import DistilBertTokenizer
from pytorch_transformers.modeling_distilbert import DistilBertModel

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
special_tokens_dict = {'additional_special_tokens': ['<PLH>', '<s>','</s>']}
tokenizer.add_special_tokens(special_tokens_dict)
encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')
encoder.resize_token_embeddings(len(tokenizer))

pad_token = tokenizer.pad_token
train_data = []
test_data = []

with open("C:\\Users\\nickh\\OneDrive\\Desktop\\inferences.csv") as csvDataFile:
    csvreader = csv.reader(csvDataFile, delimiter='\t')
    next(csvreader)
    for context1, context2, sample1, sample2, inf1, inf2, inf3, inf4 in csvreader:
        target = test_data if random.random() > 0.8 else train_data
        target.append({
            "context1":context1,
            "context2":context2,
            "sample1":sample1,
            "sample2":sample2,
            "inf1":inf1,
            "inf2":inf2,
            "inf3":inf3,
        })

def sample_train():
    positive = random.choice(train_data)
    while "inf_random" not in positive or positive["inf1"] == positive["inf_random"]:
        negative = random.choice(train_data)
        positive["inf_random"] = negative["inf1"]
    return positive

def sample_test():
    positive = random.choice(test_data)
    while "inf_random" not in positive or positive["inf1"] == positive["inf_random"]:
        negative = random.choice(train_data)
        positive["inf_random"] = negative["inf1"]
    return positive

In [10]:
class Classifier(torch.nn.Module):    
    def __init__(self, esz=1536):
        super().__init__()
        self.dense = torch.nn.Linear(
            esz, esz,
        )
        self.out = torch.nn.Linear(
            esz, 2,
        )
        self.relu = torch.nn.ReLU()
    
    def forward(self, input, hidden=None):
        #return self.relu(self.dense(input))
        return self.out(self.relu(self.dense(input)))

class Attention(torch.nn.Module):    
    def __init__(self, esz=768, seq_len=100):
        super().__init__()
        self.dense = torch.nn.Linear(
            esz, 1
        )        
        self.relu = torch.nn.ReLU()
        
    def forward(self, input, hidden=None):
        return self.relu(self.dense(input))
    
class Model():
    def __init__(self, lr=0.0001, esz=768, pad_len=20):
        self.pad_len = pad_len
        self.classifier = Classifier(esz=esz*3) 
        self.context_attention = Attention(esz=esz, seq_len=pad_len)
        self.inference_attention = Attention(esz=esz, seq_len=pad_len)
        self.optims = {
            'classifier': optim.SGD(self.classifier.parameters(), lr=lr),
            'context_attention': optim.SGD(self.context_attention.parameters(), lr=lr),
            'inference_attention': optim.SGD(self.inference_attention.parameters(), lr=lr),
        }
        self.loss = 0
        self.step = 0
        self.c_loss = torch.nn.CrossEntropyLoss()
    
    def zero_grad(self):
        for optimizer in self.optims.values():
            optimizer.zero_grad()
            
    def update_params(self):
        for optimizer in self.optims.values():
            optimizer.step()
            
    def get_and_attend_context(self, embedded):
        context_attention_mask1 = self.context_attention(embedded["context1"])
        context_attention_mask2 = self.context_attention(embedded["context2"])
        context = torch.sum((context_attention_mask1 * embedded["context1"]) + (context_attention_mask2 * embedded["context2"]), 1)
        return context
    
    def attend_inference(self, inf):
        inference_attention_mask = self.inference_attention(inf)
        return torch.sum(inference_attention_mask * inf, 1)
        return inf.mean(1)

    def attend_utterance(self, utterance):
        utterance_attention_mask = self.utterance_attention(utterance)
        return torch.sum(utterance_attention_mask * utterance, 1)
        return utterance.mean(1)
    
    def merge_utterance_with_inference_and_context(self, context, utterance, inference):
        return torch.cat([context,utterance,inference], 1)
        
    def eval_step(self):
        self.classifier.eval()
        self.context_attention.eval()
        self.inference_attention.eval()
        
        accuracy = 0
        num_steps = 10
        for _ in range(num_steps):
            xs = sample_test()
            #print("Sample: %s" % xs["sample1"])
            #print("Positive Inference 1: %s" % xs["inf1"])
            embedded = self.embed_sample(xs)
            
            context = self.get_and_attend_context(embedded)
            inf1 = self.attend_inference(embedded["inf1"])
            utterance1 = self.attend_inference(embedded["sample1"])
            merged_positive = self.merge_utterance_with_inference_and_context(context, utterance1, inf1)

            pred = self.classifier(merged_positive) 
            #print(pred)
            pred = torch.argmax(pred, 1)
            #if pred.item() == 0:
                #print("Inference does not follow")
                #accuracy += .append(0)
            if pred.item() == 1:
                #print("Inference follows")
                accuracy += 1

            inf_random = self.attend_inference(embedded["inf_random"])
            merged_negative = self.merge_utterance_with_inference_and_context(context, utterance1, inf_random)

            pred = self.classifier(merged_negative) 
            #print(pred)
            pred = torch.argmax(pred, 1)
            #print("Negative Inference 1: %s" % xs["inf_random"])
            if pred.item() == 0:
            #    print("Inference does not follow")
                #accuracy.append(1)
                accuracy += 1
            #else:
                #accuracy.append(0)
            #    print("Inference follows")

            #context = self.get_and_attend_context(embedded)
            #inp = self.attend_inference_and_merge_context(embedded["inf1"], context)
        print("Accuracy : %f" % (accuracy / num_steps))
                
    def embed_sample(self, xs):
        embedded = {}
        for k in ["context1","context2","sample1","sample2","inf1","inf2","inf3","inf_random"]:
            if k in xs and len(xs[k]) > 0:
                tokens = tokenizer.encode(xs[k])
                padded = torch.full((1, self.pad_len), tokenizer.pad_token_id, dtype=torch.long)
                padded[0,:len(tokens)] = torch.LongTensor([tokens])
                embedded[k] = encoder(padded)[0]
        return embedded
        
    def train_step(self):
        loss = 0
        self.zero_grad()
        self.classifier.train()
        self.context_attention.train()
        self.inference_attention.train()

        xs = sample_train()
        #print(xs)
        embedded = self.embed_sample(xs)
        
        context = self.get_and_attend_context(embedded)
        inf1 = self.attend_inference(embedded["inf1"])
        utterance1 = self.attend_inference(embedded["sample1"])
        merged_positive = self.merge_utterance_with_inference_and_context(context, utterance1, inf1)

        #print("Context shape : %s" % str(context.size()))
        #print("inf1 shape : %s" % str(inf1.size()))
        #print("inp1 shape : %s" % str(inp.size()))
        
        inf_random = self.attend_inference(embedded["inf_random"])
        merged_negative = self.merge_utterance_with_inference_and_context(context, utterance1, inf_random)
        inp = torch.cat([merged_positive, merged_negative], 0).unsqueeze(1)
        outs = [[1],[0]]
        
        #if "inf2" in embedded:
        #    inf2 = self.attend_inference_and_merge_context(embedded["inf2"], context)
        #    inp = torch.cat([inp, inf2, inf_random], 0)
        #    outs += [1,0]
        #    if "inf3" in embedded:
        #        inf3 = self.attend_inference_and_merge_context(embedded["inf3"], context)
        #        inp = torch.cat([inp, inf3, inf_random], 0)
        #        outs += [1,0]
        
        outs = torch.LongTensor(outs)
        #print("inp shape : %s" % str(inp.size()))
        #print(inp.size())
        #print(inp)
        pred = self.classifier(inp)
        #print(pred)
        #print(outs)
        #print("pred shape : %s" % str(pred.size()))
        #print("outs shape : %s" % str(outs.size()))
        loss = self.c_loss(torch.transpose(pred, 1, 2), outs)

        self.step += 1
        self.loss += loss
        if self.step % 5 == 0:
            print(self.step)
            print(self.loss / 50)
            self.loss = 0
            self.eval_step()

        loss.backward()
        self.update_params()
model = Model()
for i in range(1000):
    model.train_step()

5
tensor(0.0684, grad_fn=<DivBackward0>)
Accuracy : 1.000000
10
tensor(0.0702, grad_fn=<DivBackward0>)
Accuracy : 1.000000
15
tensor(0.0711, grad_fn=<DivBackward0>)
Accuracy : 1.000000
20
tensor(0.0697, grad_fn=<DivBackward0>)
Accuracy : 1.000000
25
tensor(0.0668, grad_fn=<DivBackward0>)
Accuracy : 1.000000
30
tensor(0.0680, grad_fn=<DivBackward0>)
Accuracy : 1.000000
35
tensor(0.0664, grad_fn=<DivBackward0>)
Accuracy : 1.000000
40
tensor(0.0698, grad_fn=<DivBackward0>)
Accuracy : 1.000000
45
tensor(0.0683, grad_fn=<DivBackward0>)


KeyboardInterrupt: 

In [None]:
print(encoder(torch.LongTensor([tokenizer.encode("I don't have a favourite author.")]))[0])
print(encoder(torch.LongTensor([tokenizer.encode("You play sport at least three times per week.")]))[0])