In [None]:
import os
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp -r /content/drive/MyDrive/ECPE/ . # takes ~ 1m 45 s

In [None]:
!tar -xvzf ECPE/audio_feats.tar.gz &> /dev/null
!tar -xvzf ECPE/video_feats.tar.gz &> /dev/null
!tar -xvzf ECPE/text_feats.tar.gz &> /dev/null
# ~ takes 3 mins

In [None]:
import json

anno_pth = "/content/drive/MyDrive/SemEval-2024_Task3/text/Subtask_2_train.json"
anno = json.load(open(anno_pth, "r"))

In [None]:
import torch
import pickle
import json
import numpy as np
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

class MultiModalDataset(Dataset):
    def __init__(self, data_pth):
        self.conv_utt_id_list, self.conv_couples_list, self.y_emotions_list, \
        self.y_causes_list, self.conv_len_list, self.conv_id_list \
        = self.read_data(data_pth)

    def __len__(self):
        return len(self.y_emotions_list)

    def __getitem__(self, idx):
        conv_couples, y_emotions, y_causes = self.conv_couples_list[idx], self.y_emotions_list[idx], self.y_causes_list[idx]
        conv_len, conv_id = self.conv_len_list[idx], self.conv_id_list[idx]
        conv_utt_ids = self.conv_utt_id_list[idx]

        assert conv_len == len(y_emotions)
        assert conv_len == len(conv_utt_ids)

        v, a, t, v_lens, a_lens, t_lens = self.load_tensors(conv_id, conv_utt_ids)

        return conv_couples, y_emotions, y_causes, conv_len, conv_id, \
               v, a, t, v_lens, a_lens, t_lens

    def load_tensors(self, conv_id, utt_ids):
      videos, v_lens = [], []
      audios, a_lens = [], []
      texts, t_lens = [], []

      for utt in utt_ids:
        id = 'dia'+str(conv_id)+'utt'+str(utt)+'.pkl'

        v = torch.tensor(pickle.load(open("ECPE/video_features/"+id, "rb")), dtype=torch.float)
        a = pickle.load(open("ECPE/audio_features/"+id, "rb")).detach()
        t = pickle.load(open("text_features/"+id, "rb")).squeeze().detach()

        videos.append(v); v_lens.append(len(v))
        audios.append(a); a_lens.append(len(a))
        texts.append(t); t_lens.append(len(t))

      videos = pad_sequence(videos, batch_first=True)
      audios = pad_sequence(audios, batch_first=True)
      texts = pad_sequence(texts, batch_first=True)

      # v_mask = self.get_mask(videos, v_lens)
      # a_mask = self.get_mask(audios, a_lens)
      # t_mask = self.get_mask(texts, t_lens)

      return videos, audios, texts, v_lens, a_lens, t_lens

    def get_mask(self, seq, lens):
      n, max_len, dim = seq.size()
      mask = np.zeros([n, max_len])
      for idx, seq_len in enumerate(lens):
        mask[idx][:seq_len] = 1

      return torch.BoolTensor(mask)


    def read_data(self, data_pth):
        data = json.load(open(data_pth, "r"))
        conv_id_list = []
        conv_len_list = []
        conv_utt_id_list = []
        conv_couples_list = []
        y_emotions_list, y_causes_list = [], []

        for conv in data:
          if len(conv["emotion-cause_pairs"]) != 0:
            conv_id_list.append(conv["conversation_ID"])
            utterances = conv["conversation"]
            conv_len = len(utterances)
            conv_len_list.append(conv_len)
            conv_utt_id_list.append([u['utterance_ID'] for u in utterances])

            couples = conv["emotion-cause_pairs"]

            conv_couples = [[int(e.split('_')[0]), int(c)] for e, c in couples]
            conv_emotions, conv_causes = zip(*conv_couples)
            conv_couples_list.append(conv_couples)

            y_emotions, y_causes = [], []
            for i in range(conv_len):
                emotion_label = int(i + 1 in conv_emotions)
                cause_label = int(i + 1 in conv_causes)
                y_emotions.append(emotion_label)
                y_causes.append(cause_label)

            y_emotions_list.append(y_emotions)
            y_causes_list.append(y_causes)

        return conv_utt_id_list, conv_couples_list, y_emotions_list, y_causes_list, conv_len_list, conv_id_list,

In [None]:
import scipy.sparse as sp

TORCH_SEED = 42
gen = torch.Generator().manual_seed(TORCH_SEED)

