In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
flatten = lambda l: [item for sublist in l for item in sublist]

USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

FloatTensor = torch.FloatTensor
LongTensor = torch.LongTensor
ByteTensor = torch.ByteTensor

In [2]:
class DMN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout_p=0.1):
        super(DMN, self).__init__()
        
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(input_size, hidden_size, padding_idx=0) #sparse=True)
        self.input_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.question_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        
        self.gate = nn.Sequential(
                            nn.Linear(hidden_size * 4, hidden_size),
                            nn.Tanh(),
                            nn.Linear(hidden_size, 1),
                            nn.Sigmoid()
                        )
        
        self.attention_grucell =  nn.GRUCell(hidden_size, hidden_size)
        self.memory_grucell = nn.GRUCell(hidden_size, hidden_size)
        self.answer_grucell = nn.GRUCell(hidden_size * 2, hidden_size)
        self.answer_fc = nn.Linear(hidden_size, output_size)
        
        self.dropout = nn.Dropout(dropout_p)
        
    def init_hidden(self, inputs):
        hidden = Variable(torch.zeros(1, inputs.size(0), self.hidden_size))
        return hidden.cuda() if USE_CUDA else hidden
    
    def init_weight(self):
        nn.init.xavier_uniform(self.embed.state_dict()['weight'])
        
        for name, param in self.input_gru.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        for name, param in self.question_gru.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        for name, param in self.gate.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        for name, param in self.attention_grucell.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        for name, param in self.memory_grucell.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        for name, param in self.answer_grucell.state_dict().items():
            if 'weight' in name: nn.init.xavier_normal(param)
        
        nn.init.xavier_normal(self.answer_fc.state_dict()['weight'])
        self.answer_fc.bias.data.fill_(0)
        
    def forward(self, facts, fact_masks, questions, question_masks, num_decode, episodes=3, is_training=False):
        """
        facts : (B,T_C,T_I) / LongTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)
        fact_masks : (B,T_C,T_I) / ByteTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)
        questions : (B,T_Q) / LongTensor # batch_size, question_length
        question_masks : (B,T_Q) / ByteTensor # batch_size, question_length
        """
        # Input Module
        C = [] # encoded facts
        for fact, fact_mask in zip(facts, fact_masks):
            embeds = self.embed(fact)
            if is_training:
                embeds = self.dropout(embeds)
            hidden = self.init_hidden(fact)
            outputs, hidden = self.input_gru(embeds, hidden)
            real_hidden = []

            for i, o in enumerate(outputs): # B,T,D
                real_length = fact_mask[i].data.tolist().count(0) 
                real_hidden.append(o[real_length - 1])

            C.append(torch.cat(real_hidden).view(fact.size(0), -1).unsqueeze(0))
        
        encoded_facts = torch.cat(C) # B,T_C,D
        
        # Question Module
        embeds = self.embed(questions)
        if is_training:
            embeds = self.dropout(embeds)
        hidden = self.init_hidden(questions)
        outputs, hidden = self.question_gru(embeds, hidden)
        
        if isinstance(question_masks, torch.autograd.Variable):
            real_question = []
            for i, o in enumerate(outputs): # B,T,D
                real_length = question_masks[i].data.tolist().count(0) 
                real_question.append(o[real_length - 1])
            encoded_question = torch.cat(real_question).view(questions.size(0), -1) # B,D
        else: # for inference mode
            encoded_question = hidden.squeeze(0) # B,D
            
        # Episodic Memory Module
        memory = encoded_question
        T_C = encoded_facts.size(1)
        B = encoded_facts.size(0)
        for i in range(episodes):
            hidden = self.init_hidden(encoded_facts.transpose(0, 1)[0]).squeeze(0) # B,D
            for t in range(T_C):
                #TODO: fact masking
                #TODO: gate function => softmax
                z = torch.cat([
                                    encoded_facts.transpose(0, 1)[t] * encoded_question, # B,D , element-wise product
                                    encoded_facts.transpose(0, 1)[t] * memory, # B,D , element-wise product
                                    torch.abs(encoded_facts.transpose(0,1)[t] - encoded_question), # B,D
                                    torch.abs(encoded_facts.transpose(0,1)[t] - memory) # B,D
                                ], 1)
                g_t = self.gate(z) # B,1 scalar
                hidden = g_t * self.attention_grucell(encoded_facts.transpose(0, 1)[t], hidden) + (1 - g_t) * hidden
                
            e = hidden
            memory = self.memory_grucell(e, memory)
        
        # Answer Module
        answer_hidden = memory
        start_decode = Variable(LongTensor([[word2index['<s>']] * memory.size(0)])).transpose(0, 1)
        y_t_1 = self.embed(start_decode).squeeze(1) # B,D
        
        decodes = []
        for t in range(num_decode):
            answer_hidden = self.answer_grucell(torch.cat([y_t_1, encoded_question], 1), answer_hidden)
            decodes.append(F.log_softmax(self.answer_fc(answer_hidden),1))
        return torch.cat(decodes, 1).view(B * num_decode, -1)

