In [None]:
#from IPython.display import display, HTML
#display(HTML("<style>.container { width:100% !important; }</style>"))

**GPU 사용 설정**

In [None]:
%env CUDA_VISIBLE_DEVICES=0,1,2,3
#%env CUBLAS_WORKSPACE_CONFIG=:4096:8

**Dataset 지정**

In [None]:
#dataset_name = "original"  # originial dataset in the paper
#dataset_name = "ko"        # helper dataset
dataset_name = "1cycle"     # 1cycle dataset

if dataset_name == "original":
    EMBEDDING_FILE = "data/embeddings.pkl"
else:
    EMBEDDING_FILE = f"data/embeddings_{dataset_name}.pkl"

**Random seed**

In [None]:
import numpy as np
import torch
import random

random_seed = 42

random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
#torch.use_deterministic_algorithms(True)
#torch.backends.cudnn.deterministic = True

**Model 정의**

In [None]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import *
import numpy as np
from torch.nn.parameter import Parameter
import pickle
from typing import Optional

class layer_normalization(nn.Module):

    def __init__(self, features, epsilon=1e-8):
    #def __init__(self, features, epsilon=1e-12):
        '''Applies layer normalization.
        Args:
          epsilon: A floating number. A very small number for preventing ZeroDivision Error.
        '''
        super(layer_normalization, self).__init__()
        self.epsilon = epsilon
        #self.gamma = nn.Parameter(torch.ones(features))
        #self.beta = nn.Parameter(torch.zeros(features))
        self.layernorm = nn.LayerNorm(features, eps=1e-5)

    def forward(self, x):
        #mean = x.mean(-1, keepdim=True)
        #std = x.std(-1, keepdim=True)
        #return self.gamma * (x - mean) / (std + self.epsilon) + self.beta
        #var = x.var(-1, keepdim=True)
        #return self.gamma * (x - mean) / torch.sqrt(var + self.epsilon) + self.beta
        return self.layernorm(x)
    
