In [None]:
from transformers.models.bert.modeling_bert import BertForTokenClassification, BertPooler, BertSelfAttention
from torch.nn import CrossEntropyLoss
import torch
import torch.nn as nn
import numpy as np
import copy

# Define hardcoded values
max_seq_length = 128  # replace with your value
dropout = 0.1  # replace with your value
SRD = 3  # replace with your value
use_unique_bert = False  # replace with your value
local_context_focus = 'cdm'  # replace with your value
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SelfAttention(nn.Module):
    def __init__(self, config):
        super(SelfAttention, self).__init__()
        self.config = config
        self.SA = BertSelfAttention(config)
        self.tanh = torch.nn.Tanh()

    def forward(self, inputs):
        zero_vec = np.zeros((inputs.size(0), 1, 1, max_seq_length))
        zero_tensor = torch.tensor(zero_vec).float().to(device)
        SA_out = self.SA(inputs, zero_tensor)
        return self.tanh(SA_out[0])


class LCF_ATEPC(BertForTokenClassification):
    def __init__(self, bert_base_model):
        super(LCF_ATEPC, self).__init__(config=bert_base_model.config)
        config = bert_base_model.config
        self.bert_for_global_context = bert_base_model
        if not use_unique_bert:
            self.bert_for_local_context = copy.deepcopy(self.bert_for_global_context)
        else:
            self.bert_for_local_context = self.bert_for_global_context
        self.pooler = BertPooler(config)
        self.num_emotion_labels = 6
        self.dense = torch.nn.Linear(768, 4)  # For aspect categories
        self.emotion_classifier = nn.Linear(config.hidden_size, 6)  # 6 for the number of emotions
        self.bert_global_focus = self.bert_for_global_context
        self.dropout = nn.Dropout(dropout)
        self.SA1 = SelfAttention(config)
        self.SA2 = SelfAttention(config)
        self.linear_double = nn.Linear(768 * 2, 768)
        self.linear_triple = nn.Linear(768 * 3, 768)

    @property
    def device(self):
        return self._device

    @device.setter
    def device(self, value):
        self._device = value

    def get_batch_token_labels_bert_base_indices(self, labels):
        if labels is None:
            return
        labels = labels.detach().cpu().numpy()
        for text_i in range(len(labels)):
            sep_index = np.argmax((labels[text_i] == 5))
            labels[text_i][sep_index + 1:] = 0
        return torch.tensor(labels).to(self.args.device)
    def get_batch_polarities(self, b_polarities):
        b_polarities = b_polarities.detach().cpu().numpy()
        shape = b_polarities.shape
        polarities = np.zeros((shape[0]))
        i = 0
        for polarity in b_polarities:
            polarity_idx = np.flatnonzero(polarity + 1)
            try:
                polarities[i] = polarity[polarity_idx[0]]
            except:
                pass
            i += 1
        polarities = torch.from_numpy(polarities).long().to(self.args.device)
        return polarities

    def get_batch_emotions(self, b_emotions):
        b_emotions = b_emotions.detach().cpu().numpy()
        batch_size = b_emotions.shape[0]
        max_seq_length = self.args.max_seq_length
        emotions = np.zeros((batch_size, max_seq_length))
        for i in range(batch_size):
            emotions[i, :len(b_emotions[i])] = b_emotions[i]
        emotions = torch.from_numpy(emotions).long().to(self.args.device)
        return emotions

    def feature_dynamic_weighted(self, text_local_indices, polarities):
        text_ids = text_local_indices.detach().cpu().numpy()
        asp_ids = polarities.detach().cpu().numpy()
        weighted_text_raw_indices = np.ones((text_local_indices.size(0), text_local_indices.size(1), 768),
                                            dtype=np.float32)
        SRD = self.args.SRD
        for text_i, asp_i in zip(range(len(text_ids)), range(len(asp_ids))):
            a_ids = np.flatnonzero(asp_ids[asp_i] + 1)
            text_len = np.flatnonzero(text_ids[text_i])[-1] + 1
            asp_len = len(a_ids)
            try:
                asp_begin = a_ids[0]
            except:
                asp_begin = 0
            asp_avg_index = (asp_begin * 2 + asp_len) / 2
            distances = np.zeros((text_len), dtype=np.float32)
            for i in range(len(distances)):
                if abs(i - asp_avg_index) + asp_len / 2 > SRD:
                    distances[i] = 1 - (abs(i - asp_avg_index) + asp_len / 2 - SRD) / len(distances)
                else:
                    distances[i] = 1
            for i in range(len(distances)):
                weighted_text_raw_indices[text_i][i] = weighted_text_raw_indices[text_i][i] * distances[i]
        weighted_text_raw_indices = torch.from_numpy(weighted_text_raw_indices)
        return weighted_text_raw_indices.to(self.args.device)

    def feature_dynamic_mask(self, text_local_indices, polarities):
        text_ids = text_local_indices.detach().cpu().numpy()
        asp_ids = polarities.detach().cpu().numpy()
        SRD = self.args.SRD
        masked_text_raw_indices = np.ones((text_local_indices.size(0), text_local_indices.size(1), 768),
                                          dtype=np.float32)
        for text_i, asp_i in zip(range(len(text_ids)), range(len(asp_ids))):
            a_ids = np.flatnonzero(asp_ids[asp_i] + 1)
            try:
                asp_begin = a_ids[0]
            except:
                asp_begin = 0
            asp_len = len(a_ids)
            if asp_begin >= SRD:
                mask_begin = asp_begin - SRD
            else:
                mask_begin = 0
            for i in range(mask_begin):
                masked_text_raw_indices[text_i][i] = np.zeros(768, dtype=np.float64)
            for j in range(asp_begin + asp_len + SRD - 1, self.args.max_seq_length):
                masked_text_raw_indices[text_i][j] = np.zeros(768, dtype=np.float64)
        masked_text_raw_indices = torch.from_numpy(masked_text_raw_indices)
        return masked_text_raw_indices.to(self.args.device)

    def get_ids_for_local_context_extractor(self, text_indices):
        text_ids = text_indices.detach().cpu().numpy()
        for text_i in range(len(text_ids)):
            sep_index = np.argmax((text_ids[text_i] == 102))
            text_ids[text_i][sep_index + 1:] = 0
        return torch.tensor(text_ids).to(self.args.device)

    def forward(self, input_ids_spc, token_type_ids=None, attention_mask=None, labels=None, polarities=None,
                valid_ids=None, attention_mask_label=None, emotions=None):
        global_context_out = self.bert_for_global_context(input_ids_spc, token_type_ids, attention_mask)[0]
        polarity_labels = self.get_batch_polarities(polarities)
        emotion_labels = self.get_batch_emotions(emotions)

        batch_size, max_len, feat_dim = global_context_out.shape
        global_valid_output = torch.zeros(batch_size, max_len, feat_dim, dtype=torch.float32).to(self.args.device)
        for i in range(batch_size):
            jj = -1
            for j in range(max_len):
                if valid_ids[i][j].item() == 1:
                    jj += 1
                    global_valid_output[i][jj] = global_context_out[i][j]
        global_context_out = self.dropout(global_valid_output)
        ate_logits = self.classifier(global_context_out)
        emotion_logits = self.emotion_classifier(global_context_out)

        if self.args.local_context_focus is not None:
            local_context_ids = input_ids_spc  # Define local_context_ids here
            local_context_out = self.bert_for_local_context(input_ids_spc, token_type_ids, attention_mask)[0]
            local_context_out = torch.mul(local_context_out, attention_mask.unsqueeze(2))
            local_context_out = self.dropout(local_context_out)
            if 'cdm' in self.args.local_context_focus:
                cdm_vec = self.feature_dynamic_mask(local_context_ids, polarities)
                cdm_context_out = torch.mul(local_context_out, cdm_vec)
                cdm_context_out = self.SA1(cdm_context_out)
                cat_out = torch.cat((global_context_out, cdm_context_out), dim=-1)
                cat_out = self.linear_double(cat_out)
            elif 'cdw' in self.args.local_context_focus:
                cdw_vec = self.feature_dynamic_weighted(local_context_ids, polarities)
                cdw_context_out = torch.mul(local_context_out, cdw_vec)
                cdw_context_out = self.SA1(cdw_context_out)
                cat_out = torch.cat((global_context_out, cdw_context_out), dim=-1)
                cat_out = self.linear_double(cat_out)
            elif 'fusion' in self.args.local_context_focus:
                cdm_vec = self.feature_dynamic_mask(local_context_ids, polarities)
                cdm_context_out = torch.mul(local_context_out, cdm_vec)
                cdw_vec = self.feature_dynamic_weighted(local_context_ids, polarities)
                cdw_context_out = torch.mul(local_context_out, cdw_vec)
                cat_out = torch.cat((global_context_out, cdw_context_out, cdm_context_out), dim=-1)
                cat_out = self.linear_triple(cat_out)
            sa_out = self.SA2(cat_out)
            pooled_out = self.pooler(sa_out)
        else:
            pooled_out = self.pooler(global_context_out)
        pooled_out = self.dropout(pooled_out)
        apc_logits = self.dense(pooled_out)

        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=0)
            loss_sen = CrossEntropyLoss()
            ignore_index = -1
            loss_emo = CrossEntropyLoss(ignore_index=ignore_index)
            loss_ate = loss_fct(ate_logits.view(-1, self.num_labels), labels.view(-1))
            loss_apc = loss_sen(apc_logits, polarity_labels)
            loss_emo = loss_emo(emotion_logits.view(-1, self.num_emotion_labels), emotion_labels.view(-1))
            total_loss = loss_ate.item() + loss_apc.item() + loss_emo.item()
            return total_loss
        else:
            return ate_logits, apc_logits, emotion_logits