In [3]:
import copy

def bAbI_data_load(path):
    try:
        data = open(path, 'r', encoding='utf8').readlines()
    except:
        print("Such a file does not exist at %s".format(path))
        return None
    
    data = [d[:-1] for d in data]
    data_p = []
    fact = []
    qa = []
    try:
        for d in data:
            index = d.split(' ')[0]
            if index == '1':
                fact = []
                qa = []
            if '?' in d:
                temp = d.split('\t')
                q = temp[0].strip().replace('?', '').split(' ')[1:] + ['?']
                a = temp[1].split() + ['</s>']
                stemp = copy.deepcopy(fact)
                data_p.append([stemp, q, a])
            else:
                tokens = d.replace('.', '').split(' ')[1:] + ['</s>']
                fact.append(tokens)
    except Exception as e:
        print(e)
        print("Please check the data is right")
        return None
    return data_p

In [28]:
path = "./data/train/[최종]졸업이수학점.txt"
#path = "C:/Users/82104/Desktop/4_1/산학연계/_산학연계/data/en/qa1_single-supporting-fact_train.txt"
test_data = bAbI_data_load('./data/test/[TEST]교양교과목안내.txt')
train_data = bAbI_data_load(path)

In [29]:
fact,q,a = list(zip(*train_data))

In [30]:
vocab = list(set(flatten(flatten(fact)) + flatten(q) + flatten(a)))

In [31]:
word2index={'<PAD>': 0, '<UNK>': 1, '<s>': 2, '</s>': 3}
for vo in vocab:
    if word2index.get(vo) is None:
        word2index[vo] = len(word2index)
index2word = {v:k for k, v in word2index.items()}

In [32]:
len(word2index)
word2index

{'<PAD>': 0,
 '<UNK>': 1,
 '<s>': 2,
 '</s>': 3,
 '': 4,
 '영역': 5,
 '들': 6,
 '영어회화': 7,
 '필수+균형': 8,
 '이수해야': 9,
 '못한': 10,
 '응': 11,
 '3학점': 12,
 '과목만': 13,
 '못하면': 14,
 '균형교양,': 15,
 '들어야하는': 16,
 '수': 17,
 '융합적사고와글쓰기가': 18,
 '영어,': 19,
 '22학점': 20,
 '해': 21,
 '돼': 22,
 '21학점': 23,
 '3영역x3학점': 24,
 '우리가': 25,
 '18학번': 26,
 '몇': 27,
 '사람만': 28,
 '정보가': 29,
 '인정돼': 30,
 '필수': 31,
 '의무': 32,
 '?': 33,
 '아니': 34,
 '있어': 35,
 '때': 36,
 '못': 37,
 '뭐야': 38,
 '정보,': 39,
 '17학번': 40,
 '소프트웨어학부': 41,
 '야': 42,
 '광운인되기,': 43,
 '들어도': 44,
 '19~22학점': 45,
 '24학점': 46,
 '영어회화가': 47,
 '5영역': 48,
 '9학점': 49,
 '이수학점': 50,
 '학점': 51,
 '교양': 52,
 '정보': 53,
 '필수교양': 54,
 '필수과목': 55,
 '졸업': 56,
 '통과하면': 57,
 '영어': 58,
 '20학번': 59,
 '정보융합학부': 60,
 '광운인되기': 61,
 '기초교양': 62,
 '융합적사고와글쓰기': 63,
 '대학영어는': 64,
 '들으면': 65,
 '는': 66,
 '통과하지': 67,
 '듣는거야': 68,
 '중': 69,
 '균형교양': 70,
 '대학영어': 71,
 '133학점': 72,
 '대학영어,': 73,
 '들어야': 74,
 '27학점': 75,
 '필수야': 76,
 '영어레벨테스트': 77,
 '뭐가': 78,
 '19학번': 79,
 '통과할': 80,
 '필

In [9]:
def pad_to_fact(fact, x_to_ix): # this is for inference
    
    max_x = max([s.size(1) for s in fact])
    x_p = []
    for i in range(len(fact)):
        if fact[i].size(1) < max_x:
            x_p.append(torch.cat([fact[i], Variable(LongTensor([x_to_ix['<PAD>']] * (max_x - fact[i].size(1)))).view(1, -1)], 1))
        else:
            x_p.append(fact[i])
        
    fact = torch.cat(x_p)
    fact_mask = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in fact]).view(fact.size(0), -1)
    return fact, fact_mask

In [10]:
def prepare_sequence(seq, to_index):
    idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index["<UNK>"], seq))
    return Variable(LongTensor(idxs))

In [11]:
for t in train_data:
    for i,fact in enumerate(t[0]):
        t[0][i] = prepare_sequence(fact, word2index).view(1, -1)
    
    t[1] = prepare_sequence(t[1], word2index).view(1, -1)
    t[2] = prepare_sequence(t[2], word2index).view(1, -1)

