In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import *
import numpy as np
from torch.nn.parameter import Parameter
import pickle
from typing import Optional

class layer_normalization(nn.Module):

    def __init__(self, features, epsilon=1e-8):
        '''Applies layer normalization.
        Args:
          epsilon: A floating number. A very small number for preventing ZeroDivision Error.
        '''
        super(layer_normalization, self).__init__()
        self.epsilon = epsilon
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.epsilon) + self.beta
    
class multihead_attention(nn.Module):

    def __init__(self, num_units, num_heads=8, dropout_rate=0, causality=False):
        '''Applies multihead attention.
        Args:
            num_units: A scalar. Attention size.
            dropout_rate: A floating point number.
            causality: Boolean. If true, units that reference the future are masked.
            num_heads: An int. Number of heads.
        '''
        super(multihead_attention, self).__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.causality = causality
        self.Q_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        self.K_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        self.V_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())

        self.output_dropout = nn.Dropout(p=self.dropout_rate)

        self.normalization = layer_normalization(self.num_units)

    def forward(self, queries, keys, values):
        # keys, values: same shape of [N, T_k, C_k]
        # queries: A 3d Variable with shape of [N, T_q, C_q]

        print('q shape:', queries.shape)
        # Linear projections
        Q = self.Q_proj(queries)  # (N, T_q, C)
        print('Q shape:', Q.shape)
        K = self.K_proj(keys)  # (N, T_q, C)
        V = self.V_proj(values)  # (N, T_q, C)

        # Split and concat
        Q_ = torch.cat(torch.chunk(Q, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)
        K_ = torch.cat(torch.chunk(K, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)
        V_ = torch.cat(torch.chunk(V, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)

        # Multiplication
        outputs = torch.bmm(Q_, K_.permute(0, 2, 1))  # (h*N, T_q, T_k)

        # Scale
        outputs = outputs / (K_.size()[-1] ** 0.5)

        # Key Masking
        key_masks = torch.sign(torch.abs(torch.sum(keys, dim=-1)))  # (N, T_k)
        key_masks = key_masks.repeat(self.num_heads, 1)  # (h*N, T_k)
        key_masks = torch.unsqueeze(key_masks, 1).repeat(1, queries.size()[1], 1)  # (h*N, T_q, T_k)

        padding = Variable(torch.ones(*outputs.size()).cuda() * (-2 ** 32 + 1))
        condition = key_masks.eq(0.).float()
        outputs = padding * condition + outputs * (1. - condition)

        # Causality = Future blinding
        if self.causality:
            diag_vals = torch.ones(*outputs[0, :, :].size()).cuda()  # (T_q, T_k)
            tril = torch.tril(diag_vals, diagonal=0)  # (T_q, T_k)
            # print(tril)
            masks = Variable(torch.unsqueeze(tril, 0).repeat(outputs.size()[0], 1, 1))  # (h*N, T_q, T_k)

            padding = Variable(torch.ones(*masks.size()).cuda() * (-2 ** 32 + 1))
            condition = masks.eq(0.).float()
            outputs = padding * condition + outputs * (1. - condition)

        # Activation
        outputs = F.softmax(outputs, dim=-1)  # (h*N, T_q, T_k)

        # Query Masking
        query_masks = torch.sign(torch.abs(torch.sum(queries, dim=-1)))  # (N, T_q)
        query_masks = query_masks.repeat(self.num_heads, 1)  # (h*N, T_q)
        query_masks = torch.unsqueeze(query_masks, 2).repeat(1, 1, keys.size()[1])  # (h*N, T_q, T_k)
        outputs = outputs * query_masks

        # Dropouts
        outputs = self.output_dropout(outputs)  # (h*N, T_q, T_k)

        # Weighted sum
        outputs = torch.bmm(outputs, V_)  # (h*N, T_q, C/h)

        # Restore shape
        outputs = torch.cat(torch.chunk(outputs, self.num_heads, dim=0), dim=2)  # (N, T_q, C)

        # Residual connection
        outputs += queries

        # Normalize
        outputs = self.normalization(outputs)  # (N, T_q, C)

        return outputs

class feedforward(nn.Module):

    def __init__(self, in_channels, num_units=[2048, 512]):
        '''Point-wise feed forward net.
        Args:
          in_channels: a number of channels of inputs
          num_units: A list of two integers.
        '''
        super(feedforward, self).__init__()
        self.in_channels = in_channels
        self.num_units = num_units

        # nn.Linear is faster than nn.Conv1d
        self.conv = False
        if self.conv:
            params = {'in_channels': self.in_channels, 'out_channels': self.num_units[0],
                      'kernel_size': 1, 'stride': 1, 'bias': True}
            self.conv1 = nn.Sequential(nn.Conv1d(**params), nn.ReLU())
            params = {'in_channels': self.num_units[0], 'out_channels': self.num_units[1],
                      'kernel_size': 1, 'stride': 1, 'bias': True}
            self.conv2 = nn.Conv1d(**params)
        else:
            self.conv1 = nn.Sequential(nn.Linear(self.in_channels, self.num_units[0]), nn.ReLU())
            self.conv2 = nn.Linear(self.num_units[0], self.num_units[1])
        self.normalization = layer_normalization(self.in_channels)

    def forward(self, inputs):
        if self.conv:
            inputs = inputs.permute(0, 2, 1)
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)

        # Residual connection
        outputs += inputs

        # Layer normalization
        if self.conv:
            outputs = self.normalization(outputs.permute(0, 2, 1))
        else:
            outputs = self.normalization(outputs)

        return outputs

class ScriptWriter_cpre(nn.Module):
    
    def __init__(
        self,
        eta=0.5,
        max_sentence_len = 50,
        max_num_utterance = 11,
        embedding_file = 'data/embeddings_ko.pkl',
    ):
        super().__init__()
        self.max_num_utterance = max_num_utterance
        self.negative_samples = 1
        self.max_sentence_len = max_sentence_len
        self.emb_size = 200 # word_embedding_size 
        self.hidden_units = 200
        #self.total_words = 43514
        self.total_words = 11883
        #self.batch_size = batch_size
        #self.eval_batch_size = eval_batch_size
        #self.learning_rate_ph = tf.compat.v1.placeholder(tf.float32, shape=[], name='learning_rate')
        self.dropout_rate = 0
        self.num_heads = 1
        self.num_blocks = 3 
        self.eta = eta
        #self.gamma = tf.compat.v1.get_variable('gamma', shape=1, dtype=tf.float32, trainable=True, initializer=tf.constant_initializer(0.5))
        self.gamma = nn.Parameter(torch.tensor(0.5), requires_grad=True)
        word_emb = pickle.load(open(embedding_file, 'rb'), encoding="bytes")
        word_emb = torch.FloatTensor(word_emb)
        self.embedding = nn.Embedding.from_pretrained(word_emb, freeze=True)
        
        for i in range(self.num_blocks):
            self.__setattr__('self_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('self_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            
        for i in range(self.num_blocks+1):
            self.__setattr__('ru_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('ru_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            self.__setattr__('ur_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('ur_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            self.__setattr__('nu_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('nu_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            self.__setattr__('un_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('un_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            self.__setattr__('nr_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('nr_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
            self.__setattr__('rn_multihead_attention_%d' % i, multihead_attention(
                     num_units=self.hidden_units,
                     num_heads=self.num_heads,
                     dropout_rate=self.dropout_rate,
                     causality=False))
            self.__setattr__('rn_feedforward_%d' % i, feedforward(
                     self.hidden_units,
                     [self.hidden_units, self.hidden_units]))
                                       
                                       
        self.n_dense = nn.Linear(self.hidden_units, self.hidden_units)
        self.lastu_dense = nn.Linear(self.max_sentence_len, 1) 
        self.lastur_dense = nn.Linear(self.max_sentence_len, 1)
        
        depth = self.max_num_utterance # 11
        height = self.max_sentence_len # 50
        width = self.max_sentence_len # 50
        padding = ((depth%3 + 1)//2, (height%3 + 1)//2, (width%3 + 1)//2,)
        conv3d_1_layer = nn.Conv3d((self.num_blocks+1)*2, 32, 3, padding='same') 
        nn.init.uniform_(conv3d_1_layer.weight, -0.01, 0.01) 
        self.conv3d_1 = torch.nn.Sequential(conv3d_1_layer, torch.nn.ELU())
        self.maxpool3d_1 = torch.nn.MaxPool3d(3, padding=padding)
        
        depth = (self.max_num_utterance+2)//3 # 11
        height = (self.max_sentence_len+2)//3 # 50
        width = (self.max_sentence_len+2)//3 # 50
        padding = ((depth%3 + 1)//2, (height%3 + 1)//2, (width%3 + 1)//2,)
        conv3d_2_layer = nn.Conv3d(32, 32, 3, padding='same') 
        nn.init.uniform_(conv3d_2_layer.weight, -0.01, 0.01)
        self.conv3d_2 = torch.nn.Sequential(conv3d_2_layer, torch.nn.ELU())
        self.maxpool3d_2 = torch.nn.MaxPool3d(3, padding=padding)
        mur_flatten_size = ((depth+2)//3)*((height+2)//3)*((width+2)//3)*32
        #print('mur_flatten_size =', mur_flatten_size)
        
        height = self.max_sentence_len # 50
        width = self.max_sentence_len # 50
        padding = ((height%3 + 1)//2, (width%3 + 1)//2)
        conv2d_1_layer = nn.Conv2d((self.num_blocks+1)*2, 32, 3, padding='same')
        nn.init.uniform_(conv2d_1_layer.weight, -0.01, 0.01)
        self.conv2d_1 = torch.nn.Sequential(conv2d_1_layer, torch.nn.ELU())
        self.maxpool2d_1 = torch.nn.MaxPool2d(3, padding=padding)
        
        height = (self.max_sentence_len+2)//3 # 50
        width = (self.max_sentence_len+2)//3 # 50
        padding = ((height%3 + 1)//2, (width%3 + 1)//2)
        conv2d_2_layer = nn.Conv2d(32, 32, 3, padding='same') 
        nn.init.uniform_(conv2d_2_layer.weight, -0.01, 0.01)
        self.conv2d_2 = torch.nn.Sequential(conv2d_2_layer, torch.nn.ELU())
        self.maxpool2d_2 = torch.nn.MaxPool2d(3, padding=padding)
        
        total_flatten_size = mur_flatten_size*2 + ((height+2)//3)*((width+2)//3)*32
        #print('total_flatten_size =', total_flatten_size)
        
        self.logits_dense = nn.Linear(total_flatten_size, 1)  
        nn.init.orthogonal_(self.logits_dense.weight)
        
    def forward(
        self,
        response: Optional[torch.Tensor] = None,
        gt_response: Optional[torch.Tensor] = None,
        narrative: Optional[torch.Tensor] = None,
        utterance: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
    ):
        print('response.shape =', response.shape)
        print('response type =', type(response))
        print('utterance.shape =', utterance.shape)
        all_utterances = torch.unbind(utterance, dim=1)
        #print(all_utterances[0].shape)
        
        response_embeddings = self.embedding(response)
        Hr_stack = [response_embeddings]
        for i in range(self.num_blocks):
            response_embeddings = self.__getattr__('self_multihead_attention_%d' % i)(
                response_embeddings, response_embeddings, response_embeddings)
            response_embeddings = self.__getattr__('self_feedforward_%d' % i)(response_embeddings)
            Hr_stack.append(response_embeddings)
        
        gt_response_embeddings = self.embedding(gt_response)
        Hgtr_stack = [gt_response_embeddings]
        for i in range(self.num_blocks):
            gt_response_embeddings = self.__getattr__('self_multihead_attention_%d' % i)(
                gt_response_embeddings, gt_response_embeddings, gt_response_embeddings)
            gt_response_embeddings = self.__getattr__('self_feedforward_%d' % i)(gt_response_embeddings)
            Hgtr_stack.append(response_embeddings)
            
        narrative_embeddings = self.embedding(narrative)
        Hn_stack = [narrative_embeddings]
        for i in range(self.num_blocks):
            narrative_embeddings = self.__getattr__('self_multihead_attention_%d' % i)(
                narrative_embeddings, narrative_embeddings, narrative_embeddings)
            narrative_embeddings = self.__getattr__('self_feedforward_%d' % i)(narrative_embeddings)
            Hn_stack.append(response_embeddings)
        
        Mur, Mun = [], []
        self.decay_factor = []
        last_u_reps = []
        turn_id = 0
        
        for utterance in all_utterances:
            utterance_embeddings = self.embedding(utterance)
            Hu_stack = [utterance_embeddings]
            for i in range(self.num_blocks):
                utterance_embeddings = self.__getattr__('self_multihead_attention_%d' % i)(
                    utterance_embeddings, utterance_embeddings, utterance_embeddings)
                utterance_embeddings = self.__getattr__('self_feedforward_%d' % i)(utterance_embeddings)
                Hu_stack.append(utterance_embeddings)
                
            if turn_id == self.max_num_utterance - 1:
                last_u_reps = Hu_stack
            
            r_a_u_stack = []
            u_a_r_stack = []
            for i in range(self.num_blocks + 1):
                r_a_u = self.__getattr__('ru_multihead_attention_%d' % i)(
                    Hr_stack[i], Hu_stack[i], Hu_stack[i])
                r_a_u = self.__getattr__('ru_feedforward_%d' % i)(r_a_u)
                r_a_u_stack.append(r_a_u)
                u_a_r = self.__getattr__('ur_multihead_attention_%d' % i)(
                    Hu_stack[i], Hr_stack[i], Hr_stack[i])
                u_a_r = self.__getattr__('ur_feedforward_%d' % i)(u_a_r)
                u_a_r_stack.append(u_a_r)
            r_a_u_stack.extend(Hr_stack)
            u_a_r_stack.extend(Hu_stack)
            
            n_a_u_stack = []
            u_a_n_stack = []
            for i in range(self.num_blocks + 1):
                n_a_u = self.__getattr__('nu_multihead_attention_%d' % i)(
                    Hn_stack[i], Hu_stack[i], Hu_stack[i])
                n_a_u = self.__getattr__('nu_feedforward_%d' % i)(n_a_u)
                n_a_u_stack.append(n_a_u)
                u_a_n = self.__getattr__('un_multihead_attention_%d' % i)(
                    Hu_stack[i], Hn_stack[i], Hn_stack[i])
                u_a_n = self.__getattr__('un_feedforward_%d' % i)(u_a_n)
                u_a_n_stack.append(u_a_n)
            n_a_u_stack.extend(Hn_stack)
            u_a_n_stack.extend(Hu_stack)
            
            r_a_u = torch.stack(r_a_u_stack, dim=-1)
            u_a_r = torch.stack(u_a_r_stack, dim=-1)
            n_a_u = torch.stack(n_a_u_stack, dim=-1)
            u_a_n = torch.stack(u_a_n_stack, dim=-1)
            
            # sim shape [batch, max_sent_len, max_sent_len, 2 * (stack_num + 1)]
            # divide sqrt(200) to prevent gradient explosion
            # (-1, 50, 50, 8)
            sim_ur = torch.einsum('biks,bjks->bijs', u_a_r, r_a_u) / torch.sqrt(torch.tensor(200.0))  # for no rp and normal
            sim_un = torch.einsum('biks,bjks->bijs', u_a_n, n_a_u) / torch.sqrt(torch.tensor(200.0))  # for no rp and normal
            
            self_n = torch.nn.functional.normalize(torch.stack(Hn_stack, dim=-1))  # for no rp
            self_u = torch.nn.functional.normalize(torch.stack(Hu_stack, dim=-1))  # for no rp
            Hn_stack_tensor = torch.stack(Hn_stack, dim=-1)  # [batch, o_len, embedding_size, stack]
            
            self_sim = torch.einsum('biks,bjks->bijs', self_u, self_n)  # [batch, u_len, o_len, stack]
            self_sim = 1 - self.gamma * torch.sum(self_sim, dim=1)  # [batch, (1), o_len, stack]
            Hn_stack = torch.einsum('bjkl,bjl->bjkl', Hn_stack_tensor, self_sim)
            Hn_stack = torch.unbind(Hn_stack, dim=-1)
            
            Mur.append(sim_ur)
            Mun.append(sim_un)
            turn_id += 1
            
        # Hn_stack ( (-1,50,200), ... ) len = block_num
        #print('narrative updated final len(Hn_stack) =', len(Hn_stack), ', Hn_stack[0].shape =', Hn_stack[0].shape)
        
        #print('stack shape = ', torch.stack(Hn_stack, dim=2).shape)
        Hn_stack_for_tracking = self.n_dense(torch.stack(Hn_stack, dim=2))  # [batch, o_len, stack, embedding_size]
        #print('Hn_stack_for_tracking.shape after dense =', Hn_stack_for_tracking.shape)
        Hn_stack_for_tracking = Hn_stack_for_tracking.permute((0, 1, 3, 2))  # [batch, o_len, embedding_size, stack]
        #print('Hn_stack_for_tracking.shape after permute =', Hn_stack_for_tracking.shape)
        Hlastu_stack_for_tracking = torch.stack(last_u_reps, dim=-1)  # [batch, u_len, embedding_size, stack]
        Hr_stack_for_tracking = torch.stack(Hgtr_stack, dim=-1)  # [batch, r_len, embedding_size, stack]
        Hlastu = Hlastu_stack_for_tracking.permute((0, 2, 3, 1)) # [batch, embedding_size, stack, u_len]
        Hlastu = torch.squeeze(self.lastu_dense(Hlastu), dim=-1)  # [batch, embedding_size, stack]
        p1_tensor = nn.functional.softmax(torch.einsum('bnds,bds->bns', Hn_stack_for_tracking, Hlastu), dim=1)  # [batch, o_len, stack]
        Hlastur = Hr_stack_for_tracking.permute((0, 2, 3, 1))
        Hlastur = torch.squeeze(self.lastur_dense(Hlastur), dim=-1)  # [batch, embedding_size, stack]
        p2_tensor = nn.functional.softmax(torch.einsum('bnds,bds->bns', Hn_stack_for_tracking, Hlastur), dim=1)  # [batch, o_len, stack]
        p1 = torch.unbind(p1_tensor, dim=-1)
        p2 = torch.unbind(p2_tensor, dim=-1)
        #print('len(p1) =', len(p1), ', p1[0].shape =', p1[0].shape)
        
        n_a_r_stack = []
        r_a_n_stack = []
        for i in range(self.num_blocks + 1):
            n_a_r = self.__getattr__('nr_multihead_attention_%d' % i)(
                Hn_stack[i], Hr_stack[i], Hr_stack[i])
            n_a_r = self.__getattr__('nr_feedforward_%d' % i)(n_a_r)
            n_a_r_stack.append(n_a_r)
            r_a_n = self.__getattr__('rn_multihead_attention_%d' % i)(
                Hr_stack[i], Hn_stack[i], Hn_stack[i])
            r_a_n = self.__getattr__('rn_feedforward_%d' % i)(r_a_n)
            r_a_n_stack.append(r_a_n)
        n_a_r_stack.extend(Hn_stack)
        r_a_n_stack.extend(Hr_stack)

        n_a_r = torch.stack(n_a_r_stack, dim=-1)
        r_a_n = torch.stack(r_a_n_stack, dim=-1)

        Mrn = torch.einsum('biks,bjks->bijs', n_a_r, r_a_n) / torch.sqrt(torch.tensor(200.0))
        self.rosim = Mrn
        Mur = torch.stack(Mur, dim=1)
        Mun = torch.stack(Mun, dim=1)
        
        #print('Mur.shape =', Mur.shape)
        #print('Mun.shape =', Mun.shape)
        #print('Mrn.shape =', Mrn.shape)
        
        conv3d = self.conv3d_1(Mur.permute(0,4,1,2,3)) # (-1, 11, 50, 50, 8) -> (-1, 8, 11, 50, 50)
        #print('conv3d.shape =', conv3d.shape)
        pool3d = self.maxpool3d_1(conv3d)              # (-1, 32, 4, 17, 17)
        #print('pool3d.shape =', pool3d.shape)
        conv3d2 = self.conv3d_2(pool3d)
        pool3d2 = self.maxpool3d_2(conv3d2)            # (-1, 32, 2, 6, 6)
        #print('pool3d2.shape =', pool3d2.shape)
        mur = torch.flatten(pool3d2.permute(0,2,3,4,1), start_dim=1)
        #print('mur.shape =', mur.shape)
        
        conv3d = self.conv3d_1(Mun.permute(0,4,1,2,3))
        pool3d = self.maxpool3d_1(conv3d)
        conv3d2 = self.conv3d_2(pool3d)
        pool3d2 = self.maxpool3d_2(conv3d2)
        mun = torch.flatten(pool3d2.permute(0,2,3,4,1), start_dim=1)
        #print('mun.shape =', mun.shape)
        
        conv2d = self.conv2d_1(Mrn.permute((0,3,1,2)))
        pool2d = self.maxpool2d_1(conv2d)
        conv2d2 = self.conv2d_2(pool2d)
        pool2d2 = self.maxpool2d_2(conv2d2)
        mrn = torch.flatten(pool2d2.permute(0,2,3,1), start_dim=1) 
        #print('mrn.shape =', mrn.shape)
        
        all_vector = torch.cat([mur, mun, mrn], dim=-1)
        #print('all_vector.shape =', all_vector.shape)
        logits = torch.reshape(self.logits_dense(all_vector), shape=(-1,))
        y_pred = torch.sigmoid(logits)
        #print('y_pred.shape =', y_pred.shape)
        
        return {
            'y_pred': y_pred,
            'logits': logits,
            'p1': p1,
            'p2': p2
        }

In [3]:
import numpy as np
N = 4
device = torch.device("cuda")
response = torch.from_numpy(np.random.randint(1, 11400, size=(N, 50))).to(device)
gt_response = torch.from_numpy(np.random.randint(1, 11400, size=(N, 50))).to(device)
narrative = torch.from_numpy(np.random.randint(1, 11400, size=(N, 50))).to(device)
utterance = torch.from_numpy(np.random.randint(1, 11400, size=(N, 11, 50))).to(device)

model = ScriptWriter_cpre().to(device)

_ = model(response=response, gt_response=gt_response, narrative=narrative, utterance=utterance)

response.shape = torch.Size([4, 50])
response type = <class 'torch.Tensor'>
utterance.shape = torch.Size([4, 11, 50])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200]

q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])
Q shape: torch

In [8]:
model.embedding(torch.tensor([3]).to(device))[:, 0:30]

tensor([[-1.0511,  0.5404, -0.8509,  0.4603, -1.9017,  1.5315, -0.0478, -1.3589,
          0.5049,  0.8584, -0.4572, -0.0399, -0.8955, -0.9557, -0.9209, -1.0791,
          2.2756, -0.3729,  0.5335,  1.4092, -0.1845, -0.8261, -0.9601,  1.4598,
         -0.4093,  0.6887,  0.1956,  0.3117, -0.0178, -0.6985]],
       device='cuda:0')

In [None]:
#!pip uninstall numpy -y
#!pip install datasets

!pip install numpy==1.19.5
!pip install pandas==1.4.0
!pip install datasets==1.18.4

In [2]:
from transformers import Trainer, TrainingArguments

class ScriptwriterTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bcewithlogitsloss = nn.BCEWithLogitsLoss()

    def compute_loss(self, model, inputs, return_outputs=False):
        y_true = inputs.pop('labels')
        
        outputs = model(**inputs)
        logits = outputs.pop('logits')
        p1 = outputs.pop('p1')
        p2 = outputs.pop('p2')
        
        KL_loss = 0.0
        for i in range(model.num_blocks + 1):
            KL_loss += torch.mean(nn.functional.kl_div(p1[i], p2[i]))
        KL_loss /= (model.num_blocks + 1)
        RS_loss = torch.mean(torch.clip(self.bcewithlogitsloss(labels=y_true, logits=logits), -10, 10))
        loss = model.eta * RS_loss + (1 - model.eta) * KL_loss
        
        return (loss, outputs) if return_outputs else loss
        

In [3]:
from datasets import load_dataset
datasets = load_dataset('story_data', 'original')

model_checkpoint = "script-writer-cpre"
training_args = TrainingArguments(
    f"{model_checkpoint}-dev",
    evaluation_strategy = "epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=2e-5,
    weight_decay=0.01,
    push_to_hub=False,
)

Found cached dataset story_data (/home/kotech/.cache/huggingface/datasets/story_data/original/1.0.0/ec7f2c2a5c010e3fa36891e870388b66436f5a5edd79bce7b1ec2f4808991faa)


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()

In [5]:
model = ScriptWriter_cpre()

In [6]:
trainer = ScriptwriterTrainer(
    model=model,
    args=training_args,
    train_dataset=datasets["train"],
    eval_dataset=datasets["validation"],
    data_collator=data_collator,
)

In [7]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `ScriptWriter_cpre.forward` and have been ignored: id. If id are not expected by `ScriptWriter_cpre.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 136524
  Num Epochs = 3
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 51198


response.shape =response.shape = torch.Size([4, 50])
response type = <class 'torch.Tensor'>
utterance.shape = torch.Size([4, 11, 50])
 torch.Size([4, 50])
response type = <class 'torch.Tensor'>
utterance.shape = torch.Size([4, 11, 50])
q shape: torch.Size([4, 50, 200])
q shape: torch.Size([4, 50, 200])


/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [4,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [4,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [4,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [4,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [4,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [4,0,0], thread: [69,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectL

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/kotech/venv-torch1102/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/kotech/venv-torch1102/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/tmp/ipykernel_31870/125454331.py", line 316, in forward
    response_embeddings = self.__getattr__('self_multihead_attention_%d' % i)(
  File "/home/kotech/venv-torch1102/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/tmp/ipykernel_31870/125454331.py", line 57, in forward
    Q = self.Q_proj(queries)  # (N, T_q, C)
  File "/home/kotech/venv-torch1102/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kotech/venv-torch1102/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/kotech/venv-torch1102/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kotech/venv-torch1102/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/kotech/venv-torch1102/lib/python3.9/site-packages/torch/nn/functional.py", line 1848, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasCreate(handle)`