class multihead_attention(nn.Module):

    def __init__(self, num_units, num_heads=8, dropout_rate=0, causality=False):
        '''Applies multihead attention.
        Args:
            num_units: A scalar. Attention size.
            dropout_rate: A floating point number.
            causality: Boolean. If true, units that reference the future are masked.
            num_heads: An int. Number of heads.
        '''
        super(multihead_attention, self).__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.causality = causality
        self.Q_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        self.K_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        self.V_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        
        # tensorflow compatible initializer
        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.zeros_(m.bias)
        self.Q_proj.apply(init_weights)
        self.K_proj.apply(init_weights)
        self.V_proj.apply(init_weights)

        self.output_dropout = nn.Dropout(p=self.dropout_rate)

        self.normalization = layer_normalization(self.num_units)

    def forward(self, queries, keys, values):
        # keys, values: same shape of [N, T_k, C_k]
        # queries: A 3d Variable with shape of [N, T_q, C_q]

        #print('q shape:', queries.shape)
        # Linear projections
        Q = self.Q_proj(queries)  # (N, T_q, C)
        #print('Q shape:', Q.shape)
        K = self.K_proj(keys)  # (N, T_q, C)
        V = self.V_proj(values)  # (N, T_q, C)

        # Split and concat
        Q_ = torch.cat(torch.chunk(Q, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)
        K_ = torch.cat(torch.chunk(K, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)
        V_ = torch.cat(torch.chunk(V, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)

        # Multiplication
        outputs = torch.bmm(Q_, K_.permute(0, 2, 1))  # (h*N, T_q, T_k)

        # Scale
        outputs = outputs / (K_.size()[-1] ** 0.5)

        # Key Masking
        key_masks = torch.sign(torch.abs(torch.sum(keys, dim=-1)))  # (N, T_k)
        key_masks = key_masks.repeat(self.num_heads, 1)  # (h*N, T_k)
        key_masks = torch.unsqueeze(key_masks, 1).repeat(1, queries.size()[1], 1)  # (h*N, T_q, T_k)

        padding = Variable(torch.ones(*outputs.size()).cuda() * (-2 ** 32 + 1))
        condition = key_masks.eq(0.).float()
        outputs = padding * condition + outputs * (1. - condition)

        # Causality = Future blinding
        if self.causality:
            diag_vals = torch.ones(*outputs[0, :, :].size()).cuda()  # (T_q, T_k)
            tril = torch.tril(diag_vals, diagonal=0)  # (T_q, T_k)
            # print(tril)
            masks = Variable(torch.unsqueeze(tril, 0).repeat(outputs.size()[0], 1, 1))  # (h*N, T_q, T_k)

            padding = Variable(torch.ones(*masks.size()).cuda() * (-2 ** 32 + 1))
            condition = masks.eq(0.).float()
            outputs = padding * condition + outputs * (1. - condition)

        # Activation
        outputs = F.softmax(outputs, dim=-1)  # (h*N, T_q, T_k)

        # Query Masking
        query_masks = torch.sign(torch.abs(torch.sum(queries, dim=-1)))  # (N, T_q)
        query_masks = query_masks.repeat(self.num_heads, 1)  # (h*N, T_q)
        query_masks = torch.unsqueeze(query_masks, 2).repeat(1, 1, keys.size()[1])  # (h*N, T_q, T_k)
        outputs = outputs * query_masks

        # Dropouts
        outputs = self.output_dropout(outputs)  # (h*N, T_q, T_k)

        # Weighted sum
        outputs = torch.bmm(outputs, V_)  # (h*N, T_q, C/h)

        # Restore shape
        outputs = torch.cat(torch.chunk(outputs, self.num_heads, dim=0), dim=2)  # (N, T_q, C)

        # Residual connection
        outputs += queries

        # Normalize
        #outputs = self.normalization(outputs)  # (N, T_q, C)

        return outputs

class feedforward(nn.Module):

    def __init__(self, in_channels, num_units=[2048, 512]):
        '''Point-wise feed forward net.
        Args:
          in_channels: a number of channels of inputs
          num_units: A list of two integers.
        '''
        super(feedforward, self).__init__()
        self.in_channels = in_channels
        self.num_units = num_units
        
        # nn.Linear is faster than nn.Conv1d
        self.conv = False
        if self.conv:
            params = {'in_channels': self.in_channels, 'out_channels': self.num_units[0],
                      'kernel_size': 1, 'stride': 1, 'bias': True}
            self.conv1 = nn.Sequential(nn.Conv1d(**params), nn.ReLU())
            params = {'in_channels': self.num_units[0], 'out_channels': self.num_units[1],
                      'kernel_size': 1, 'stride': 1, 'bias': True}
            self.conv2 = nn.Conv1d(**params)
        else:
            self.conv1 = nn.Sequential(nn.Linear(self.in_channels, self.num_units[0]), nn.ReLU())
            self.conv2 = nn.Linear(self.num_units[0], self.num_units[1])
            
        # tensorflow compatible initializer
        def init_weights(m):
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.zeros_(m.bias)
        self.conv1.apply(init_weights)
        self.conv2.apply(init_weights)
        
        self.normalization = layer_normalization(self.in_channels)

    def forward(self, inputs):
        if self.conv:
            inputs = inputs.permute(0, 2, 1)
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)

        # Residual connection
        outputs += inputs

        # Layer normalization
        if self.conv:
            outputs = self.normalization(outputs.permute(0, 2, 1))
        else:
            outputs = self.normalization(outputs)

        return outputs

class ScriptWriter_cpre(nn.Module):
    
    def __init__(
        self,
        eta=0.7,
        max_sentence_len = 50,
        max_num_utterance = 11,
        embedding_file = EMBEDDING_FILE,
    ):
        super().__init__()
        self.max_num_utterance = max_num_utterance
        self.negative_samples = 1
        self.max_sentence_len = max_sentence_len
        self.hidden_units = 200 #word embedding size
        #self.total_words = 43514
        #self.total_words = 11883
        self.dropout_rate = 0
        self.num_heads = 1
        self.num_blocks = 3 
        self.eta = eta
        self.gamma = nn.Parameter(torch.tensor(0.5), requires_grad=True)
        word_emb = pickle.load(open(embedding_file, 'rb'), encoding="bytes")
        word_emb = torch.FloatTensor(word_emb)
        self.embedding = nn.Embedding.from_pretrained(word_emb, freeze=True)
        
        for i in range(self.num_blocks):
            self.__setattr__('self_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('self_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            
        for i in range(self.num_blocks+1):
            self.__setattr__('ru_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('ru_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            self.__setattr__('ur_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('ur_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            self.__setattr__('nu_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('nu_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            self.__setattr__('un_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('un_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            self.__setattr__('nr_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('nr_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            self.__setattr__('rn_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('rn_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
                                       
                                       
        self.n_dense = nn.Linear(self.hidden_units, self.hidden_units)
        torch.nn.init.xavier_uniform_(self.n_dense.weight)
        torch.nn.init.zeros_(self.n_dense.bias)
        self.lastu_dense = nn.Linear(self.max_sentence_len, 1) 
        torch.nn.init.xavier_uniform_(self.lastu_dense.weight)
        torch.nn.init.zeros_(self.lastu_dense.bias)
        self.lastur_dense = nn.Linear(self.max_sentence_len, 1)
        torch.nn.init.xavier_uniform_(self.lastur_dense.weight)
        torch.nn.init.zeros_(self.lastur_dense.bias)
        
        depth = self.max_num_utterance # 11
        height = self.max_sentence_len # 50
        width = self.max_sentence_len # 50
        padding = ((depth%3 + 1)//2, (height%3 + 1)//2, (width%3 + 1)//2,)
        conv3d_1_layer = nn.Conv3d((self.num_blocks+1)*2, 32, 3, padding='same') 
        nn.init.uniform_(conv3d_1_layer.weight, -0.01, 0.01) 
        nn.init.zeros_(conv3d_1_layer.bias)
        self.conv3d_1 = torch.nn.Sequential(conv3d_1_layer, torch.nn.ELU())
        self.maxpool3d_1 = torch.nn.MaxPool3d(3, padding=padding)
        
        depth = (self.max_num_utterance+2)//3 # 11
        height = (self.max_sentence_len+2)//3 # 50
        width = (self.max_sentence_len+2)//3 # 50
        padding = ((depth%3 + 1)//2, (height%3 + 1)//2, (width%3 + 1)//2,)
        conv3d_2_layer = nn.Conv3d(32, 32, 3, padding='same') 
        nn.init.uniform_(conv3d_2_layer.weight, -0.01, 0.01)
        nn.init.zeros_(conv3d_2_layer.bias)
        self.conv3d_2 = torch.nn.Sequential(conv3d_2_layer, torch.nn.ELU())
        self.maxpool3d_2 = torch.nn.MaxPool3d(3, padding=padding)
        mur_flatten_size = ((depth+2)//3)*((height+2)//3)*((width+2)//3)*32
        #print('mur_flatten_size =', mur_flatten_size)
        
        height = self.max_sentence_len # 50
        width = self.max_sentence_len # 50
        padding = ((height%3 + 1)//2, (width%3 + 1)//2)
        conv2d_1_layer = nn.Conv2d((self.num_blocks+1)*2, 32, 3, padding='same')
        nn.init.uniform_(conv2d_1_layer.weight, -0.01, 0.01)
        nn.init.zeros_(conv2d_1_layer.bias)
        self.conv2d_1 = torch.nn.Sequential(conv2d_1_layer, torch.nn.ELU())
        self.maxpool2d_1 = torch.nn.MaxPool2d(3, padding=padding)
        
        height = (self.max_sentence_len+2)//3 # 50
        width = (self.max_sentence_len+2)//3 # 50
        padding = ((height%3 + 1)//2, (width%3 + 1)//2)
        conv2d_2_layer = nn.Conv2d(32, 32, 3, padding='same') 
        nn.init.uniform_(conv2d_2_layer.weight, -0.01, 0.01)
        nn.init.zeros_(conv2d_2_layer.bias)
        self.conv2d_2 = torch.nn.Sequential(conv2d_2_layer, torch.nn.ELU())
        self.maxpool2d_2 = torch.nn.MaxPool2d(3, padding=padding)
        
        total_flatten_size = mur_flatten_size*2 + ((height+2)//3)*((width+2)//3)*32
        #print('total_flatten_size =', total_flatten_size)
        
        self.logits_dense = nn.Linear(total_flatten_size, 1)  
        nn.init.orthogonal_(self.logits_dense.weight)
        nn.init.zeros_(self.logits_dense.bias)
        
        self.bcewithlogitsloss = nn.BCEWithLogitsLoss()
        
    def forward(
        self,
        idx: Optional[torch.Tensor] = None,
        response: Optional[torch.Tensor] = None,
        gt_response: Optional[torch.Tensor] = None,
        narrative: Optional[torch.Tensor] = None,
        utterance: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
    ):
        #print('response.shape =', response.shape)
        #print('response type =', type(response))
        #print('utterance.shape =', utterance.shape)
        all_utterances = torch.unbind(utterance, dim=1)
        #print(all_utterances[0].shape)
        
        response_embeddings = self.embedding(response)
        Hr_stack = [response_embeddings]
        for i in range(self.num_blocks):
            response_embeddings = self.__getattr__('self_multihead_attention_%d' % i)(
                response_embeddings, response_embeddings, response_embeddings)
            response_embeddings = self.__getattr__('self_feedforward_%d' % i)(response_embeddings)
            Hr_stack.append(response_embeddings)
            
        
        #for i in range(self.num_blocks+1):
        #    print('Hr_stack[%d] ='%i, Hr_stack[i])
        
        gt_response_embeddings = self.embedding(gt_response)
        Hgtr_stack = [gt_response_embeddings]
        for i in range(self.num_blocks):
            gt_response_embeddings = self.__getattr__('self_multihead_attention_%d' % i)(
                gt_response_embeddings, gt_response_embeddings, gt_response_embeddings)
            gt_response_embeddings = self.__getattr__('self_feedforward_%d' % i)(gt_response_embeddings)
            Hgtr_stack.append(gt_response_embeddings)
            
        narrative_embeddings = self.embedding(narrative)
        Hn_stack = [narrative_embeddings]
        for i in range(self.num_blocks):
            narrative_embeddings = self.__getattr__('self_multihead_attention_%d' % i)(
                narrative_embeddings, narrative_embeddings, narrative_embeddings)
            narrative_embeddings = self.__getattr__('self_feedforward_%d' % i)(narrative_embeddings)
            Hn_stack.append(narrative_embeddings)
            
        #for i in range(self.num_blocks+1):
        #    print('Hn_stack[%d] ='%i, Hn_stack[i])
        
        Mur, Mun = [], []
        self.decay_factor = []
        last_u_reps = []
        turn_id = 0
        
        for utterance in all_utterances:
            utterance_embeddings = self.embedding(utterance)
            Hu_stack = [utterance_embeddings]
            for i in range(self.num_blocks):
                utterance_embeddings = self.__getattr__('self_multihead_attention_%d' % i)(
                    utterance_embeddings, utterance_embeddings, utterance_embeddings)
                utterance_embeddings = self.__getattr__('self_feedforward_%d' % i)(utterance_embeddings)
                Hu_stack.append(utterance_embeddings)
                
            if turn_id == self.max_num_utterance - 1:
                last_u_reps = Hu_stack
            
            r_a_u_stack = []
            u_a_r_stack = []
            for i in range(self.num_blocks + 1):
                r_a_u = self.__getattr__('ru_multihead_attention_%d' % i)(
                    Hr_stack[i], Hu_stack[i], Hu_stack[i])
                r_a_u = self.__getattr__('ru_feedforward_%d' % i)(r_a_u)
                r_a_u_stack.append(r_a_u)
                u_a_r = self.__getattr__('ur_multihead_attention_%d' % i)(
                    Hu_stack[i], Hr_stack[i], Hr_stack[i])
                u_a_r = self.__getattr__('ur_feedforward_%d' % i)(u_a_r)
                u_a_r_stack.append(u_a_r)
            r_a_u_stack.extend(Hr_stack)
            u_a_r_stack.extend(Hu_stack)
            
            n_a_u_stack = []
            u_a_n_stack = []
            for i in range(self.num_blocks + 1):
                n_a_u = self.__getattr__('nu_multihead_attention_%d' % i)(
                    Hn_stack[i], Hu_stack[i], Hu_stack[i])
                n_a_u = self.__getattr__('nu_feedforward_%d' % i)(n_a_u)
                n_a_u_stack.append(n_a_u)
                u_a_n = self.__getattr__('un_multihead_attention_%d' % i)(
                    Hu_stack[i], Hn_stack[i], Hn_stack[i])
                u_a_n = self.__getattr__('un_feedforward_%d' % i)(u_a_n)
                u_a_n_stack.append(u_a_n)
            n_a_u_stack.extend(Hn_stack)
            u_a_n_stack.extend(Hu_stack)
            
            r_a_u = torch.stack(r_a_u_stack, dim=-1)
            u_a_r = torch.stack(u_a_r_stack, dim=-1)
            n_a_u = torch.stack(n_a_u_stack, dim=-1)
            u_a_n = torch.stack(u_a_n_stack, dim=-1)
            
            # sim shape [batch, max_sent_len, max_sent_len, 2 * (stack_num + 1)]
            # divide sqrt(200) to prevent gradient explosion
            # (-1, 50, 50, 8)
            sim_ur = torch.einsum('biks,bjks->bijs', u_a_r, r_a_u) / torch.sqrt(torch.tensor(200.0))  # for no rp and normal
            sim_un = torch.einsum('biks,bjks->bijs', u_a_n, n_a_u) / torch.sqrt(torch.tensor(200.0))  # for no rp and normal
            
            self_n = torch.nn.functional.normalize(torch.stack(Hn_stack, dim=-1), p=2, dim=None)  # for no rp
            self_u = torch.nn.functional.normalize(torch.stack(Hu_stack, dim=-1), p=2, dim=None)  # for no rp
            Hn_stack_tensor = torch.stack(Hn_stack, dim=-1)  # [batch, o_len, embedding_size, stack]
            #print('Hn_stack_tensor =', Hn_stack_tensor)
            #print('self_n =', self_n)
            #print('self_u =', self_u)
            
            self_sim = torch.einsum('biks,bjks->bijs', self_u, self_n)  # [batch, u_len, o_len, stack]
            #print('self_sim0 = ', self_sim)
            self_sim = 1 - self.gamma * torch.sum(self_sim, dim=1)  # [batch, (1), o_len, stack]
            #print('self_sim = ', self_sim)
            Hn_stack = torch.einsum('bjkl,bjl->bjkl', Hn_stack_tensor, self_sim)
            Hn_stack = torch.unbind(Hn_stack, dim=-1)
            #for i in range(4):
            #    print('Hn_stack[%d] ='%i, Hn_stack[i])
            
            Mur.append(sim_ur)
            Mun.append(sim_un)
            turn_id += 1
            
        # Hn_stack ( (-1,50,200), ... ) len = block_num
        #print('narrative updated final len(Hn_stack) =', len(Hn_stack), ', Hn_stack[0].shape =', Hn_stack[0].shape)
        
        #print('stack shape = ', torch.stack(Hn_stack, dim=2).shape)
        Hn_stack_for_tracking = self.n_dense(torch.stack(Hn_stack, dim=2))  # [batch, o_len, stack, embedding_size]
        #print('torch.stack(Hn_stack, dim=2) =', torch.stack(Hn_stack, dim=2))
        #print('Hn_stack_for_tracking.shape after dense =', Hn_stack_for_tracking.shape)
        Hn_stack_for_tracking = Hn_stack_for_tracking.permute((0, 1, 3, 2))  # [batch, o_len, embedding_size, stack]
        #print('Hn_stack_for_tracking.shape after permute =', Hn_stack_for_tracking.shape)
        Hlastu_stack_for_tracking = torch.stack(last_u_reps, dim=-1)  # [batch, u_len, embedding_size, stack]
        Hr_stack_for_tracking = torch.stack(Hgtr_stack, dim=-1)  # [batch, r_len, embedding_size, stack]
        Hlastu = Hlastu_stack_for_tracking.permute((0, 2, 3, 1)) # [batch, embedding_size, stack, u_len]
        Hlastu = torch.squeeze(self.lastu_dense(Hlastu), dim=-1)  # [batch, embedding_size, stack]
        p1_tensor = nn.functional.softmax(torch.einsum('bnds,bds->bns', Hn_stack_for_tracking, Hlastu), dim=1)  # [batch, o_len, stack]
        #print('Hlastu =', Hlastu)
        Hlastur = Hr_stack_for_tracking.permute((0, 2, 3, 1))
        Hlastur = torch.squeeze(self.lastur_dense(Hlastur), dim=-1)  # [batch, embedding_size, stack]
        p2_tensor = nn.functional.softmax(torch.einsum('bnds,bds->bns', Hn_stack_for_tracking, Hlastur), dim=1)  # [batch, o_len, stack]
        #print('Hn_stack_for_tracking =', Hn_stack_for_tracking)
        #print('Hlastur =', Hlastur)
        #print('einsum for p2 =', torch.einsum('bnds,bds->bns', Hn_stack_for_tracking, Hlastur))
        p1 = torch.unbind(p1_tensor, dim=-1)
        p2 = torch.unbind(p2_tensor, dim=-1)
        #print('len(p1) =', len(p1), ', p1[0].shape =', p1[0].shape)
        #print('forward p2 =', p2)
        
        n_a_r_stack = []
        r_a_n_stack = []
        for i in range(self.num_blocks + 1):
            #print('Hn_stack[%d] ='%i, Hn_stack[i])
            n_a_r = self.__getattr__('nr_multihead_attention_%d' % i)(
                Hn_stack[i], Hr_stack[i], Hr_stack[i])
            n_a_r = self.__getattr__('nr_feedforward_%d' % i)(n_a_r)
            n_a_r_stack.append(n_a_r)
            r_a_n = self.__getattr__('rn_multihead_attention_%d' % i)(
                Hr_stack[i], Hn_stack[i], Hn_stack[i])
            r_a_n = self.__getattr__('rn_feedforward_%d' % i)(r_a_n)
            r_a_n_stack.append(r_a_n)
        n_a_r_stack.extend(Hn_stack)
        r_a_n_stack.extend(Hr_stack)

        n_a_r = torch.stack(n_a_r_stack, dim=-1)
        r_a_n = torch.stack(r_a_n_stack, dim=-1)
        
        #print('n_a_r =', n_a_r)
        #print('r_a_n =', r_a_n)

        Mrn = torch.einsum('biks,bjks->bijs', n_a_r, r_a_n) / torch.sqrt(torch.tensor(200.0))
        self.rosim = Mrn
        Mur = torch.stack(Mur, dim=1)
        Mun = torch.stack(Mun, dim=1)
        
        #print('Mrn.shape =', Mrn.shape)
        #print('Mrn =', Mrn)
        
        conv3d = self.conv3d_1(Mur.permute(0,4,1,2,3)) # (-1, 11, 50, 50, 8) -> (-1, 8, 11, 50, 50)
        #print('conv3d.shape =', conv3d.shape)
        pool3d = self.maxpool3d_1(conv3d)              # (-1, 32, 4, 17, 17)
        #print('pool3d.shape =', pool3d.shape)
        conv3d2 = self.conv3d_2(pool3d)
        pool3d2 = self.maxpool3d_2(conv3d2)            # (-1, 32, 2, 6, 6)
        #print('pool3d2.shape =', pool3d2.shape)
        mur = torch.flatten(pool3d2.permute(0,2,3,4,1), start_dim=1)
        #print('mur.shape =', mur.shape)
        
        conv3d = self.conv3d_1(Mun.permute(0,4,1,2,3))
        pool3d = self.maxpool3d_1(conv3d)
        conv3d2 = self.conv3d_2(pool3d)
        pool3d2 = self.maxpool3d_2(conv3d2)
        mun = torch.flatten(pool3d2.permute(0,2,3,4,1), start_dim=1)
        #print('mun.shape =', mun.shape)
        
        conv2d = self.conv2d_1(Mrn.permute((0,3,1,2)))
        pool2d = self.maxpool2d_1(conv2d)
        conv2d2 = self.conv2d_2(pool2d)
        pool2d2 = self.maxpool2d_2(conv2d2)
        mrn = torch.flatten(pool2d2.permute(0,2,3,1), start_dim=1) 
        #print('mrn.shape =', mrn.shape)
        #print('mrn=', mrn)
        
        #print('p1[0] =', p1[0])
        #print('p2[0] =', p2[0])
        #KL_loss = torch.tensor(0.0)
        eps = 1e-7
        #print(num_blocks)
        KL_loss = 0.0
        for i in range(self.num_blocks + 1):
            #print('p1[%d].shape ='%i, p1[i])
            #print('p2[%d].shape ='%i, p2[i])
            KL_loss += torch.mean(nn.functional.kl_div((p2[i]+eps).log(), p1[i], reduction='batchmean'))
            #KL_loss += torch.mean(self.kl_loss((p2[i]+eps).log(), p1[i]))
            #KL_loss += torch.mean(nn.functional.kl_div(p2_log[i], p1[i], reduction='batchmean'))
            #print('KL:', i, torch.mean(nn.functional.kl_div(p2[i].log(), p1[i], reduction='batchmean')))
        KL_loss /= (self.num_blocks + 1)
        #print('KL =', KL_loss)
        #print(logits)
        #print(y_true)
        #print('###')
        #print(self.bcewithlogitsloss(logits, y_true))
        
        all_vector = torch.cat([mur, mun, mrn], dim=-1)
        #print('all_vector.shape =', all_vector.shape)
        #print('all_vector =', all_vector)
        logits = torch.reshape(self.logits_dense(all_vector), shape=(-1,))
        y_pred = torch.sigmoid(logits)
        #print('y_pred.shape =', y_pred.shape)
        
        y_true = labels
        #RS_loss = torch.mean(self.bcewithlogitsloss(logits, y_true))
        RS_loss = torch.mean(torch.clip(self.bcewithlogitsloss(logits, y_true), -10, 10))
        #print('RS =', RS_loss)
        loss = self.eta * RS_loss + (1 - self.eta) * KL_loss
        #print('loss =', loss)
        loss = torch.unsqueeze(loss, dim=0)
        
        return {
            'loss': loss,
            'y_pred': y_pred,
            #'logits': logits,
            #'KL_loss': KL_loss
            #'p1': p1,
            #'p2': p2
        }

In [None]:
model = ScriptWriter_cpre()

In [None]:
from transformers import Trainer, TrainingArguments, TrainerCallback
from typing import Dict, Union, Any
from transformers.utils import  is_apex_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled

class ScriptwriterTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bcewithlogitsloss = nn.BCEWithLogitsLoss()
        self.kl_loss = nn.KLDivLoss(reduction = 'batchmean')

    def compute_loss(self, model, inputs, return_outputs=False):
        #y_true = inputs.pop('labels')
        
        outputs = model(**inputs)
        """
        p1 = outputs.pop('p1')
        p2 = outputs.pop('p2')
        print('len(p1)=',len(p1))
        print('p1[0].shape=',p1[0].shape)
        
        #print('p1[0] =', p1[0])
        #print('p2[0] =', p2[0])
        self.eta = 0.7
        KL_loss = 0.0
        num_blocks = len(p1)-1
        eps = torch.tensor(1e-7)
        #print(num_blocks)
        for i in range(num_blocks + 1):
            #print('p1[%d].shape ='%i, p1[i])
            #print('p2[%d].shape ='%i, p2[i])
            #KL_loss += torch.mean(nn.functional.kl_div((p2[i]+eps).log(), p1[i], reduction='batchmean'))
            KL_loss += torch.mean(self.kl_loss((p2[i]+eps).log(), p1[i]))
            #KL_loss += torch.mean(nn.functional.kl_div(p2_log[i], p1[i], reduction='batchmean'))
            #print('KL:', i, torch.mean(nn.functional.kl_div(p2[i].log(), p1[i], reduction='batchmean')))
        KL_loss /= (num_blocks + 1)
        print('KL =', KL_loss)
        #print(logits)
        #print(y_true)
        #print('###')
        #print(self.bcewithlogitsloss(logits, y_true))
        print('logits.shpae =', logits.shape)
        #print('y_true =',y_true)
        #RS_loss = torch.mean(torch.clip(self.bcewithlogitsloss(logits, y_true), -10, 10))
        RS_loss = torch.mean(self.bcewithlogitsloss(logits, y_true))
        print('RS =', RS_loss)
        print('KL =', KL_loss)
        self.eta = 0.7
        loss = self.eta * RS_loss + (1 - self.eta) * KL_loss
        #loss = KL_loss
        """
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        #print('loss =', loss)
        
        return (loss, outputs) if return_outputs else loss
        
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.
        Subclass and override to inject custom behavior.
        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.
                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
            #print('multi gpu loss =', loss)

        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
            loss = loss / self.args.gradient_accumulation_steps

        if self.do_grad_scaling:
            self.scaler.scale(loss).backward()
        elif self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.deepspeed:
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
        else:
            #print('## backward')
            loss.backward()
            #for name, param in model.named_parameters():
                #print(name, param.grad)

        return loss.detach()

import Evaluate

def compute_metrics(evalpred):
    
    preds, labels = evalpred
    result = Evaluate.evaluate_all(preds, labels)
            
    return {
        "accuracy" : result[0], 
        "r2@1"     : result[1],
        "r10@1"    : result[2],
        "r10@2"    : result[3],
        "r10@5"    : result[4],
        "mrr"      : result[5],
        "AvgScore" : (result[1]+result[2]+result[3]+result[4]+result[5])/5.0, 
    }

In [None]:
from datasets import load_dataset
datasets = load_dataset('story_data', dataset_name)

In [None]:
model_checkpoint = "checkpoint"
training_args = TrainingArguments(
    f"{model_checkpoint}-{dataset_name}",
    max_grad_norm = 5.0, # from the original source
    #optim="adamw_torch",
    #learning_rate=0.001,
    #adam_beta1=0.9,
    #adam_beta2=0.98,
    #adam_epsilon=1e-8,
    num_train_epochs=4,
    per_device_train_batch_size = 128,
    per_device_eval_batch_size = 128*8,
    load_best_model_at_end = True,
    metric_for_best_model = "AvgScore",
    evaluation_strategy = "steps",
    save_strategy = "steps", # no, steps   # save_stratge should be same as eval stratege for best 
    eval_steps = 50 if dataset_name == "original" else 100,  # 25000/(batch*GPU)=50 for original 100 for 246578 helper
    save_steps = 50 if dataset_name == "original" else 100,  # save steps should be multiple of eval_steps
    save_total_limit = 2, # limit 1 and best_model automatically set to 2
    logging_steps = 80, # not working????
    logging_first_step = True, # not working???
    report_to="none",
    push_to_hub=False,
    seed=random_seed,
    log_level='error',
)

In [None]:
from transformers import DefaultDataCollator
from datasets import load_dataset

datasets = load_dataset('story_data', dataset_name)
data_collator = DefaultDataCollator()

In [None]:
class patience_scheduler(torch.optim.lr_scheduler.LambdaLR):
    def __init__(self, optimizer, last_epoch=-1, verbose=False):
        
        self.lr = optimizer.param_groups[0]['lr']
        #print('self.lr =', self.lr)
        self._last_lr = [self.lr]
        
        def lr_lambda(step):
            #print('self.lr =', self.lr)
            #print('step =', step)
            #lr = self.lr*(0.5**step)
            lr = 0.5**(step+1)
            #self._last_lr = [lr]
            return lr
        
        super().__init__(optimizer, lr_lambda, last_epoch=-1, verbose=False)
        
    def step(self, from_callback=False):
        if (from_callback):
            super().step()
            print('lr changed to:', self.optimizer.param_groups[0]['lr'])
            #print(self.lr)
        else:
            pass

class LrCallback(TrainerCallback):
    "A callback after evaluation"
    def __init__(self):
        self.best_score = -1
        self.patience = 0

    def on_evaluate(self, args, state, control, **kwargs):
        # eval_ is appended to the metric name with Trainer
        name = 'eval_' + args.metric_for_best_model
        score = kwargs['metrics'][name]
        if score > self.best_score:
            self.best_score = score
            self.patience = 0
        else:
            self.patience += 1
            if self.patience >= 3:
                lr_scheduler = kwargs['lr_scheduler']
                lr_scheduler.step(from_callback=True)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-8)
scheduler = patience_scheduler(optimizer)

In [None]:
trainer = ScriptwriterTrainer(
    model=model,
    args=training_args,
    train_dataset=datasets["train"],
    eval_dataset=datasets["validation"],
    data_collator=data_collator,
    #compute_metrics=compute_metrics,
    optimizers=(optimizer, scheduler),
    #callbacks=[LrCallback],
)

**Loading model**

In [None]:
trainer._load_from_checkpoint('checkpoint-1cycle/checkpoint-best')

In [None]:
#metrics = trainer.evaluate(eval_dataset=datasets["test"])
#print(metrics)

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('monologg/kobigbird-bert-base')

In [None]:
from gensim.models import Word2Vec

word2vec = Word2Vec.load(f'data/word2vec_{dataset_name}.model')

In [None]:
EOS_ID = word2vec.wv.key_to_index['[SEP]']+1
UNK_ID = word2vec.wv.key_to_index['[UNK]']+1

In [None]:
with open(f'data/positive_{dataset_name}.pkl', "rb") as f:
    positive_data = pickle.load(f)
with open(f'data/positive_str_{dataset_name}.pkl', "rb") as f:
    positive_str = pickle.load(f)

In [None]:
all_utterances = []
all_utterances_str = []
for i, (utterances, narrative, _) in enumerate(positive_data):
    if (utterances[-1] == [EOS_ID]):
        utterances = utterances[:-1]
    all_utterances += utterances
    all_utterances_str += positive_str[i]['script']
all_utterances += [[EOS_ID]]
all_utterances_str += ['[SEP]']

In [None]:
train_num = int(len(positive_data) * 0.9)
dev_test_num = int(len(positive_data) * 0.05)
train, dev, test = positive_data[:train_num], positive_data[train_num: train_num + dev_test_num], positive_data[train_num + dev_test_num:]

In [None]:
from torch.utils.data import Dataset
class StoryDataset(Dataset):
    def __init__(self, x):
        self.utterance, self.response, self.narrative, self.gt_response, self.label = x
        self.utterance = torch.IntTensor(self.utterance)
        self.response = torch.tensor(self.response)
        self.narrative = torch.tensor(self.narrative)
        self.gt_response = torch.tensor(self.gt_response)
        self.label = torch.FloatTensor(self.label)
        self.n = len(self.label)
        
    def __getitem__(self, index): 
        return {
            "utterance": self.utterance[index],
            "response": self.response[index],
            "narrative": self.narrative[index],
            "gt_response": self.gt_response[index],
            "label": self.label[index],
        }
        
    def __len__(self): 
        return self.n

In [None]:
def make_list(context, narrative, all_utterances):
    data_list = []
    for utterance in all_utterances:
        data_list.append([context, utterance, narrative, utterance, 0])
    return data_list

In [None]:
def get_numpy_from_nonfixed_2d_array(aa, max_sentence_len=50, max_num_utterance=10, padding_value=0):
    PAD_SEQUENCE = np.array([0] * max_sentence_len)
    rows = np.empty([0, max_sentence_len], dtype='int')
    aa = aa[-max_num_utterance:]
    for a in aa:
        sentence_len = len(a)
        if sentence_len < max_sentence_len:
            #rows  = np.append(rows, [np.pad(np.array(a, dtype='int'), (0, max_sentence_len-sentence_len), 'constant', constant_values=padding_value)[:max_sentence_len]], axis=0)
            rows  = np.append(rows, [np.pad(a, (0, max_sentence_len-sentence_len), 'constant', constant_values=padding_value)[:max_sentence_len]], axis=0)
        else:
            rows = np.append(rows, [a[:max_sentence_len]], axis=0)
    num_utterance = len(aa)
    if num_utterance < max_num_utterance:
        rows = np.append(rows, [PAD_SEQUENCE]*(max_num_utterance-num_utterance), axis=0)
    # add empty +1 sentence
    rows = np.append(rows, [PAD_SEQUENCE], axis=0)
    #return np.concatenate(rows, axis=0).reshape(-1, max_sentence_len)
    return rows

def get_numpy_from_nonfixed_1d_array(a, max_sentence_len=50, padding_value=np.int_(0)):
    sentence_len = len(a)
    if sentence_len < max_sentence_len:
        #return np.pad(np.array(a, dtype='int'), (0, max_sentence_len-sentence_len), 'constant', constant_values=padding_value)
        return np.pad(a, (0, max_sentence_len-sentence_len), 'constant', constant_values=padding_value)
    else:
        return np.array(a[:max_sentence_len], dtype='int')

#from tqdm.notebook import tqdm

def pad_process(data, max_sentence_len=50, max_num_utterance=10):
    utterance = []
    response = []
    narrative = []
    gt_response = []
    y_true = []
    #for unit in tqdm(data):
    for i, unit in enumerate(data):
        #if len(unit[1]) ==0:
        #    print(i)
        utterance.append(get_numpy_from_nonfixed_2d_array(unit[0]))
        response.append(get_numpy_from_nonfixed_1d_array(unit[1]))
        narrative.append(get_numpy_from_nonfixed_1d_array(unit[2]))
        gt_response.append(get_numpy_from_nonfixed_1d_array(unit[3]))
        y_true.append(unit[4])
        
#    for r in response:
#        for c in r:
#            if type(c) != np.int64:
#                print('###')
#                print(r)
    utterance = np.stack(utterance)
    response = np.stack(response)
    narrative = np.stack(narrative)
    gt_response = np.stack(gt_response)
    y_true = np.stack(y_true)
    return (utterance, response, narrative, gt_response, y_true)

def pad_process_response(all_responses, max_sentence_len=50, max_num_utterance=10):
    response = []
    for unit in all_responses:
        response.append(get_numpy_from_nonfixed_1d_array(unit))
        
    response = np.stack(response)
    return response

def pad_process_utterance(utterances, max_sentence_len=50, max_num_utterance=10):
    return np.stack([get_numpy_from_nonfixed_2d_array(utterances)])

def pad_process_narrative(narrative, max_sentence_len=50, max_num_utterance=10):
    return np.stack([get_numpy_from_nonfixed_1d_array(narrative)])

In [None]:
def ids_to_text(ids):
    return tokenizer.convert_tokens_to_string([word2vec.wv.index_to_key[i-1] for i in ids])

In [None]:
import pandas as pd

**이전 저장 결과 불러오기**

In [None]:
pred_name = f"{model_checkpoint}-{dataset_name}/pred.xlsx"
try:
    df = pd.read_excel(pred_name)
except:
    df = pd.DataFrame(columns =['pos_id', 'res_id', 'text'])
#df = pd.DataFrame(columns =['pos_id', 'res_id', 'text'])
#df.to_excel(pred_name, index=False)

In [None]:
#df[pd.all(df.pos_id==0, df.res_id==1)]
#df.loc[(df.pos_id==0)&(df.res_id==1)].values[0][2]
#df.loc[(df.pos_id==0)&(df.res_id==1)].text.values[0]

In [None]:
def add_to_df(j, i, narrative_str, utterances_str, context_str, response_str):
    if i == 1:
        for utterance_str in utterances_str:
            df.loc[len(df.index)] = [j, 'T', utterance_str]
        df.loc[len(df.index)] = [j, 'N', narrative_str]
        df.loc[len(df.index)] = [j, 'C', context_str]
    df.loc[len(df.index)] = [j, i, response_str]
    df.to_excel(pred_name, index=False)

**Predict**

In [None]:
import ipywidgets as widgets
from IPython.display import display

pd.set_option('display.max_colwidth', 0)
progress_bar = widgets.FloatProgress(min=0, max=1)
display(progress_bar)
output_widgets = widgets.Output()
display(output_widgets)

# resume
max_utterance_num_pred = 10
start_pos_id = 0
start_res_id = 1
if len(df) > 0:
    row = df.iloc[-1]
    start_pos_id = row['pos_id']
    if row['res_id'] >= max_utterance_num_pred-1 or row['text'] == '[SEP]':
        start_pos_id += 1
    else:
        start_res_id = row['res_id']+1

x_response = pad_process_response(all_utterances)
x_gt_response = x_response
n_data = len(x_response)
#print(x_response.shape)
#print(len(x_response))
x_y_true = np.zeros((n_data,))
#print(x_y_true.shape)
for j, data in enumerate(test[start_pos_id:]):
    utterances = data[0]
    narrative = data[1]
    context = [utterances[0]]
    narrative_str = ids_to_text(narrative)
    context_str = ids_to_text(context[0])
    utterances_str = []
    for utterance in utterances:
        #print(ids_to_text(utterance))
        utterances_str += [ids_to_text(utterance)]
    #print(utterances_str)
    #print('N:', narrative_str)
    #print('C:', context_str)
    x_narrative = pad_process_narrative(narrative)
    x_narrative = np.tile(x_narrative, (n_data, 1))
    #print(x_narrative.shape)
    if j == 0 and start_res_id > 1:
        i_start = start_res_id
        for k in range(1, start_res_id):
            u_str = df.loc[(df.pos_id==start_pos_id)&(df.res_id==k)].text.values[0]
            u_i = all_utterances_str.index(u_str)
            context.append(all_utterances[u_i])
        #print(context)
    else:
        i_start = 1
    for i in range(i_start, max_utterance_num_pred):
        output_widgets.clear_output()
        progress_bar.value=(j+start_pos_id)/float(n_data)
        with output_widgets:
            display(df.iloc[-50:])
        x_context = pad_process_utterance(context)
        x_context = np.tile(x_context, (n_data, 1, 1))
        #x_list = make_list(context, narrative, all_utterances)
        #print('pad_processing...')
        #x = pad_process(x_list)    
        x = (x_context, x_response, x_narrative, x_gt_response, x_y_true)
        predict_dataset = StoryDataset(x)
        #predict_dataset = tensor_dataset(x)
        y = trainer.predict(test_dataset=predict_dataset)
        y_idxs = np.argsort(y.predictions)[::-1]
        for y_i in y_idxs:
            y_top = all_utterances[y_i]
            if y_top in context:
                continue
            else:
                response = y_top
                break
        response_str = all_utterances_str[y_i]
        add_to_df(j+start_pos_id, i, narrative_str, utterances_str, context_str, response_str)
        #print('R:', response_str)
        context.append(response)
        if response == [EOS_ID]:
            break

In [None]:
raise Exception('전체 실행 여기서 멈추기')

**HTML display sample**

In [None]:
from IPython.core.display import display, HTML
display(HTML('<h3>Hello, world!</h3>'))
display(HTML('Hello, world!'))

**widget clear sample**

In [None]:
import ipywidgets as widgets
from IPython.display import display
import pandas as pd
import numpy as np
url = "https://data.london.gov.uk/download/number-international-visitors-london/b1e0f953-4c8a-4b45-95f5-e0d143d5641e/international-visitors-london-raw.csv"
df_london = pd.read_csv(url, encoding='latin_1')

In [None]:
ALL = 'ALL'
def unique_sorted_values_plus_ALL(array):
    unique = array.unique().tolist()
    unique.sort()
    unique.insert(0, ALL)
    return unique

dropdown_year = widgets.Dropdown(options = unique_sorted_values_plus_ALL(df_london.year))
display(dropdown_year)
#output_year = widgets.Output()
output_year = widgets.Output()

def dropdown_year_eventhandler(change):
    output_year.clear_output()
    with output_year:
        if (change.new == ALL):
            display(df_london)
        else:
            display(df_london[df_london.year == change.new])
        

dropdown_year.observe(dropdown_year_eventhandler, names='value')

display(output_year)

In [None]:
progress_bar = widgets.FloatProgress(min=0, max=1)
display(progress_bar)
#output_year = widgets.Output()
output_year = widgets.Output()
display(output_year)

with output_year:
    display(df_london)

import time
time.sleep(2)
progress_bar.value=0.5

output_year.clear_output()

with output_year:
    display(df_london[df_london.year == '2012'])


In [None]:
print(dir(widgets))