def build_loaders(configs, anno_pth, shuffle=True, val_ratio=0.2):
    dataset = MultiModalDataset(anno_pth)
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [1-val_ratio, val_ratio], gen)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=configs.batch_size,
                                               shuffle=shuffle, collate_fn=batch_preprocessing)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=configs.batch_size,
                                               shuffle=False, collate_fn=batch_preprocessing)

    return train_loader, val_loader

def batch_preprocessing(batch):
    conv_couples_b, y_emotions_b, y_causes_b, conv_len_b, conv_id_b, \
    v_token_b, a_token_b, t_token_b, v_lens_b, a_lens_b, t_lens_b = zip(*batch)

    y_mask_b, y_emotions_b, y_causes_b = pad_convs(conv_len_b, y_emotions_b, y_causes_b)
    adj_b = pad_matrices(conv_len_b)
    v_token_b = pad_conversations(list(v_token_b), conv_len_b)
    a_token_b = pad_conversations(list(a_token_b), conv_len_b)
    t_token_b = pad_conversations(list(t_token_b), conv_len_b)

    v_mask = get_mask(v_token_b, v_lens_b)
    a_mask = get_mask(a_token_b, a_lens_b)
    t_mask = get_mask(t_token_b, t_lens_b)

    return np.array(conv_len_b), np.array(adj_b), \
           np.array(y_emotions_b), np.array(y_causes_b), np.array(y_mask_b), conv_couples_b, conv_id_b, \
           v_token_b, a_token_b, t_token_b, v_mask, a_mask, t_mask

def pad_conversations(seq_tokens, conv_lens):
  num_conv = len(seq_tokens)
  num_utt = max(conv_lens)
  num_tokens = max([s.size()[1] for s in seq_tokens])

  for i, seq in enumerate(seq_tokens):
    cur_utt, cur_len, _ = seq.size()
    pad = (0, 0, 0, num_tokens - cur_len, 0, num_utt - cur_utt)
    seq_tokens[i] = F.pad(seq, pad, "constant", 0)

  return pad_sequence(seq_tokens, batch_first=True)

def get_mask(seq, lens):
    num_conv, num_utt, max_len, dim = seq.size()
    mask = np.zeros([num_conv, num_utt, max_len])

    for conv, conv_len in enumerate(lens):
      for utt, seq_len in enumerate(conv_len):
        mask[conv][utt][:seq_len] = 1

    return torch.BoolTensor(mask)

def pad_convs(conv_len_b, y_emotions_b, y_causes_b):
    max_conv_len = max(conv_len_b)

    y_mask_b, y_emotions_b_, y_causes_b_ = [], [], []
    for y_emotions, y_causes in zip(y_emotions_b, y_causes_b):
        y_emotions_ = pad_list(y_emotions, max_conv_len, -1)
        y_causes_ = pad_list(y_causes, max_conv_len, -1)
        y_mask = list(map(lambda x: 0 if x == -1 else 1, y_emotions_))

        y_mask_b.append(y_mask)
        y_emotions_b_.append(y_emotions_)
        y_causes_b_.append(y_causes_)

    return y_mask_b, y_emotions_b_, y_causes_b_


def pad_matrices(conv_len_b):
    N = max(conv_len_b)
    adj_b = []
    for conv_len in conv_len_b:
        adj = np.ones((conv_len, conv_len))
        adj = sp.coo_matrix(adj)
        adj = sp.coo_matrix((adj.data, (adj.row, adj.col)),
                            shape=(N, N), dtype=np.float32)
        adj_b.append(adj.toarray())
    return adj_b


def pad_list(element_list, max_len, pad_mark):
    element_list_pad = element_list[:]
    pad_mark_list = [pad_mark] * (max_len - len(element_list))
    element_list_pad.extend(pad_mark_list)
    return element_list_pad

In [None]:
!pip -q install einops transformers

In [None]:
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat, pack
from enum import Enum

class TokenTypes(Enum):
    VIDEO = 0
    AUDIO = 1
    TEXT = 2
    FUSION = 3

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 1
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        pad_mask = None,
        attn_mask = None
    ):
        x = self.norm(x)
        kv_x = x

        q, k, v = (self.to_q(x), *self.to_kv(kv_x).chunk(2, dim = -1))

        q, k, v = map(lambda t:
                      rearrange(t, 'b n (h d) -> b h n d', h = self.heads),
                      (q, k, v))

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        # print("sim:",sim.size())
        if pad_mask is not None:
          # print(pad_mask.size())
          sim = sim.masked_fill(pad_mask, -torch.finfo(sim.dtype).max)

        if attn_mask is not None:
            # print(attn_mask.size())
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return F.gelu(gate) * x

