In [None]:
import torch
from transformers import AutoModel
from crf import *


class CRFModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dropout = config['dropout']
        self.num_classes = config['num_classes']
        self.pad_value = config['pad_value']
        self.CLS = config['CLS']
        self.context_encoder = AutoModel.from_pretrained(
            config['bert_path'])
        self.dim = self.context_encoder.embeddings.word_embeddings.weight.data.shape[-1]
        self.spk_embeddings = nn.Embedding(300, self.dim)
        self.crf_layer = CRF(self.num_classes)
        self.emission = nn.Linear(self.dim, self.num_classes)
        self.loss_func = torch.nn.CrossEntropyLoss(ignore_index=-1)
    
    def device(self):
        return self.context_encoder.device

    def forward(self, sentences, sentences_mask, speaker_ids, last_turns, emotion_idxs=None):
        '''
        sentences: batch * max_turns * max_length
        speaker_ids: batch * max_turns
        emotion[optional] : batch * max_turns
        '''
        batch_size = sentences.shape[0]
        max_turns = sentences.shape[1]
        max_len = sentences.shape[-1]
        speaker_ids = speaker_ids.reshape(batch_size * max_turns, -1)
        sentences = sentences.reshape(batch_size * max_turns, -1)
        cls_id = torch.ones_like(speaker_ids) * self.CLS
        input_ids = torch.cat((cls_id, sentences), 1)
        mask = 1 - (input_ids == (self.pad_value)).long()
        # with torch.no_grad():
        utterance_encoded = self.context_encoder(
            input_ids=input_ids,
            attention_mask=mask,
            output_hidden_states=True,
            return_dict=True
        )['last_hidden_state']
        mask_pos = mask.sum(1)-2
        features = utterance_encoded[torch.arange(mask_pos.shape[0]), mask_pos, :]
        emissions = self.emission(features)
        crf_emissions = emissions.reshape(batch_size, max_turns, -1)
        crf_emissions = crf_emissions.transpose(0, 1)
        sentences_mask = sentences_mask.transpose(0, 1)
        speaker_ids = speaker_ids.reshape(batch_size, max_turns).transpose(0, 1)
        last_turns = last_turns.transpose(0, 1)
        # train
        if emotion_idxs is not None:
            emotion_idxs = emotion_idxs.transpose(0, 1)
            loss1 = -self.crf_layer(crf_emissions, emotion_idxs, mask=sentences_mask)
            # 接上分类loss让CRF专注序列信息
            loss2 = self.loss_func(emissions.view(-1, self.num_classes), emotion_idxs.view(-1))
            loss = loss1 + loss2
            return loss
        # test
        else:
            return self.crf_layer.decode(crf_emissions, mask=sentences_mask)

In [None]:
feature -> 
mask_pos = mask.sum(1)-2 -> sum over horizontal row (diff columns) -> output [1,row]
# change transpose(0,1) -> sentences_mask -> .reshape[batch_size, -1]

In [None]:
# 윤전 + 창현 버전
class CRFModel(nn.Module):
    def __init__(self, numClasses, dropout, bert_path):
        super().__init__()
        self.numClasses = numClasses
        self.dropout = dropout
        self.padValue = 1 # pad value
        # CLS
        tokenizer = AutoTokenizer.from_pretrained(bert_path)
        self.CLS = tokenizer('')['input_ids'][0]
        self.encoder = AutoModel.from_pretrained(bert_path)
        self.dimension = self.encoder.embeddings.word_embeddings.weight.data.shape[-1]
        self.spkEmbeddings = nn.Embedding(300, self.dimension)
        self.CRFlayer = CRF(self.numClasses)
        self.emission = nn.Linear(self.dimension, self.numClasses)
        self.lossFunc = torch.nn.CrossEntropyLoss(ignore_index=-1)
    
    def device(self):
        return self.encoder.device

    def forward(self, sentences, sentencesMask, speakerIds, lastTurns, emotionIdxes=None):
        '''
        sentences: batch * max_turns * max_length
        speaker_ids: batch * max_turns
        emotion[optional] : batch * max_turns
        '''

        #my implementation
        sentBatchSize, sentMaxTurns, sentMaxLen = sentences.shape[0], sentences.shape[1], sentences.shape[2]
        speakerBatchSize, speakerMaxTurns = speakerIds.shape[0], speakerIds.shape[1]

        sentInputRowNum = sentBatchSize*sentMaxTurns
        speakerInputRowNum = speakerBatchSize*speakerMaxTurns

        sentencesReshaped = sentences.reshape(sentInputRowNum, -1)
        speakerIdsReshaped = speakerIds.reshape(speakerInputRowNum, -1)


        clsId = torch.ones(speakerIdsReshaped.size(), dtype=speakerIdsReshaped.dtype, \
                            layout=speakerIdsReshaped.layout, device=speakerIdsReshaped.device) * self.CLS
        # clsId = torch.ones_like(speakerIdsReshaped) * self.CLS
        inputIds = torch.concat(tensors=(clsId, sentencesReshaped), dim=1)
        
        # mask is used to avoid/ignore padded values of the input tensor
        # masking indices should be {0: if padded, 1: if not padded}
        inputIds[inputIds==self.padValue] = 0
        inputIds[inputIds!=self.padValue] = 1
        attentionMask = inputIds

        # with torch.no_grad():
        utteranceEncoded = self.context_encoder(
            input_ids=inputIds,
            attention_mask=attentionMask,
            output_hidden_states=True,
            return_dict=True
        )['last_hidden_state']


        maskPos = torch.sum(input=mask, dim=1, keepdim=False) - 2
        # change below
        features = utteranceEncoded[torch.arange(maskPos.shape[0]), maskPos, :]
        emissions = self.emission(features)
        crfEmissions = emissions.reshape(sentBatchSize, sentMaxTurns, -1).transpose(0, 1)



        sentencesMask = torch.transpose(sentencesMask, dim0=0, dim1=1)
        # check if it runs, if not it may mean speaker and sentence batch size are different
        speakerIds = torch.transpose(speakerIds.reshape(speakerBatchSize, speakerMaxTurns), dim0=0, dim1=1) 
        lastTurns = torch.transpose(lastTurns, dim0=0, dim1=1)



        # train
        if emotionIdxes is not None:
            emotionIdxes = emotionIdxes.transpose(0, 1)
            return -self.CRFlayer(crfEmissions, emotionIdxes, mask=sentencesMask) + self.lossFunc(emissions.view(-1, self.numClasses), emotionIdxes.view(-1))
        else:
            return self.CRFlayer.decode(crfEmissions, mask=sentencesMask)