In [ ]:
# -*- coding: utf-8 -*-
# file: data_utils.py
# author: yangheng <yangheng@m.scnu.edu.cn>
# Copyright (C) 2019. All Rights Reserved.

import os
import random

from sklearn.preprocessing import LabelEncoder

# Define hardcoded values
max_seq_length = 128  # replace with your value

class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, sentence_label=None, aspect_label=None, polarity=None, emotion=None):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.sentence_label = sentence_label
        self.aspect_label = aspect_label
        self.polarity = polarity
        self.emotion = emotion


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids_spc, input_mask, segment_ids, label_id, polarities=None, valid_ids=None,
                 label_mask=None, emotions=None):
        self.input_ids_spc = input_ids_spc
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id
        self.valid_ids = valid_ids
        self.label_mask = label_mask
        self.polarities = polarities
        self.emotions = emotions


def readfile(filename):
    '''
    read file
    '''
    f = open(filename, encoding='utf8')
    data = []
    sentence = []
    tag = []
    polarity = []
    emotion = []  # Add this line
    for line in f:
        if len(line) == 0 or line.startswith('-DOCSTART') or line[0] == "\n":
            if len(sentence) > 0:
                data.append((sentence, tag, polarity, emotion))  # Modify this line
                sentence = []
                tag = []
                polarity = []
                emotion = []  # Add this line
            continue
        splits = line.split(' ')
        if len(splits) != 4:  # Modify this line
            print('warning! detected error line(s) in input file:{}'.format(line))
        sentence.append(splits[0])
        tag.append(splits[-3])  # Modify this line
        polarity.append(int(splits[-2]))  # Modify this line
        emotion.append(splits[-1][:-1])  # Add this line

    if len(sentence) > 0:
        data.append((sentence, tag, polarity, emotion))  # Modify this line
    return data


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        return readfile(input_file)