def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult * 2 / 3)
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        GEGLU(),
        nn.Linear(inner_dim, dim, bias = False)
    )

class Zorro_AVT(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        num_fusion_tokens = 8,
        out_dim = 768
    ):
        super().__init__()

        # fusion tokens
        self.num_fusion_tokens = num_fusion_tokens
        self.fusion_tokens = nn.Parameter(torch.randn(num_fusion_tokens, dim).cuda())
        self.fusion_mask = torch.ones(num_fusion_tokens) == 1
        if torch.cuda.is_available():
          self.fusion_mask = self.fusion_mask.cuda()
        # transformer
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        self.norm = LayerNorm(dim)
        self.out_layer = nn.Linear(dim, out_dim, bias = False)

    def forward(
        self,
        video_tokens,
        audio_tokens,
        text_tokens,
        video_mask,
        audio_mask,
        text_mask
    ):
        batch, device = video_tokens.shape[0], video_tokens.device

        fusion_tokens = repeat(self.fusion_tokens, 'n d -> b n d', b = batch)
        fusion_mask = repeat(self.fusion_mask, 'n -> b n', b = batch)

        # construct all tokens
        video_tokens, audio_tokens, text_tokens, fusion_tokens = \
        map(lambda t:
            rearrange(t, 'b ... d -> b (...) d'),
            (video_tokens, audio_tokens, text_tokens, fusion_tokens))


        tokens, ps = pack((
            video_tokens,
            audio_tokens,
            text_tokens,
            fusion_tokens
        ), 'b * d')

        # construct mask (thus zorro)
        token_types = torch.tensor(list((
            *((TokenTypes.VIDEO.value,) * video_tokens.shape[-2]),
            *((TokenTypes.AUDIO.value,) * audio_tokens.shape[-2]),
            *((TokenTypes.TEXT.value,) * text_tokens.shape[-2]),
            *((TokenTypes.FUSION.value,) * fusion_tokens.shape[-2]),
        )), device = device, dtype = torch.long)

        token_types_attend_from = rearrange(token_types, 'i -> i 1')
        token_types_attend_to = rearrange(token_types, 'j -> 1 j')

        # the logic goes
        # every modality, including fusion can attend to self
        zorro_mask = token_types_attend_from == token_types_attend_to

        # fusion can attend to everything
        zorro_mask = zorro_mask | (token_types_attend_from == TokenTypes.FUSION.value)
        # print("zorro_mask:", zorro_mask.size())
        # construct padding mask
        pad_mask = torch.cat((video_mask, audio_mask, text_mask, fusion_mask), -1)[:, None, None, :]
        # print("pad_mask:", pad_mask.size())
        # attend and feedforward
        for attn, ff in self.layers:
            tokens = attn(tokens, pad_mask = pad_mask, attn_mask = zorro_mask) + tokens
            tokens = ff(tokens) + tokens

        tokens = self.norm(tokens)
        fusion_tokens = tokens[:, -self.num_fusion_tokens:, :]
        pooled_tokens = torch.mean(fusion_tokens, 1)

        return self.out_layer(pooled_tokens)

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class GraphAttentionLayer(nn.Module):
    """
    reference: https://github.com/xptree/DeepInf
    """
    def __init__(self, att_head, in_dim, out_dim, dp_gnn, leaky_alpha=0.2):
        super(GraphAttentionLayer, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dp_gnn = dp_gnn

        self.att_head = att_head
        self.W = nn.Parameter(torch.Tensor(self.att_head, self.in_dim, self.out_dim))
        self.b = nn.Parameter(torch.Tensor(self.out_dim))

        self.w_src = nn.Parameter(torch.Tensor(self.att_head, self.out_dim, 1))
        self.w_dst = nn.Parameter(torch.Tensor(self.att_head, self.out_dim, 1))
        self.leaky_alpha = leaky_alpha
        self.init_gnn_param()

        assert self.in_dim == self.out_dim*self.att_head
        self.H = nn.Linear(self.in_dim, self.in_dim)
        init.xavier_normal_(self.H.weight)

    def init_gnn_param(self):
        init.xavier_uniform_(self.W.data)
        init.zeros_(self.b.data)
        init.xavier_uniform_(self.w_src.data)
        init.xavier_uniform_(self.w_dst.data)

    def forward(self, feat_in, adj=None):
        batch, N, in_dim = feat_in.size()
        assert in_dim == self.in_dim

        feat_in_ = feat_in.unsqueeze(1)
        h = torch.matmul(feat_in_, self.W)

        attn_src = torch.matmul(F.tanh(h), self.w_src)
        attn_dst = torch.matmul(F.tanh(h), self.w_dst)
        attn = attn_src.expand(-1, -1, -1, N) + attn_dst.expand(-1, -1, -1, N).permute(0, 1, 3, 2)
        attn = F.leaky_relu(attn, self.leaky_alpha, inplace=True)

        adj = torch.FloatTensor(adj).to(DEVICE)
        mask = 1 - adj.unsqueeze(1)
        attn.data.masked_fill_(mask.bool(), -999)

        attn = F.softmax(attn, dim=-1)
        feat_out = torch.matmul(attn, h) + self.b

        feat_out = feat_out.transpose(1, 2).contiguous().view(batch, N, -1)
        feat_out = F.elu(feat_out)

        gate = F.sigmoid(self.H(feat_in))
        feat_out = gate * feat_out + (1 - gate) * feat_in

        feat_out = F.dropout(feat_out, self.dp_gnn, training=self.training)

        return feat_out

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_dim) + ' -> ' + str(self.out_dim*self.att_head) + ')'