In [12]:
for t in test_data:
    for i, fact in enumerate(t[0]):
        t[0][i] = prepare_sequence(fact, word2index).view(1, -1)
    
    t[1] = prepare_sequence(t[1], word2index).view(1, -1)
    t[2] = prepare_sequence(t[2], word2index).view(1, -1)

In [13]:
import torch.optim as optim

model = DMN(len(word2index), 80, len(word2index))
#model.hidden = model.init_hidden()
loss_function = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.001)

'''
model = torch.load('C:/Users/82104/Desktop/데이터/모델 데이터/model/[0601]crayon.pt')
model.load_state_dict(torch.load('C:/Users/82104/Desktop/데이터/모델 데이터/model/[0602]crayon.pth'))
for parameter in model.parameters():
    parameter.requires_grad = False
'''

checkpoint = torch.load('C:/Users/82104/Desktop/데이터/모델 데이터/model/[졸업이수학점0603]crayon.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
loss_function.state_dict(checkpoint['criterion_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.to("cpu")
model.eval()

#print(model)
#print(model.state_dict())
print(model)

DMN(
  (embed): Embedding(84, 80, padding_idx=0)
  (input_gru): GRU(80, 80, batch_first=True)
  (question_gru): GRU(80, 80, batch_first=True)
  (gate): Sequential(
    (0): Linear(in_features=320, out_features=80, bias=True)
    (1): Tanh()
    (2): Linear(in_features=80, out_features=1, bias=True)
    (3): Sigmoid()
  )
  (attention_grucell): GRUCell(80, 80)
  (memory_grucell): GRUCell(80, 80)
  (answer_grucell): GRUCell(160, 80)
  (answer_fc): Linear(in_features=80, out_features=84, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


In [14]:
accuracy = 0

for t in test_data:  
    fact, fact_mask = pad_to_fact(t[0], word2index)
    question = t[1]
    question_mask = Variable(ByteTensor([0] * t[1].size(1)), volatile=False).unsqueeze(0)
    answer = t[2].squeeze(0)
    
    model.zero_grad()
    pred = model([fact], [fact_mask], question, question_mask, answer.size(0), 3, True)
    if pred.max(1)[1].data.tolist() == answer.data.tolist():
        accuracy += 1
    #print("Facts : ")
    #print('\n'.join([' '.join(list(map(lambda x: index2word[x],f))) for f in fact.data.tolist()]))
    
    print("")
    print("Question : ",' '.join(list(map(lambda x: index2word[x], question.data.tolist()[0]))))
    print("")
    print("Answer : ",' '.join(list(map(lambda x: index2word[x], answer.data.tolist()))))
    print("Prediction : ",' '.join(list(map(lambda x: index2word[x], pred.max(1)[1].data.tolist()))))
    
print(accuracy/len(test_data) * 100)


Question :  17학번  영어회화  필수  들어야 해   ?

Answer :  아니 </s>
Prediction :  정보융합학부 </s>

Question :  17학번  영어회화가 필수  ?

Answer :  아니 </s>
Prediction :  정보융합학부 </s>

Question :  17학번  영어회화가 필수과목 야   ?

Answer :  아니 </s>
Prediction :  정보융합학부 </s>

Question :  18학번  영어회화  필수 들어야 해   ?

Answer :  아니 </s>
Prediction :  19학번 </s>

Question :  18학번  영어회화가 필수야   ?

Answer :  아니 </s>
Prediction :  19학번 </s>

Question :  18학번  영어회화가 필수과목 야   ?

Answer :  아니 </s>
Prediction :  19학번 </s>

Question :  영어레벨테스트  통과하지 못하면 대학영어 들어야 돼  ?

Answer :  응 </s>
Prediction :  정보융합학부 </s>

Question :  영어레벨테스트  못 통과하면 대학영어 들어야 돼  ?

Answer :  응 </s>
Prediction :  해 </s>

Question :  대학영어는 영어레벨테스트 못 통과할 때 듣는거야  ?

Answer :  응 </s>
Prediction :  정보융합학부 </s>

Question :  우리가 들어야하는 교양 영역  뭐야  ?

Answer :  필수교양 균형교양 기초교양 </s>
Prediction :  통과하면 </s> </s> </s>

Question :  우리가 들어야하는 교양 영역  뭐가 있어  ?

Answer :  필수교양 균형교양 기초교양 </s>
Prediction :  통과하면 </s> </s> </s>

Question :  17학번 필수교양 는 뭐가 있어  ?

Answer :  광운인되기 영어 정보 </s

In [15]:
s = "17학번 컴퓨터정보공학부 졸업이수학점 ?"
sl = []
sl_mask = []
for i in s.split():
    sl.append(word2index[i])
    sl_mask.append(0)
sl = torch.tensor([sl])
sl_mask = torch.tensor([sl_mask])

KeyError: '졸업이수학점'

In [None]:
pred = model([fact], [fact_mask], sl, sl_mask, 3, 3)

In [None]:
for i in pred.max(1)[1].data.tolist():
    w = index2word[i]
    print(w, end=' ')

In [None]:
index2word

In [None]:
for param in model.state_dict():
    print(param)