# Initialize the LabelEncoder
le = LabelEncoder()

# Fit the LabelEncoder to the emotions
le.fit(["Joy", "Anger", "Fear", "Sadness", "Surprise", "Disgust", "None"])  # Add all possible emotions


class ATEPCProcessor(DataProcessor):
    """Processor for the CoNLL-2003 data set."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "Restaurants.atepc.train.dat")), "train")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "Restaurants.atepc.test.dat")), "test")

    def get_labels(self):
        # return ["O", "B-ASP", "I-ASP", "[CLS]", "[SEP]", "Anger", "Disgust", "Fear", "Joy", "Sadness", "Surprise"]
        return ["O", "B-ASP", "I-ASP", "[CLS]", "[SEP]"]

    def _create_examples(self, lines, set_type):
        examples = []
        for i, (sentence, tag, polarity, emotion) in enumerate(lines):
            aspect = []
            aspect_tag = []
            aspect_polarity = [-1]
            aspect_emotion = []  # Add this line
            for w, t, p, e in zip(sentence, tag, polarity, emotion):  # Modify this line
                if p != -1:
                    aspect.append(w)
                    aspect_tag.append(t)
                    aspect_polarity.append(-1)
                aspect_emotion.append(e)  # Add this line
            guid = "%s-%s" % (set_type, i)
            text_a = sentence
            text_b = aspect
            polarity.extend(aspect_polarity)
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, sentence_label=tag,
                                         aspect_label=aspect_tag, polarity=polarity,
                                         emotion=aspect_emotion))
        return examples


def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
    """Loads a data file into a list of `InputBatch`s."""

    label_map = {label: i for i, label in enumerate(label_list, 1)}

    features = []
    for (ex_index, example) in enumerate(examples):
        text_spc_tokens = example.text_a
        aspect_tokens = example.text_b
        sentence_label = example.sentence_label
        aspect_label = example.aspect_label
        polaritiylist = example.polarity
        emotionlist = example.emotion  # Add this line
        tokens = []
        labels = []
        polarities = []
        emotions = []  # Add this line
        valid = []
        label_mask = []
        text_spc_tokens.extend(['[SEP]'])
        text_spc_tokens.extend(aspect_tokens)
        enum_tokens = text_spc_tokens
        diff = len(enum_tokens) - len(emotionlist)
        if diff > 0:
            emotionlist += emotionlist * (diff // len(emotionlist)) + emotionlist[:diff % len(emotionlist)]
        sentence_label.extend(['[SEP]'])
        emotionlist = list(map(int, emotionlist))
        sentence_label.extend(aspect_label)
        label_lists = sentence_label
        for i, word in enumerate(enum_tokens):
            token = tokenizer.tokenize(word)
            tokens.extend(token)
            label_1 = label_lists[i]
            polarity_1 = polaritiylist[i]
            emotion_1 = emotionlist[i]
            for m in range(len(token)):
                if m == 0:
                    labels.append(label_1)
                    polarities.append(polarity_1)
                    emotions.append(emotion_1)  # Add this line
                    valid.append(1)
                    label_mask.append(1)
                else:
                    valid.append(0)
        if len(tokens) >= max_seq_length - 1:
            tokens = tokens[0:(max_seq_length - 2)]
            polarities = polarities[0:(max_seq_length - 2)]
            emotions = emotions[0:(max_seq_length - 2)]
            labels = labels[0:(max_seq_length - 2)]
            valid = valid[0:(max_seq_length - 2)]
            label_mask = label_mask[0:(max_seq_length - 2)]
        ntokens = []
        segment_ids = []
        label_ids = []
        ntokens.append("[CLS]")
        segment_ids.append(0)
        valid.insert(0, 1)
        label_mask.insert(0, 1)
        label_ids.append(label_map["[CLS]"])
        for i, token in enumerate(tokens):
            ntokens.append(token)
            segment_ids.append(0)
            if len(labels) > i:
                label_ids.append(label_map[labels[i]])
        ntokens.append("[SEP]")
        segment_ids.append(0)
        valid.append(1)
        label_mask.append(1)
        label_ids.append(label_map["[SEP]"])
        input_ids_spc = tokenizer.convert_tokens_to_ids(ntokens)
        input_mask = [1] * len(input_ids_spc)
        label_mask = [1] * len(label_ids)
        while len(input_ids_spc) < max_seq_length:
            input_ids_spc.append(0)
            input_mask.append(0)
            segment_ids.append(0)
            label_ids.append(0)
            valid.append(1)
            label_mask.append(0)
        while len(label_ids) < max_seq_length:
            label_ids.append(0)
            label_mask.append(0)
        while len(polarities) < max_seq_length:
            polarities.append(-1)
        while len(emotions) < max_seq_length:
            emotions.append(-1)
        assert len(input_ids_spc) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length
        assert len(valid) == max_seq_length
        assert len(label_mask) == max_seq_length
        assert len(emotions) == max_seq_length
        features.append(
            InputFeatures(input_ids_spc=input_ids_spc,
                          input_mask=input_mask,
                          segment_ids=segment_ids,
                          label_id=label_ids,
                          polarities=polarities,
                          emotions=emotions,
                          valid_ids=valid,
                          label_mask=label_mask))

    return features


In [ ]:
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 22 19:41:55 2017

@author: Quantum Liu
"""
'''
Example:
gm=GPUManager()
with torch.cuda.device(gm.auto_choice()):
    blabla

Or:
gm=GPUManager()
torch.cuda.set_device(gm.auto_choice())
'''