class Network(nn.Module):
    def __init__(self, configs):
        super(Network, self).__init__()

        self.multimodal = Zorro_AVT(1024, 8)
        self.gnn = GraphNN(configs)
        self.pred = Pre_Predictions(configs)
        self.rank = RankNN(configs)
        self.pairwise_loss = configs.pairwise_loss

    def forward(self, conv_len, adj, v_tokens, a_tokens, t_tokens, v_mask, a_mask, t_mask):

        num_utt = v_tokens.size()[1]
        conv_utt_h = []
        for i in range(num_utt):
          v, a, t = v_tokens[:,i,:,:], a_tokens[:,i,:,:], t_tokens[:,i,:,:]
          v_m, a_m, t_m = v_mask[:,i,:], a_mask[:,i,:], t_mask[:,i,:]
          out = self.multimodal(v, a, t, v_m, a_m, t_m)
          conv_utt_h.append(out)
        conv_utt_h = torch.stack(conv_utt_h, 1)

        conv_utt_h = self.gnn(conv_utt_h, conv_len, adj)
        pred_e, pred_c = self.pred(conv_utt_h)

        couples_pred, emo_cau_pos = self.rank(conv_utt_h)

        return couples_pred, emo_cau_pos, pred_e, pred_c

    def loss_rank(self, couples_pred, emo_cau_pos, doc_couples, y_mask, test=False):
        couples_true, couples_mask, doc_couples_pred = self.output_util(couples_pred, emo_cau_pos, doc_couples, y_mask, test)

        if not self.pairwise_loss:
            couples_mask = torch.BoolTensor(couples_mask).to(DEVICE)
            couples_true = torch.FloatTensor(couples_true).to(DEVICE)
            criterion = nn.BCEWithLogitsLoss(reduction='mean')
            couples_true = couples_true.masked_select(couples_mask)
            couples_pred = couples_pred.masked_select(couples_mask)
            loss_couple = criterion(couples_pred, couples_true)
        else:
            x1, x2, y = self.pairwise_util(couples_pred, couples_true, couples_mask)
            criterion = nn.MarginRankingLoss(margin=1.0, reduction='mean')
            loss_couple = criterion(F.tanh(x1), F.tanh(x2), y)

        return loss_couple, doc_couples_pred

    def output_util(self, couples_pred, emo_cau_pos, doc_couples, y_mask, test=False):
        """
        TODO: combine this function to data_loader
        """
        batch, n_couple = couples_pred.size()

        couples_true, couples_mask = [], []
        doc_couples_pred = []
        for i in range(batch):
            y_mask_i = y_mask[i]
            max_doc_idx = sum(y_mask_i)

            doc_couples_i = doc_couples[i]
            couples_true_i = []
            couples_mask_i = []
            for couple_idx, emo_cau in enumerate(emo_cau_pos):
                if emo_cau[0] > max_doc_idx or emo_cau[1] > max_doc_idx:
                    couples_mask_i.append(0)
                    couples_true_i.append(0)
                else:
                    couples_mask_i.append(1)
                    couples_true_i.append(1 if emo_cau in doc_couples_i else 0)

            couples_pred_i = couples_pred[i]
            doc_couples_pred_i = []
            if test:
                if torch.sum(torch.isnan(couples_pred_i)) > 0:
                    k_idx = [0] * 3
                else:
                    _, k_idx = torch.topk(couples_pred_i, k=3, dim=0)
                doc_couples_pred_i = [(emo_cau_pos[idx], couples_pred_i[idx].tolist()) for idx in k_idx]

            couples_true.append(couples_true_i)
            couples_mask.append(couples_mask_i)
            doc_couples_pred.append(doc_couples_pred_i)
        return couples_true, couples_mask, doc_couples_pred

    def loss_pre(self, pred_e, pred_c, y_emotions, y_causes, y_mask):
        y_mask = torch.BoolTensor(y_mask).to(DEVICE)
        y_emotions = torch.FloatTensor(y_emotions).to(DEVICE)
        y_causes = torch.FloatTensor(y_causes).to(DEVICE)

        criterion = nn.BCEWithLogitsLoss(reduction='mean')
        pred_e = pred_e.masked_select(y_mask)
        true_e = y_emotions.masked_select(y_mask)
        loss_e = criterion(pred_e, true_e)

        pred_c = pred_c.masked_select(y_mask)
        true_c = y_causes.masked_select(y_mask)
        loss_c = criterion(pred_c, true_c)
        return loss_e, loss_c

    def pairwise_util(self, couples_pred, couples_true, couples_mask):
        """
        TODO: efficient re-implementation; combine this function to data_loader
        """
        batch, n_couple = couples_pred.size()
        x1, x2 = [], []
        for i in range(batch):
            x1_i_tmp = []
            x2_i_tmp = []
            couples_mask_i = couples_mask[i]
            couples_pred_i = couples_pred[i]
            couples_true_i = couples_true[i]
            for pred_ij, true_ij, mask_ij in zip(couples_pred_i, couples_true_i, couples_mask_i):
                if mask_ij == 1:
                    if true_ij == 1:
                        x1_i_tmp.append(pred_ij.reshape(-1, 1))
                    else:
                        x2_i_tmp.append(pred_ij.reshape(-1))
            m = len(x2_i_tmp)
            n = len(x1_i_tmp)
            x1_i = torch.cat([torch.cat(x1_i_tmp, dim=0)] * m, dim=1).reshape(-1)
            x1.append(x1_i)
            x2_i = []
            for _ in range(n):
                x2_i.extend(x2_i_tmp)
            x2_i = torch.cat(x2_i, dim=0)
            x2.append(x2_i)

        x1 = torch.cat(x1, dim=0)
        x2 = torch.cat(x2, dim=0)
        y = torch.FloatTensor([1] * x1.size(0)).to(DEVICE)
        return x1, x2, y