In [None]:
# 윤전 버전
class CRFModel(nn.Module):

    def __init__(self, numClasses, dropout, bert_path):
        super().__init__()
        self.numClasses = numClasses
        self.dropout = dropout
        self.padValue = 1 # pad value
        # CLS
        tokenizer = AutoTokenizer.from_pretrained(bert_path)
        self.CLS = tokenizer('')['input_ids'][0]
        self.encoder = AutoModel.from_pretrained(bert_path)
        self.dimension = self.encoder.embeddings.word_embeddings.weight.data.shape[-1]
        self.spkEmbeddings = nn.Embedding(300, self.dimension)
        self.CRFlayer = CRF(self.numClasses)
        self.emission = nn.Linear(self.dimension, self.numClasses)
        self.lossFunc = torch.nn.CrossEntropyLoss(ignore_index=-1)
    def device(self):
        return self.encoder.device

    def forward(self, sentences, sentencesMask, speakerIds, lastTurns, emotionIdxes=None):
    #def forward(self, sentences, sentencesMask, speakerIds, emotionIdxes):
        batchSize, maxTurns = sentences.shape[0], sentences.shape[1]

        speakerIdsReshaped = speakerIds.reshape(batchSize * maxTurns, -1)
        sentencesReshaped = sentences.reshape(batchSize * maxTurns, -1)

        clsId = torch.ones_like(speakerIdsReshaped) * self.CLS
        inputIds = torch.cat((clsId, sentencesReshaped), 1)
        mask = 1 - (inputIds == (self.padValue)).long()

        utteranceEncoded = self.encoder(
            input_ids=inputIds,
            attention_mask=mask,
            output_hidden_states=True,
            return_dict=True
        )['last_hidden_state']

        maskPos = mask.sum(1)-2
        features = utteranceEncoded[torch.arange(maskPos.shape[0]), maskPos, :]
        emissions = self.emission(features)
        crfEmissions = emissions.reshape(batchSize, maxTurns, -1).transpose(0, 1)
        sentencesMask = sentencesMask.transpose(0, 1)
        speakerIds = speakerIds.reshape(batchSize, maxTurns).transpose(0, 1)
        lastTurns = lastTurns.transpose(0, 1)

        # train
        if emotionIdxes is not None:
            emotionIdxes = emotionIdxes.transpose(0, 1)
            return -self.CRFlayer(crfEmissions, emotionIdxes, mask=sentencesMask) + self.lossFunc(emissions.view(-1, self.numClasses), emotionIdxes.view(-1))
        else:
            return self.CRFlayer.decode(crfEmissions, mask=sentencesMask)

In [30]:
import torch
a = torch.tensor([[2., -1., 0], [1., -1., 0]])
print(a)
a == 2

tensor([[ 2., -1.,  0.],
        [ 1., -1.,  0.]])


tensor([[ True, False, False],
        [False, False, False]])

In [12]:
a.bool().int()

tensor([[1, 1, 0],
        [1, 1, 0]], dtype=torch.int32)

In [14]:
torch.eq(a, 0)

tensor([[False, False,  True],
        [False, False,  True]])

In [33]:
a = torch.where(a == 2, 0, 1)
a

tensor([[0, 1, 1],
        [1, 1, 1]])

In [11]:
a_pos=a.sum(1)

In [12]:
a_pos.shape

torch.Size([2])

In [14]:
torch.sum(a,1)

tensor([1., 0.])

In [21]:
1-a

tensor([[0., 2., 1.],
        [0., 2., 1.]])

In [37]:
torch.tensor(True).long()

tensor(1)