import os
import torch


def check_gpus():
    '''
    GPU available check
    http://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-cuda/
    '''
    if not torch.cuda.is_available():
        print('This script could only be used to manage NVIDIA GPUs,but no GPU found in your device')
        return False
    elif not 'NVIDIA System Management' in os.popen('nvidia-smi -h').read():
        print("'nvidia-smi' tool not found.")
        return False
    return True


if check_gpus():
    def parse(line, qargs):
        '''
        line:
            a line of text
        qargs:
            query arguments
        return:
            a dict of gpu infos
        Pasing a line of csv format text returned by nvidia-smi
        解析一行nvidia-smi返回的csv格式文本
        '''
        numberic_args = ['memory.free', 'memory.total', 'power.draw', 'power.limit']  # 可计数的参数
        power_manage_enable = lambda v: (not 'Not Support' in v)  # lambda表达式，显卡是否滋瓷power management（笔记本可能不滋瓷）
        to_numberic = lambda v: float(v.upper().strip().replace('MIB', '').replace('W', ''))  # 带单位字符串去掉单位
        process = lambda k, v: (
            (int(to_numberic(v)) if power_manage_enable(v) else 1) if k in numberic_args else v.strip())
        return {k: process(k, v) for k, v in zip(qargs, line.strip().split(','))}


    def query_gpu(qargs=[]):
        '''
        qargs:
            query arguments
        return:
            a list of dict
        Querying GPUs infos
        查询GPU信息
        '''
        qargs = ['index', 'gpu_name', 'memory.free', 'memory.total', 'power.draw', 'power.limit'] + qargs
        cmd = 'nvidia-smi --query-gpu={} --format=csv,noheader'.format(','.join(qargs))
        results = os.popen(cmd).readlines()
        return [parse(line, qargs) for line in results]


    def by_power(d):
        '''
        helper function fo sorting gpus by power
        '''
        power_infos = (d['power.draw'], d['power.limit'])
        if any(v == 1 for v in power_infos):
            print('Power management unable for GPU {}'.format(d['index']))
            return 1
        return float(d['power.draw']) / d['power.limit']


    class GPUManager():
        '''
        qargs:
            query arguments
        A manager which can list all available GPU devices
        and sort them and choice the most free one.Unspecified
        ones pref.
        GPU设备管理器，考虑列举出所有可用GPU设备，并加以排序，自动选出
        最空闲的设备。在一个GPUManager对象内会记录每个GPU是否已被指定，
        优先选择未指定的GPU。
        '''

        def __init__(self, qargs=[]):
            '''
            '''
            self.qargs = qargs
            self.gpus = query_gpu(qargs)
            for gpu in self.gpus:
                gpu['specified'] = False
            self.gpu_num = len(self.gpus)

        def _sort_by_memory(self, gpus, by_size=False):
            if by_size:
                print('Sorted by free memory size')
                return sorted(gpus, key=lambda d: d['memory.free'], reverse=True)
            else:
                print('Sorted by free memory rate')
                return sorted(gpus, key=lambda d: float(d['memory.free']) / d['memory.total'], reverse=True)

        def _sort_by_power(self, gpus):
            return sorted(gpus, key=by_power)

        def _sort_by_custom(self, gpus, key, reverse=False, qargs=[]):
            if isinstance(key, str) and (key in qargs):
                return sorted(gpus, key=lambda d: d[key], reverse=reverse)
            if isinstance(key, type(lambda a: a)):
                return sorted(gpus, key=key, reverse=reverse)
            raise ValueError(
                "The argument 'key' must be a function or a key in query args,please read the documention of nvidia-smi")

        def auto_choice(self, mode=0):
            '''
            mode:
                0:(default)sorted by free memory size
            return:
                a TF device object
            Auto choice the freest GPU device,not specified
            ones
            自动选择最空闲GPU,返回索引
            '''
            for old_infos, new_infos in zip(self.gpus, query_gpu(self.qargs)):
                old_infos.update(new_infos)
            unspecified_gpus = [gpu for gpu in self.gpus if not gpu['specified']] or self.gpus

            if mode == 0:
                print('Choosing the GPU device has largest free memory...')
                chosen_gpu = self._sort_by_memory(unspecified_gpus, True)[0]
            elif mode == 1:
                print('Choosing the GPU device has highest free memory rate...')
                chosen_gpu = self._sort_by_power(unspecified_gpus)[0]
            elif mode == 2:
                print('Choosing the GPU device by power...')
                chosen_gpu = self._sort_by_power(unspecified_gpus)[0]
            else:
                print('Given an unaviliable mode,will be chosen by memory')
                chosen_gpu = self._sort_by_memory(unspecified_gpus)[0]
            chosen_gpu['specified'] = True
            index = chosen_gpu['index']
            print('Using GPU {i}:\n{info}'.format(i=index, info='\n'.join(
                [str(k) + ':' + str(v) for k, v in chosen_gpu.items()])))
            return int(index)