class GraphNN(nn.Module):
    def __init__(self, configs):
        super(GraphNN, self).__init__()
        in_dim = configs.feat_dim
        self.gnn_dims = [in_dim] + [int(dim) for dim in configs.gnn_dims.strip().split(',')]

        self.gnn_layers = len(self.gnn_dims) - 1
        self.att_heads = [int(att_head) for att_head in configs.att_heads.strip().split(',')]
        self.gnn_layer_stack = nn.ModuleList()
        for i in range(self.gnn_layers):
            in_dim = self.gnn_dims[i] * self.att_heads[i - 1] if i != 0 else self.gnn_dims[i]
            self.gnn_layer_stack.append(
                GraphAttentionLayer(self.att_heads[i], in_dim, self.gnn_dims[i + 1], configs.dp)
            )

    def forward(self, doc_sents_h, doc_len, adj):
        batch, max_doc_len, _ = doc_sents_h.size()
        assert max(doc_len) == max_doc_len

        for i, gnn_layer in enumerate(self.gnn_layer_stack):
            doc_sents_h = gnn_layer(doc_sents_h, adj)

        return doc_sents_h


class RankNN(nn.Module):
    def __init__(self, configs):
        super(RankNN, self).__init__()
        self.K = configs.K
        self.pos_emb_dim = configs.pos_emb_dim
        self.pos_layer = nn.Embedding(2*self.K + 1, self.pos_emb_dim)
        nn.init.xavier_uniform_(self.pos_layer.weight)

        self.feat_dim = int(configs.gnn_dims.strip().split(',')[-1]) * int(configs.att_heads.strip().split(',')[-1])
        self.rank_feat_dim = 2*self.feat_dim + self.pos_emb_dim
        self.rank_layer1 = nn.Linear(self.rank_feat_dim, self.rank_feat_dim)
        self.rank_layer2 = nn.Linear(self.rank_feat_dim, 1)

    def forward(self, doc_sents_h):
        batch, _, _ = doc_sents_h.size()
        couples, rel_pos, emo_cau_pos = self.couple_generator(doc_sents_h, self.K)

        rel_pos = rel_pos + self.K
        rel_pos_emb = self.pos_layer(rel_pos)
        kernel = self.kernel_generator(rel_pos)
        kernel = kernel.unsqueeze(0).expand(batch, -1, -1)
        rel_pos_emb = torch.matmul(kernel, rel_pos_emb)
        couples = torch.cat([couples, rel_pos_emb], dim=2)

        couples = F.relu(self.rank_layer1(couples))
        couples_pred = self.rank_layer2(couples)
        return couples_pred.squeeze(2), emo_cau_pos

    def couple_generator(self, H, k):
        batch, seq_len, feat_dim = H.size()
        P_left = torch.cat([H] * seq_len, dim=2)
        P_left = P_left.reshape(-1, seq_len * seq_len, feat_dim)
        P_right = torch.cat([H] * seq_len, dim=1)
        P = torch.cat([P_left, P_right], dim=2)

        base_idx = np.arange(1, seq_len + 1)
        emo_pos = np.concatenate([base_idx.reshape(-1, 1)] * seq_len, axis=1).reshape(1, -1)[0]
        cau_pos = np.concatenate([base_idx] * seq_len, axis=0)

        rel_pos = cau_pos - emo_pos
        rel_pos = torch.LongTensor(rel_pos).to(DEVICE)
        emo_pos = torch.LongTensor(emo_pos).to(DEVICE)
        cau_pos = torch.LongTensor(cau_pos).to(DEVICE)

        if seq_len > k + 1:
            rel_mask = np.array(list(map(lambda x: -k <= x <= k, rel_pos.tolist())), dtype=np.int)
            rel_mask = torch.BoolTensor(rel_mask).to(DEVICE)
            rel_pos = rel_pos.masked_select(rel_mask)
            emo_pos = emo_pos.masked_select(rel_mask)
            cau_pos = cau_pos.masked_select(rel_mask)

            rel_mask = rel_mask.unsqueeze(1).expand(-1, 2 * feat_dim)
            rel_mask = rel_mask.unsqueeze(0).expand(batch, -1, -1)
            P = P.masked_select(rel_mask)
            P = P.reshape(batch, -1, 2 * feat_dim)
        assert rel_pos.size(0) == P.size(1)
        rel_pos = rel_pos.unsqueeze(0).expand(batch, -1)

        emo_cau_pos = []
        for emo, cau in zip(emo_pos.tolist(), cau_pos.tolist()):
            emo_cau_pos.append([emo, cau])
        return P, rel_pos, emo_cau_pos

    def kernel_generator(self, rel_pos):
        n_couple = rel_pos.size(1)
        rel_pos_ = rel_pos[0].type(torch.FloatTensor).to(DEVICE)
        kernel_left = torch.cat([rel_pos_.reshape(-1, 1)] * n_couple, dim=1)
        kernel = kernel_left - kernel_left.transpose(0, 1)
        return torch.exp(-(torch.pow(kernel, 2)))


class Pre_Predictions(nn.Module):
    def __init__(self, configs):
        super(Pre_Predictions, self).__init__()
        self.feat_dim = int(configs.gnn_dims.strip().split(',')[-1]) * int(configs.att_heads.strip().split(',')[-1])
        self.out_e = nn.Linear(self.feat_dim, 1)
        self.out_c = nn.Linear(self.feat_dim, 1)

    def forward(self, doc_sents_h):
        pred_e = self.out_e(doc_sents_h)
        pred_c = self.out_c(doc_sents_h)
        return pred_e.squeeze(2), pred_c.squeeze(2)

In [None]:
import pickle, json, decimal, math


def to_np(x):
    return x.data.cpu().numpy()


def logistic(x):
    return 1 / (1 + math.exp(-x))


def eval_func(doc_couples_all, doc_couples_pred_all):
    tmp_num = {'ec': 0, 'e': 0, 'c': 0}
    tmp_den_p = {'ec': 0, 'e': 0, 'c': 0}
    tmp_den_r = {'ec': 0, 'e': 0, 'c': 0}

    for doc_couples, doc_couples_pred in zip(doc_couples_all, doc_couples_pred_all):
        doc_couples = set([','.join(list(map(lambda x: str(x), doc_couple))) for doc_couple in doc_couples])
        doc_couples_pred = set([','.join(list(map(lambda x: str(x), doc_couple))) for doc_couple in doc_couples_pred])

        tmp_num['ec'] += len(doc_couples & doc_couples_pred)
        tmp_den_p['ec'] += len(doc_couples_pred)
        tmp_den_r['ec'] += len(doc_couples)

        doc_emos = set([doc_couple.split(',')[0] for doc_couple in doc_couples])
        doc_emos_pred = set([doc_couple.split(',')[0] for doc_couple in doc_couples_pred])
        tmp_num['e'] += len(doc_emos & doc_emos_pred)
        tmp_den_p['e'] += len(doc_emos_pred)
        tmp_den_r['e'] += len(doc_emos)

        doc_caus = set([doc_couple.split(',')[1] for doc_couple in doc_couples])
        doc_caus_pred = set([doc_couple.split(',')[1] for doc_couple in doc_couples_pred])
        tmp_num['c'] += len(doc_caus & doc_caus_pred)
        tmp_den_p['c'] += len(doc_caus_pred)
        tmp_den_r['c'] += len(doc_caus)

    metrics = {}
    for task in ['ec', 'e', 'c']:
        p = tmp_num[task] / (tmp_den_p[task] + 1e-8)
        r = tmp_num[task] / (tmp_den_r[task] + 1e-8)
        f = 2 * p * r / (p + r + 1e-8)
        metrics[task] = (p, r, f)

    return metrics['ec'], metrics['e'], metrics['c']