else:
    raise ImportError('GPU available check failed')


In [ ]:
import json
import logging
import os, sys
import random
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score

from time import strftime, localtime

import numpy as np
import torch
from torch import device
from transformers.optimization import AdamW
from transformers.models.bert.modeling_bert import BertModel
from transformers import BertTokenizer
from seqeval.metrics import classification_report
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)

from utils.data_utils import ATEPCProcessor, convert_examples_to_features
from model.lcf_atepc import LCF_ATEPC

# Define hardcoded values
max_seq_length = 128
dropout = 0.1
SRD = 3
use_unique_bert = False
local_context_focus = 'cdm'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Other hardcoded values
dataset = 'your_dataset'
output_dir = 'your_output_dir'
learning_rate = 0.001
num_train_epochs = 10
train_batch_size = 32
eval_batch_size = 32
eval_steps = 20
gradient_accumulation_steps = 1
config_path = 'experiments.json'

# ...

def main():
    # ...

    random.seed(1)
    np.random.seed(1)
    torch.manual_seed(1)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    processor = ATEPCProcessor()
    label_list = processor.get_labels()
    num_labels = len(label_list) + 1

    bert_model = "bert-base-uncased"
    data_dir = "atepc_datasets/restaurant"

    tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)
    train_examples = processor.get_train_examples(data_dir)
    eval_examples = processor.get_test_examples(data_dir)
    num_train_optimization_steps = int(
        len(train_examples) / train_batch_size / gradient_accumulation_steps) * num_train_epochs
    bert_base_model = BertModel.from_pretrained(bert_model)
    bert_base_model.config.num_labels = num_labels

    model = LCF_ATEPC(bert_base_model, max_seq_length=max_seq_length, dropout=dropout, SRD=SRD, 
                      use_unique_bert=use_unique_bert, local_context_focus=local_context_focus, device=device)

    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.00001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.00001}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=0.00001)
    eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length,
                                                 tokenizer)
    all_spc_input_ids = torch.tensor([f.input_ids_spc for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
    all_polarities = torch.tensor([f.polarities for f in eval_features], dtype=torch.long)
    all_valid_ids = torch.tensor([f.valid_ids for f in eval_features], dtype=torch.long)
    all_lmask_ids = torch.tensor([f.label_mask for f in eval_features], dtype=torch.long)
    all_emotions = torch.tensor([f.emotions for f in eval_features], dtype=torch.long)
    eval_data = TensorDataset(all_spc_input_ids, all_input_mask, all_segment_ids, all_label_ids,
                            all_polarities, all_valid_ids, all_lmask_ids, all_emotions)  # Modify this line
    # Run prediction for full data
    eval_sampler = RandomSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

    def evaluate(eval_ATE=True, eval_APC=True, eval_emotion=True):
        apc_result = {'max_apc_test_acc': 0, 'max_apc_test_f1': 0}
        ate_result = 0
        emotion_result = {'max_emotion_test_acc': 0, 'max_emotion_test_f1': 0}
        y_true = []
        y_pred = []
        n_test_correct, n_test_total = 0, 0
        test_apc_logits_all, test_polarities_all = None, None
        test_emotion_logits_all, test_emotions_all = None, None
        model.eval()
        label_map = {i: label for i, label in enumerate(label_list, 1)}
        for input_ids_spc, input_mask, segment_ids, label_ids, polarities, valid_ids, l_mask, emotions in eval_dataloader:
            input_ids_spc = input_ids_spc.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            valid_ids = valid_ids.to(device)
            label_ids = label_ids.to(device)
            polarities = polarities.to(device)
            l_mask = l_mask.to(device)
            emotions = emotions.to(device)

            with torch.no_grad():
                ate_logits, apc_logits, emotion_logits = model(input_ids_spc, segment_ids, input_mask,
                                                               valid_ids=valid_ids, polarities=polarities,
                                                               attention_mask_label=l_mask, emotions=emotions)
            if eval_APC:
                polarities = model.get_batch_polarities(polarities)
                n_test_correct += (torch.argmax(apc_logits, -1) == polarities).sum().item()
                n_test_total += len(polarities)

                if test_polarities_all is None:
                    test_polarities_all = polarities
                    test_apc_logits_all = apc_logits
                else:
                    test_polarities_all = torch.cat((test_polarities_all, polarities), dim=0)
                    test_apc_logits_all = torch.cat((test_apc_logits_all, apc_logits), dim=0)

            if eval_emotion:
                emotions = model.get_batch_emotions(emotions)
                n_test_correct += (torch.argmax(emotion_logits, -1) == emotions).sum().item()
                n_test_total += len(emotions)

                if test_emotions_all is None:
                    test_emotions_all = emotions
                    test_emotion_logits_all = emotion_logits
                else:
                    test_emotions_all = torch.cat((test_emotions_all, emotions), dim=0)
                    test_emotion_logits_all = torch.cat((test_emotion_logits_all, emotion_logits), dim=0)

            if eval_ATE:
                # Assuming ATE evaluation is based on F1 score
                ate_f1 = f1_score(label_ids.cpu(), torch.argmax(ate_logits, -1).cpu(), average='macro')
                ate_result = max(ate_result, ate_f1)

        if eval_APC:
            # Assuming APC evaluation is based on accuracy
            apc_acc = accuracy_score(test_polarities_all.cpu(), torch.argmax(test_apc_logits_all, -1).cpu())
            apc_f1 = f1_score(test_polarities_all.cpu(), torch.argmax(test_apc_logits_all, -1).cpu(), average='macro')
            apc_result = {'max_apc_test_acc': apc_acc, 'max_apc_test_f1': apc_f1}

        if eval_emotion:
            # Assuming emotion evaluation is based on accuracy
            emotion_acc = accuracy_score(test_emotions_all.cpu(), torch.argmax(test_emotion_logits_all, -1).cpu())
            emotion_f1 = f1_score(test_emotions_all.cpu(), torch.argmax(test_emotion_logits_all, -1).cpu(),
                                  average='macro')
            emotion_result = {'max_emotion_test_acc': emotion_acc, 'max_emotion_test_f1': emotion_f1}

        return apc_result, ate_result, emotion_result

    def save_model(path):
        # Save a trained model and the associated configuration,
        # Take care of the storage!
        os.makedirs(path, exist_ok=True)
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        model_to_save.save_pretrained(path)
        tokenizer.save_pretrained(path)
        label_map = {i: label for i, label in enumerate(label_list, 1)}
        model_config = {"bert_model": args.bert_model, "do_lower": True, "max_seq_length": args.max_seq_length,
                        "num_labels": len(label_list) + 1, "label_map": label_map}
        json.dump(model_config, open(os.path.join(path, "config.json"), "w"))
        logger.info('save model to: {}'.format(path))

    def train():
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_spc_input_ids = torch.tensor([f.input_ids_spc for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        all_valid_ids = torch.tensor([f.valid_ids for f in train_features], dtype=torch.long)
        all_lmask_ids = torch.tensor([f.label_mask for f in train_features], dtype=torch.long)
        all_polarities = torch.tensor([f.polarities for f in train_features], dtype=torch.long)
        all_emotions = torch.tensor([f.emotions for f in train_features], dtype=torch.long)
        # print("Shape of all_spc_input_ids: ", all_spc_input_ids.shape)
        # print("Shape of all_input_mask: ", all_input_mask.shape)
        # print("Shape of all_segment_ids: ", all_segment_ids.shape)
        # print("Shape of all_label_ids: ", all_label_ids.shape)
        # print("Shape of all_valid_ids: ", all_valid_ids.shape)
        # print("Shape of all_lmask_ids: ", all_lmask_ids.shape)
        # print("Shape of all_polarities: ", all_polarities.shape)
        # print("Shape of all_emotions: ", all_emotions.shape)
        train_data = TensorDataset(all_spc_input_ids, all_input_mask, all_segment_ids,
                                   all_label_ids, all_polarities, all_valid_ids, all_lmask_ids,
                                   all_emotions)

        train_sampler = SequentialSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
        max_apc_test_acc = 0
        max_apc_test_f1 = 0
        max_ate_test_f1 = 0
        max_emotion_test_acc = 0  # Add this line
        max_emotion_test_f1 = 0  # Add this line

        global_step = 0
        for epoch in range(int(args.num_train_epochs)):
            logger.info('#' * 80)
            logger.info('Train {} Epoch{}'.format(args.seed, epoch + 1, args.data_dir))
            logger.info('#' * 80)
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids_spc, input_mask, segment_ids, label_ids, polarities, valid_ids, l_mask, emotions = batch
                # print(f"Shape of input_ids_spc: {input_ids_spc.shape}")
                # print(f"Shape of segment_ids: {segment_ids.shape}")
                # print(f"Shape of input_mask: {input_mask.shape}")
                # print(f"Shape of label_ids: {label_ids.shape}")
                # print(f"Shape of polarities: {polarities.shape}")
                # print(f"Shape of valid_ids: {valid_ids.shape}")
                # print(f"Shape of l_mask: {l_mask.shape}")
                # print(f"Shape of emotions: {emotions.shape}")

                loss = torch.tensor(model(input_ids_spc, segment_ids, input_mask, label_ids, polarities,
                                          valid_ids, l_mask, emotions), requires_grad=True)
                loss.backward()
                nb_tr_examples += input_ids_spc.size(0)
                nb_tr_steps += 1
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                if global_step % args.eval_steps == 0:
                    if epoch >= args.num_train_epochs - 2 or args.num_train_epochs <= 2:
                        # evaluate in last 2 epochs
                        apc_result, ate_result, emotion_result = evaluate(eval_ATE=not args.use_bert_spc,
                                                                          eval_emotion=True)
                        if apc_result['max_apc_test_acc'] > max_apc_test_acc:
                            max_apc_test_acc = apc_result['max_apc_test_acc']
                        if apc_result['max_apc_test_f1'] > max_apc_test_f1:
                            max_apc_test_f1 = apc_result['max_apc_test_f1']
                        if ate_result > max_ate_test_f1:
                            max_ate_test_f1 = ate_result
                        if emotion_result['max_emotion_test_acc'] > max_emotion_test_acc:  # Add this line
                            max_emotion_test_acc = emotion_result['max_emotion_test_acc']  # Add this line
                        if emotion_result['max_emotion_test_f1'] > max_emotion_test_f1:  # Add this line
                            max_emotion_test_f1 = emotion_result['max_emotion_test_f1']  # Add this line

                        current_apc_test_acc = apc_result['max_apc_test_acc']
                        current_apc_test_f1 = apc_result['max_apc_test_f1']
                        current_ate_test_f1 = round(ate_result, 2)
                        current_emotion_test_acc = emotion_result['max_emotion_test_acc']  # Add this line
                        current_emotion_test_f1 = emotion_result['max_emotion_test_f1']  # Add this line

                        logger.info('*' * 80)
                        logger.info('Train {} Epoch{}, Evaluate for {}'.format(args.seed, epoch + 1, args.data_dir))
                        logger.info(f'APC_test_acc: {current_apc_test_acc}(max: {max_apc_test_acc})  '
                                    f'APC_test_f1: {current_apc_test_f1}(max: {max_apc_test_f1})')
                        if args.use_bert_spc:
                            logger.info(f'ATE_test_F1: {current_apc_test_f1}(max: {max_apc_test_f1})'
                                        f' (Unreliable since `use_bert_spc` is "True".)')
                        else:
                            logger.info(f'ATE_test_f1: {current_ate_test_f1}(max:{max_ate_test_f1})')
                        logger.info(
                            f'Emotion_test_acc: {current_emotion_test_acc}(max: {max_emotion_test_acc})  '  # Add this line
                            f'Emotion_test_f1: {current_emotion_test_f1}(max: {max_emotion_test_f1})')  # Add this line
                        logger.info('*' * 80)
        return [max_apc_test_acc, max_apc_test_f1, max_ate_test_f1, max_emotion_test_acc,
                max_emotion_test_f1]

    return train()



In [ ]:
if __name__ == "__main__":
    device = torch.device("cpu")
    n = 5
    results = []
    max_apc_test_acc, max_apc_test_f1, max_ate_test_f1 = 0, 0, 0
    for i in range(n):
        seed = i + 1
        logger.info('No.{} training process of {}'.format(i + 1, n))
        apc_test_acc, apc_test_f1, ate_test_f1 = main()
        if apc_test_acc > max_apc_test_acc:
                max_apc_test_acc = apc_test_acc
            if apc_test_f1 > max_apc_test_f1:
                max_apc_test_f1 = apc_test_f1
            if ate_test_f1 > max_ate_test_f1:
                max_ate_test_f1 = ate_test_f1
            logger.info('max_ate_test_f1:{} max_apc_test_acc: {}\tmax_apc_test_f1: {} \t'
                        .format(max_ate_test_f1, max_apc_test_acc, max_apc_test_f1))