In [None]:
def inference_one_batch(batch, model):
    conv_len, adj_b, y_emotions_b, y_causes_b, y_mask_b, conv_couples_b, conv_id_b, \
    v_token_b, a_token_b, t_token_b, v_mask, a_mask, t_mask = batch

    couples_pred, emo_cau_pos, pred_e, pred_c = model(conv_len, adj_b,
                                                      v_token_b, a_token_b, t_token_b,
                                                      v_mask, a_mask, t_mask)

    loss_e, loss_c = model.loss_pre(pred_e, pred_c, y_emotions_b, y_causes_b, y_mask_b)
    loss_couple, conv_couples_pred_b = model.loss_rank(couples_pred, emo_cau_pos, conv_couples_b, y_mask_b, test=True)

    return to_np(loss_couple), to_np(loss_e), to_np(loss_c), \
           conv_couples_b, conv_couples_pred_b, conv_id_b

def inference_one_epoch(batches, model):
    conv_id_all, conv_couples_all, conv_couples_pred_all = [], [], []
    for batch in tqdm(batches):
        _, _, _, conv_couples, conv_couples_pred, conv_id_b = inference_one_batch(batch, model)
        conv_id_all.extend(conv_id_b)
        conv_couples_all.extend(conv_couples)
        conv_couples_pred_all.extend(conv_couples_pred)

    conv_couples_pred_all = lexicon_based_extraction(conv_id_all, conv_couples_pred_all)
    metric_ec, metric_e, metric_c = eval_func(conv_couples_all, conv_couples_pred_all)
    return metric_ec, metric_e, metric_c, conv_id_all, conv_couples_all, conv_couples_pred_all

def lexicon_based_extraction(conv_ids, couples_pred):

    couples_pred_filtered = []
    for i, (conv_id, couples_pred_i) in enumerate(zip(conv_ids, couples_pred)):
        top1, top1_prob = couples_pred_i[0][0], couples_pred_i[0][1]
        couples_pred_i_filtered = [top1]

        emotional_clauses_i = set([p[0] for p in conversations[conv_id]['pairs']])
        for couple in couples_pred_i[1:]:
            if couple[0][0] in emotional_clauses_i and logistic(couple[1]) > 0.5:
                couples_pred_i_filtered.append(couple[0])

        couples_pred_filtered.append(couples_pred_i_filtered)
    return couples_pred_filtered

In [None]:
class Config(object):
    def __init__(self):
        self.feat_dim = 768

        self.gnn_dims = '192'
        self.att_heads = '4'
        self.K = 12
        self.pos_emb_dim = 50
        self.pairwise_loss = False

        self.epochs = 15
        self.lr = 1e-5
        self.batch_size = 1
        self.gradient_accumulation_steps = 2
        self.dp = 0.1
        self.l2 = 1e-5
        self.l2_bert = 0.01
        self.warmup_proportion = 0.1
        self.adam_epsilon = 1e-8

In [None]:
configs = Config()

torch.manual_seed(TORCH_SEED)
torch.cuda.manual_seed_all(TORCH_SEED)
torch.backends.cudnn.deterministic = True

train_loader, val_loader = build_loaders(configs, anno_pth)

In [None]:
from transformers import AdamW, get_linear_schedule_with_warmup

model = Network(configs).to(DEVICE)

params = model.parameters()
optimizer = AdamW(params, lr=configs.lr)

num_steps_all = len(train_loader) // configs.gradient_accumulation_steps * configs.epochs
warmup_steps = int(num_steps_all * configs.warmup_proportion)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_steps_all)

total_params = sum([p.numel() for p in filter(lambda p: p.requires_grad, model.parameters())])
print(f"Total trainable parameters in model: {total_params:,}")



Total trainable parameters in model: 88,385,229


In [None]:
from tqdm import tqdm

model.zero_grad()
max_ec, max_e, max_c = (-1, -1, -1), None, None
metric_ec, metric_e, metric_c = (-1, -1, -1), None, None
early_stop_flag = None
for epoch in range(1, 30):
    print("=====Epoch {}=====".format(epoch))
    with tqdm(train_loader, unit='batch') as tepoch:
      minibatch = 0
      t_loss, l_couple, l_e, l_c = 0.,0.,0.,0.
      for train_step, batch in enumerate(tepoch, 1):
          minibatch += 1
          model.train()
          conv_len, adj_b, y_emotions_b, y_causes_b, y_mask_b, conv_couples_b, conv_id_b, \
          v_token_b, a_token_b, t_token_b, v_mask, a_mask, t_mask = batch

          couples_pred, emo_cau_pos, pred_e, pred_c = model(conv_len, adj_b,
                                                            v_token_b.cuda(), a_token_b.cuda(), t_token_b.cuda(),
                                                            v_mask.cuda(), a_mask.cuda(), t_mask.cuda())
          loss_e, loss_c = model.loss_pre(pred_e, pred_c, y_emotions_b, y_causes_b, y_mask_b)
          loss_couple, _ = model.loss_rank(couples_pred, emo_cau_pos, conv_couples_b, y_mask_b)
          loss = loss_couple + loss_e + loss_c

          loss = loss / configs.gradient_accumulation_steps

          loss.backward()
          t_loss = (t_loss + loss.item())/minibatch
          l_couple = (l_couple + loss_couple.item())/minibatch
          l_e = (l_e + loss_e.item())/minibatch
          l_c = (l_c + loss_c.item())/minibatch
          tepoch.set_postfix({'total_loss':t_loss , 'loss_couple': l_couple, \
                              'loss_e': l_e, 'loss_c': l_c})

          if train_step % configs.gradient_accumulation_steps == 0:
              optimizer.step()
              scheduler.step()
              model.zero_grad()

=====Epoch 1=====


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  rel_mask = np.array(list(map(lambda x: -k <= x <= k, rel_pos.tolist())), dtype=np.int)
100%|██████████| 1028/1028 [06:48<00:00,  2.52batch/s, total_loss=0.000875, loss_couple=0.000392, loss_e=0.0008, loss_c=0.000558]


=====Epoch 2=====


100%|██████████| 1028/1028 [06:43<00:00,  2.55batch/s, total_loss=0.000748, loss_couple=0.000134, loss_e=0.000631, loss_c=0.000733]


=====Epoch 3=====


  7%|▋         | 70/1028 [00:26<05:45,  2.77batch/s, total_loss=0.0113, loss_couple=0.0029, loss_e=0.00927, loss_c=0.0104]

In [None]:
from tqdm import tqdm

for batch in tqdm(train_loader):
  conv_len, adj_b, y_emotions_b, y_causes_b, y_mask_b, conv_couples_b, conv_id_b, \
  v_token_b, a_token_b, t_token_b, v_mask, a_mask, t_mask = batch

  num_utt = v_token_b.size()[1]
  outs = []
  with torch.no_grad():
    for i in range(num_utt):
      v, a, t = v_token_b[:,i,:,:], a_token_b[:,i,:,:], t_token_b[:,i,:,:]
      # print(v.size(), a.size(), t.size())
      v_m, a_m, t_m = v_mask[:,i,:], a_mask[:,i,:], t_mask[:,i,:]
      # print(v_m.size(), a_m.size(), t_m.size())
      out = model(v, a, t, v_m, a_m, t_m)
  outs = torch.stack(outs, 1)
  break