# Data functions

In [3]:
# !pip install pandas

In [4]:
import os
import copy
import numpy as np
import torch
from torch.utils.data import Dataset
# !pip install pytorch_transformers
from pytorch_transformers import BertTokenizer


def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0):
    x = (np.ones(maxlen) * value).astype(dtype)
    if truncating == 'pre':
        trunc = sequence[-maxlen:]
    else:
        trunc = sequence[:maxlen]
    trunc = np.asarray(trunc, dtype=dtype)
    if padding == 'post':
        x[:len(trunc)] = trunc
    else:
        x[-len(trunc):] = trunc
    return x


class Tokenizer4Bert:
    def __init__(self, max_seq_len, pretrained_bert_name):
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_name)
        self.max_seq_len = max_seq_len

    def text_to_sequence(self, text, reverse=False, padding='post', truncating='post'):
        sequence = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
        if len(sequence) == 0:
            sequence = [0]
        if reverse:
            sequence = sequence[::-1]
        return pad_and_truncate(sequence, self.max_seq_len, padding=padding, truncating=truncating)

    def id_to_sequence(self, sequence, reverse=False, padding='post', truncating='post'):
        if len(sequence) == 0:
            sequence = [0]
        if reverse:
            sequence = sequence[::-1]
        return pad_and_truncate(sequence, self.max_seq_len, padding=padding, truncating=truncating)

class DepInstanceParser():
    def __init__(self, basicDependencies, tokens):
        self.basicDependencies = basicDependencies
        self.tokens = tokens
        self.words = []
        self.dep_governed_info = []
        self.dep_parsing()


    def dep_parsing(self):
        if len(self.tokens) > 0:
            words = []
            for token in self.tokens:
                token['word'] = token
                words.append(self.change_word(token['word']))
            dep_governed_info = [
                {"word": word}
                for i,word in enumerate(words)
            ]
            self.words = words
        else:
            dep_governed_info = [{}] * len(self.basicDependencies)
        for dep in self.basicDependencies:
            dependent_index = dep['dependent'] - 1
            governed_index = dep['governor'] - 1
            dep_governed_info[dependent_index] = {
                "governor": governed_index,
                "dep": dep['dep']
            }
        self.dep_governed_info = dep_governed_info

    def change_word(self, word):
        if "-RRB-" in word:
            return word.replace("-RRB-", ")")
        if "-LRB-" in word:
            return word.replace("-LRB-", "(")
        return word

    def get_first_order(self, direct=False):
        dep_adj_matrix  = [[0] * len(self.dep_governed_info) for _ in range(len(self.dep_governed_info))]
        dep_type_matrix = [["none"] * len(self.dep_governed_info) for _ in range(len(self.dep_governed_info))]
        # for i in range(len(self.dep_governed_info)):
        #     dep_adj_matrix[i][i]  = 1
        #     dep_type_matrix[i][i] = "self_loop"
        for i, dep_info in enumerate(self.dep_governed_info):
            governor = dep_info["governor"]
            dep_type = dep_info["dep"]
            dep_adj_matrix[i][governor] = 1
            dep_adj_matrix[governor][i] = 1
            dep_type_matrix[i][governor] = dep_type if direct is False else "{}_in".format(dep_type)
            dep_type_matrix[governor][i] = dep_type if direct is False else "{}_out".format(dep_type)
        return dep_adj_matrix, dep_type_matrix

    def get_next_order(self, dep_adj_matrix, dep_type_matrix):
        new_dep_adj_matrix = copy.deepcopy(dep_adj_matrix)
        new_dep_type_matrix = copy.deepcopy(dep_type_matrix)
        for target_index in range(len(dep_adj_matrix)):
            for first_order_index in range(len(dep_adj_matrix[target_index])):
                if dep_adj_matrix[target_index][first_order_index] == 0:
                    continue
                for second_order_index in range(len(dep_adj_matrix[first_order_index])):
                    if dep_adj_matrix[first_order_index][second_order_index] == 0:
                        continue
                    if second_order_index == target_index:
                        continue
                    if new_dep_adj_matrix[target_index][second_order_index] == 1:
                        continue
                    new_dep_adj_matrix[target_index][second_order_index] = 1
                    new_dep_type_matrix[target_index][second_order_index] = dep_type_matrix[first_order_index][second_order_index]
        return new_dep_adj_matrix, new_dep_type_matrix

    def get_second_order(self, direct=False):
        dep_adj_matrix, dep_type_matrix = self.get_first_order(direct=direct)
        return self.get_next_order(dep_adj_matrix, dep_type_matrix)

    def get_third_order(self, direct=False):
        dep_adj_matrix, dep_type_matrix = self.get_second_order(direct=direct)
        return self.get_next_order(dep_adj_matrix, dep_type_matrix)

    def search_dep_path(self, start_idx, end_idx, adj_max, dep_path_arr):
        for next_id in range(len(adj_max[start_idx])):
            if next_id in dep_path_arr or adj_max[start_idx][next_id] in ["none"]:
                continue
            if next_id == end_idx:
                return 1, dep_path_arr + [next_id]
            stat, dep_arr = self.search_dep_path(next_id, end_idx, adj_max, dep_path_arr + [next_id])
            if stat == 1:
                return stat, dep_arr
        return 0, []

    def get_dep_path(self, start_index, end_index, direct=False):
        dep_adj_matrix, dep_type_matrix = self.get_first_order(direct=direct)
        _, dep_path = self.search_dep_path(start_index, end_index, dep_type_matrix, [start_index])
        return dep_path

class ABSADataset(Dataset):
    def __init__(self, datafile, tokenizer, opt, deptype2id=None, dep_order="first"):
        self.datafile = datafile
        self.depfile = "{}.dep".format(datafile)
        self.tokenizer = tokenizer
        self.opt = opt
        self.deptype2id = deptype2id
        self.dep_order = dep_order
        self.textdata = ABSADataset.load_datafile(self.datafile)
        self.depinfo = ABSADataset.load_depfile(self.depfile)
        self.polarity2id = self.get_polarity2id()
        self.feature = []
        for sentence,depinfo in zip(self.textdata, self.depinfo):
            self.feature.append(self.create_feature(sentence, depinfo, opt.print_sent))
        print(self.feature[:1])

    def __getitem__(self, index):
        return self.feature[index]

    def __len__(self):
        return len(self.feature)

    def ws(self, text):
        tokens = []
        valid_ids = []
        for i, word in enumerate(text):
            if len(text) <= 0:
                continue
            token = self.tokenizer.tokenizer.tokenize(word)
            tokens.extend(token)
            for m in range(len(token)):
                if m == 0:
                    valid_ids.append(1)
                else:
                    valid_ids.append(0)
        token_ids = self.tokenizer.tokenizer.convert_tokens_to_ids(tokens)
        return tokens, token_ids, valid_ids

    def create_feature(self, sentence, depinfo, print_sent = False):
        text_left, text_right, aspect, polarity = sentence

        cls_id = self.tokenizer.tokenizer.vocab["[CLS]"]
        sep_id = self.tokenizer.tokenizer.vocab["[SEP]"]

        doc = text_left + " " + aspect + " " + text_right

        left_tokens, left_token_ids, left_valid_ids = self.ws(text_left.split(" "))
        right_tokens, right_token_ids, right_valid_ids = self.ws(text_right.split(" "))
        aspect_tokens, aspect_token_ids, aspect_valid_ids = self.ws(aspect.split(" "))
        tokens = left_tokens + aspect_tokens + right_tokens
        input_ids = [cls_id] + left_token_ids + aspect_token_ids + right_token_ids + [sep_id] + aspect_token_ids + [sep_id]
        valid_ids = [1] + left_valid_ids + aspect_valid_ids + right_valid_ids + [1] + aspect_valid_ids + [1]
        mem_valid_ids = [0] + [0] * len(left_tokens) + [1] * len(aspect_tokens) + [0] * len(right_tokens) # aspect terms mask
        segment_ids = [0] * (len(tokens) + 2) + [1] * (len(aspect_tokens)+1)

        dep_instance_parser = DepInstanceParser(basicDependencies=depinfo, tokens=[])
        if self.dep_order == "first":
            dep_adj_matrix, dep_type_matrix = dep_instance_parser.get_first_order()
        elif self.dep_order == "second":
            dep_adj_matrix, dep_type_matrix = dep_instance_parser.get_second_order()
        elif self.dep_order == "third":
            dep_adj_matrix, dep_type_matrix = dep_instance_parser.get_third_order()
        else:
            raise ValueError()

        token_head_list = []
        for input_id, valid_id in zip(input_ids, valid_ids):
            if input_id == cls_id:
                continue
            if input_id == sep_id:
                break
            if valid_id == 1:
                token_head_list.append(input_id)

        input_ids = self.tokenizer.id_to_sequence(input_ids)
        valid_ids = self.tokenizer.id_to_sequence(valid_ids)
        segment_ids = self.tokenizer.id_to_sequence(segment_ids)
        mem_valid_ids = self.tokenizer.id_to_sequence(mem_valid_ids)

        size = input_ids.shape[0]
        
        if print_sent:
            print(doc)
            print(len(dep_adj_matrix[0]))

        # final_dep_adj_matrix = [[0] * size for _ in range(self.tokenizer.max_seq_len)]
        # final_dep_value_matrix = [[0] * size for _ in range(self.tokenizer.max_seq_len)]
        final_dep_adj_matrix = [[0] * size for _ in range(size)]
        final_dep_value_matrix = [[0] * size for _ in range(size)]
        for i in range(len(token_head_list)):
            for j in range(len(dep_adj_matrix[i])):
                if j >= size:
                    break
                final_dep_adj_matrix[i+1][j] = dep_adj_matrix[i][j]
                final_dep_value_matrix[i+1][j] = self.deptype2id[dep_type_matrix[i][j]]

        return {
            "input_ids":torch.tensor(input_ids),
            "valid_ids":torch.tensor(valid_ids),
            "segment_ids":torch.tensor(segment_ids),
            "mem_valid_ids":torch.tensor(mem_valid_ids),
            "dep_adj_matrix":torch.tensor(final_dep_adj_matrix),
            "dep_value_matrix":torch.tensor(final_dep_value_matrix),
            "polarity": self.polarity2id[polarity],
            "raw_text": doc,
            "aspect": aspect
        }


    @staticmethod
    def load_depfile(filename):
        data = []
        with open(filename, 'r') as f:
            dep_info = []
            for line in f:
                line = line.strip()
                if len(line) > 0:
                    items = line.split("\t")
                    dep_info.append({
                        "governor": int(items[0]),
                        "dependent": int(items[1]),
                        "dep": items[2],
                    })
                else:
                    if len(dep_info) > 0:
                        data.append(dep_info)
                        dep_info = []
            if len(dep_info) > 0:
                data.append(dep_info)
                dep_info = []
        return data

    @staticmethod
    def load_datafile(filename):
        data = []
        with open(filename, 'r') as f:
            lines = f.readlines()
            for i in range(0, len(lines), 3):
                text_left, _, text_right = [s.lower().strip() for s in lines[i].partition("$T$")]
                aspect = lines[i + 1].lower().strip()
                text_right = text_right.replace("$T$", aspect)
                polarity = lines[i + 2].strip()
                data.append([text_left, text_right, aspect, polarity])

        return data

    @staticmethod
    def load_deptype_map(opt):
        deptype_set = set()
        for filename in [opt.train_file, opt.test_file, opt.val_file]:
            filename = "{}.dep".format(filename)
            if os.path.exists(filename) is False:
                continue
            data = ABSADataset.load_depfile(filename)
            for dep_info in data:
                for item in dep_info:
                    deptype_set.add(item['dep'])
        deptype_map = {"none": 0}
        for deptype in sorted(deptype_set, key=lambda x:x):
            deptype_map[deptype] = len(deptype_map)
        return deptype_map

    @staticmethod
    def get_polarity2id():
        polarity_label = ["-1","0","1"]
        return dict([(label, idx) for idx,label in enumerate(polarity_label)])

# TGCN Model

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from pytorch_transformers import BertPreTrainedModel,BertModel

class GraphConvolution(nn.Module):
    """
    Simple GCN layer
    """
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, text, adj):
        hidden = torch.matmul(text, self.weight)
        denom = torch.sum(adj, dim=2, keepdim=True) + 1
        output = torch.matmul(adj, hidden) / denom
        if self.bias is not None:
            return output + self.bias
        else:
            return output

class TypeGraphConvolution(nn.Module):
    """
    TGCN Layer
    """
    def __init__(self, in_features, out_features, embedding_dim, bias=True):
        super(TypeGraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.dense = nn.Linear(embedding_dim, in_features, bias=False)
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, text, adj, dep_embed):
        batch_size, max_len, feat_dim = text.shape
        val_us = text.unsqueeze(dim=2)
        val_us = val_us.repeat(1, 1, max_len, 1)
        val_sum = val_us + self.dense(dep_embed)
        adj_us = adj.unsqueeze(dim=-1)
        adj_us = adj_us.repeat(1, 1, 1, feat_dim)
        hidden = torch.matmul(val_sum, self.weight)
        output = hidden.transpose(1,2) * adj_us

        output = torch.sum(output, dim=2)

        if self.bias is not None:
            return output + self.bias
        else:
            return output
        
class SemGraphConvolution(nn.Module):
    """
    Semantic GCN layer with attention adjacency matrix 
    """
    def __init__(self, in_features, out_features, attention_heads = 1, bias=True):
        super(SemGraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, text, adj):
        hidden = torch.matmul(text, self.weight)
        denom = torch.sum(adj, dim=2, keepdim=True) + 1
        output = torch.matmul(adj, hidden) / denom
        if self.bias is not None:
            return output + self.bias
        else:
            return output
        
class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, x):
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention = self.softmax(scores)
        return attention


class MultiHeadAttention(nn.Module):

    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0

        self.d_k = d_model // h
        self.h = h
        self.linears = self.clones(nn.Linear(d_model, d_model), 2)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, mask=None):
        if mask is not None:
            mask = mask[:, :, :query.size(1)]
            mask = mask.unsqueeze(1)
            
        nbatches = query.size(0)
        query, key = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linears, (query, key))]
        
        attn = self.attention(query, key, mask=mask, dropout=self.dropout)

        return attn
    

    def attention(self, query, key, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)
        if dropout is not None:
            p_attn = dropout(p_attn)

        return p_attn
    
    def clones(self, module, N):
        return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
        

class AsaTgcnSem(BertPreTrainedModel):
    def __init__(self, config, modules, tokenizer, opt):
#     use_ensemble = True, fusion_type = 'concat', dropout = 0.2, concat_dropout = 0.5,
#                  cooc_path = 'cooc_matrix_ids.csv', cooc = None):
        """
        modules: dictionary of form {'tgcn': bool, 'semgcn': bool, 'lexgcn': bool}
        cooc: cooc matrix as dataframe preloaded into memory. if not passed as argument,
        the matrix will be loaded from the specified path.
        """
        
        super(AsaTgcnSem, self).__init__(config)
        self.opt = opt
        self.modules = opt.modules
        self.use_tgcn, self.use_semgcn, self.use_lexgcn = opt.modules['tgcn'], opt.modules['semgcn'], opt.modules['lexgcn']
        self.num_modules = sum((self.use_tgcn, self.use_semgcn, self.use_lexgcn))
        self.use_ensemble = opt.use_ensemble
        self.layer_number_tgcn = opt.num_layers['tgcn']
        self.layer_number_sem = opt.num_layers['semgcn']
        self.layer_number_lex = opt.num_layers['lexgcn']
        assert self.use_tgcn or self.use_semgcn or self.use_lexgcn
        assert opt.fusion_type == 'concat' or opt.fusion_type == 'gate'
        self.fusion_type = opt.fusion_type
        
        self.num_labels = config.num_labels
        self.num_types = config.num_types
        
        
#         self.modules = modules
#         self.use_tgcn, self.use_semgcn, self.use_lexgcn = modules['tgcn'], modules['semgcn'], modules['lexgcn']
#         self.num_modules = sum((self.use_tgcn, self.use_semgcn, self.use_lexgcn))
#         assert self.use_tgcn or self.use_semgcn or self.use_lexgcn
#         assert fusion_type == 'concat' or fusion_type == 'gate'
#         self.fusion_type = fusion_type
#         self.config = config
#         self.layer_number_tgcn = 3
#         self.layer_number_sem = 2
#         self.layer_number_lex = 2
#         self.num_labels = config.num_labels
#         self.num_types = config.num_types
#         self.use_ensemble = use_ensemble
        
        self.bert = BertModel(config)
        
        if self.use_tgcn:
            self.TGCNLayers = nn.ModuleList(([TypeGraphConvolution(config.hidden_size, config.hidden_size, config.hidden_size)
                                             for _ in range(self.layer_number_tgcn)]))
        if self.use_semgcn:
            self.SemGCNLayers = nn.ModuleList(([GraphConvolution(config.hidden_size, config.hidden_size)
                                            for _ in range(self.layer_number_sem)]))
        if self.use_lexgcn:
            self.LexGCNLayers = nn.ModuleList(([GraphConvolution(config.hidden_size, config.hidden_size)
                                           for _ in range(self.layer_number_lex)]))
        
        if self.use_lexgcn:
            if opt.cooc is not None:
                self.cooc = opt.cooc
            else:
                self.cooc = pd.read_csv(opt.cooc_path, index_col=0)
            self.cooc.index = self.cooc.index.astype(int)
            self.cooc.columns = self.cooc.columns.astype(int)
            
            # Obsolete
#             self.cooc_matrix = self.cooc.to_numpy()
#             # Padding with 0's to deal with out-of-vocabulary words in test data
#             self.cooc_matrix = np.pad(self.cooc_matrix, ((0, 1), (0, 1)), mode='constant')
#             # Mapping token ids to indices of matrix
#             self.id_to_index_map = {tokenizer.tokenizer.convert_tokens_to_ids(tokenizer.tokenizer.tokenize(w))[0]: i for i, w in enumerate(self.cooc.columns)}
        
        # multiplied by two if concat
        if self.fusion_type == 'concat':
            self.fc_single = nn.Linear(config.hidden_size*self.num_modules, self.num_labels)
        elif self.fusion_type == 'gate':
            self.fc_single = nn.Linear(config.hidden_size, self.num_labels)
        
        self.gate_weight = nn.Parameter(torch.FloatTensor(config.hidden_size, config.hidden_size * 2))
        self.gate_bias = nn.Parameter(torch.FloatTensor(config.hidden_size))
    
        self.dropout = nn.Dropout(opt.dropout)
        self.concat_dropout = nn.Dropout(opt.concat_dropout)
        self.ensemble_linear_tgcn = nn.Linear(1, self.layer_number_tgcn)
        self.ensemble_linear_semgcn = nn.Linear(1, self.layer_number_sem)
        self.ensemble_linear_lexgcn = nn.Linear(1, self.layer_number_lex)
        self.ensemble = nn.Parameter(torch.FloatTensor(3, 1))
        self.dep_embedding = nn.Embedding(self.num_types, config.hidden_size, padding_idx=0)

    def get_attention(self, val_out, dep_embed, adj):
        batch_size, max_len, feat_dim = val_out.shape
        val_us = val_out.unsqueeze(dim=2)
        val_us = val_us.repeat(1,1,max_len,1)
        val_cat = torch.cat((val_us, dep_embed), -1).float()
        atten_expand = (val_cat * val_cat.transpose(1,2))

        attention_score = torch.sum(atten_expand, dim=-1)
        attention_score = attention_score / np.power(feat_dim, 0.5)
        exp_attention_score = torch.exp(attention_score)
        exp_attention_score = torch.mul(exp_attention_score, adj.float()) # mask
        sum_attention_score = torch.sum(exp_attention_score, dim=-1).unsqueeze(dim=-1).repeat(1,1,max_len)

        attention_score = torch.div(exp_attention_score, sum_attention_score + 1e-10)
        if 'HalfTensor' in val_out.type():
            attention_score = attention_score.half()

        return attention_score
    
    def get_lex_adj(self, input_ids, batch_size, max_len):
        # Initialize an empty adjacency tensor
        adj_tensor = torch.zeros((batch_size, max_len, max_len))
        
#         for i, id_sequence in enumerate(input_ids):
#             # Get word list
#             num_words = int(torch.sum(id_sequence != 0))

#             word_indices = []
            
#             for word_id in id_sequence:
#                 word_id_int = int(word_id) # conver from torch.tensor to int
#                 index = self.id_to_index_map[word_id_int] if word_id_int in self.id_to_index_map else -1
#                 word_indices.append(index)
# #             print('word indices: ', word_indices)
# #             print('num words: ', num_words)
# #             print('input ids: ', input_ids)
            
#             for j in range(num_words):
#                 for k in range(num_words):
#                     if j != k:
#                         adj_tensor[i, j, k] = self.cooc_matrix[word_indices[j]][word_indices[k]]
#                     else:
#                         adj_tensor[i, j, k] = adj_tensor[i, j, k] / (2 * num_words)
        
        
        # number of non-zero input_ids for each sentence
        num_words = []
        
        # i refers to the sentence number 
        for i, id_sequence in enumerate(input_ids):
            num_words.append(int(torch.sum(id_sequence != 0)))
            
            for j in range(num_words[i]):
                for k in range(num_words[i]):
                    if j != k:
                        id_j, id_k = id_sequence[j].item(), id_sequence[k].item()
                        if id_j in self.cooc and id_k in self.cooc:
                            adj_tensor[i, j, k] = self.cooc[id_j][id_k]
                        else:
                            adj_tensor[i, j, k] = 0
            
            
        # Calculate the sums of rows for each matrix
        row_sums = adj_tensor.sum(dim=2, keepdim=True).repeat(1, 1, max_len)

        # Calculate the sums of columns for each matrix
        column_sums = adj_tensor.sum(dim=1, keepdim=True).repeat(1, max_len, 1)

        # Create a diagonal mask for each matrix
        diagonal_mask = torch.eye(adj_tensor.size(-1)).bool().unsqueeze(0).repeat(batch_size, 1, 1)

        total_sum = row_sums + column_sums

        # Set the diagonal entries to the sum of all the row and column entries (will be averaged later)
        res = torch.where(diagonal_mask, total_sum, adj_tensor)
        
        adj_tensor = adj_tensor + res
        
        # Average 
        
        for i, num in enumerate(num_words):
            # Divide diagonal elements by 2
            diagonal = torch.diagonal(adj_tensor[i])
            diagonal_divided = diagonal / num

            # Assign divided diagonal elements back to the tensor
            adj_tensor[i].diagonal().copy_(diagonal_divided)

        return adj_tensor

    def get_avarage(self, aspect_indices, x):
        aspect_indices_us = torch.unsqueeze(aspect_indices, 2)
        x_mask = x * aspect_indices_us
        aspect_len = (aspect_indices_us != 0).sum(dim=1)
        x_sum = x_mask.sum(dim=1)
        x_av = torch.div(x_sum, aspect_len)

        return x_av
    
    def set_dropout(self, dropout):
        self.dropout = nn.Dropout(dropout)

        
    def forward(self, input_ids, segment_ids, valid_ids, mem_valid_ids, dep_adj_matrix, dep_value_matrix):
        # Generate sentence representation with BERT
        sequence_output, pooled_output = self.bert(input_ids, segment_ids)
        
        # Dependency type embeddings
        dep_embed = self.dep_embedding(dep_value_matrix)
        
        # Initializing valid output tensor (i.e. 0 for padding, only keeping representations of tokens in sentence)
        batch_size, max_len, feat_dim = sequence_output.shape
        valid_output = torch.zeros(batch_size, max_len, feat_dim, device=input_ids.device).type_as(sequence_output)
        for i in range(batch_size):
            temp = sequence_output[i][valid_ids[i] == 1]
            valid_output[i][:temp.size(0)] = temp
        valid_output = self.dropout(valid_output)

        attention_score_for_output = [] # Useless code?
        tgcn_layer_outputs = []
        semgcn_layer_outputs = []
        lexgcn_layer_outputs = []
        seq_out_tgcn = valid_output
        seq_out_semgcn = valid_output
        seq_out_lexgcn = valid_output
        if self.use_tgcn:
            for tgcn in self.TGCNLayers:
                # Computing attention
                attention_score = self.get_attention(seq_out_tgcn, dep_embed, dep_adj_matrix)
                attention_score_for_output.append(attention_score) # Useless code?

                # Applying GCN layer
                seq_out = F.relu(tgcn(seq_out_tgcn, attention_score, dep_embed))

                # Saving layer output to be used for layer ensemble later
                tgcn_layer_outputs.append(seq_out_tgcn)
                
            # Average aspect terms for each layer and combining into list 
            tgcn_layer_outputs_pool = [self.get_avarage(mem_valid_ids, x_out) for x_out in tgcn_layer_outputs]
        
        if self.use_semgcn:
            for semgcn in self.SemGCNLayers:
                # Computing attention
                attn = MultiHeadAttention(1, feat_dim)
                attn.to('cuda')
                attn_tensor = attn(seq_out_semgcn, seq_out_semgcn)
                attn_tensor = attn_tensor.squeeze(1)

                # Applying GCN layer
                seq_out_semgcn = F.relu(semgcn(seq_out_semgcn, attn_tensor))

                # Saving layer output
                semgcn_layer_outputs.append(seq_out_semgcn)
                
            # Average aspect terms for each layer and combining into list
            semgcn_layer_outputs_pool = [self.get_avarage(mem_valid_ids, x_out) for x_out in semgcn_layer_outputs]
            
        
        if self.use_lexgcn:
            for lexgcn in self.LexGCNLayers:
                # Compute adjaceny matrix
                adj_tensor = self.get_lex_adj(input_ids, batch_size, max_len)
                adj_tensor = adj_tensor.to('cuda')
                # Applying GCN layer
# #                 print(f'ADJ_TENSOR: {adj_tensor}')
#                 print(seq_out_lexgcn.device, adj_tensor.device)
                seq_out_lexgcn = F.relu(lexgcn(seq_out_lexgcn, adj_tensor))
                
                # Saving layer output
                lexgcn_layer_outputs.append(seq_out_lexgcn)
            
            # Average aspect terms for each layer and combining into list
            lexgcn_layer_outputs_pool = [self.get_avarage(mem_valid_ids, x_out) for x_out in lexgcn_layer_outputs]
        
        all_outputs = []
        
        if self.use_ensemble:
            if self.use_tgcn:
                # Layer ensemble for tgcn
                tgcn_pool = torch.stack(tgcn_layer_outputs_pool, -1) # stacking layer outputs 
                ensemble_tgcn = torch.matmul(tgcn_pool, F.softmax(self.ensemble_linear_tgcn.weight, dim=0))
                ensemble_tgcn = ensemble_tgcn.squeeze(dim=-1)
                ensemble_tgcn = self.dropout(ensemble_tgcn)
                all_outputs.append(ensemble_tgcn)
            
            if self.use_semgcn:
                # Layer ensemble for semgcn
                semgcn_pool = torch.stack(semgcn_layer_outputs_pool, -1)
                ensemble_semgcn = torch.matmul(semgcn_pool, F.softmax(self.ensemble_linear_semgcn.weight, dim = 0))
                ensemble_semgcn = ensemble_semgcn.squeeze(dim=-1)
                ensemble_semgcn = self.dropout(ensemble_semgcn)
                all_outputs.append(ensemble_semgcn)
            
            if self.use_lexgcn:
            # Layer ensemble for lexgcn
                lexgcn_pool = torch.stack(lexgcn_layer_outputs_pool, -1)
                ensemble_lexgcn = torch.matmul(lexgcn_pool, F.softmax(self.ensemble_linear_lexgcn.weight, dim = 0))
                ensemble_lexgcn = ensemble_lexgcn.squeeze(dim=-1)
                ensemble_lexgcn = self.dropout(ensemble_lexgcn)
                all_outputs.append(ensemble_lexgcn)
            
        else:
            # Take only the last layer output
            if self.use_tgcn:
                ensemble_tgcn = tgcn_layer_outputs_pool[-1]
                all_outputs.append(ensemble_tgcn)
            if self.use_semgcn:
                ensemble_semgcn = semgcn_layer_outputs_pool[-1]
                all_outputs.append(ensemble_semgcn)
            if self.use_lexgcn:
                ensemble_lexgcn = lexgcn_layer_outputs_pool[-1]
                all_outputs.append(ensemble_lexgcn)
            
        # Stacking module outputs
        ensemble_out = torch.cat(all_outputs, dim=1)
        
        # gating only if 2 modules used
        if self.fusion_type == 'gate' and self.num_modules == 2: 
            concatenated = torch.cat((ensemble_tgcn, ensemble_semgcn), dim=1) 
            g = torch.matmul(concatenated, self.gate_weight.t()) + self.gate_bias  # Compute W_g[h0 ; h1] + b_g
            g = torch.sigmoid(g)
            ensemble_out = g * ensemble_tgcn + (1 - g) * ensemble_semgcn
          
        # Additional dropout
        if (self.num_modules == 2 and self.fusion_type == 'concat') or self.num_modules == 3:
            ensemble_out = self.concat_dropout(ensemble_out)
        output = self.fc_single(ensemble_out)
        
        return output
    
class AsaTgcn(BertPreTrainedModel):
    def __init__(self, config, dropout = 0.2):
        super(AsaTgcn, self).__init__(config)
        self.config = config
        self.layer_number_tgcn = 3
        self.num_labels = config.num_labels
        self.num_types = config.num_types

        self.bert = BertModel(config)
        self.TGCNLayers = nn.ModuleList(([TypeGraphConvolution(config.hidden_size, config.hidden_size, config.hidden_size)
                                         for _ in range(self.layer_number_tgcn)]))
        self.fc_single = nn.Linear(config.hidden_size, self.num_labels)
        self.dropout = nn.Dropout(dropout)
        self.ensemble_linear_tgcn = nn.Linear(1, self.layer_number_tgcn)
        self.ensemble = nn.Parameter(torch.FloatTensor(3, 1))
        self.dep_embedding = nn.Embedding(self.num_types, config.hidden_size, padding_idx=0)

    def get_attention(self, val_out, dep_embed, adj):
        batch_size, max_len, feat_dim = val_out.shape
        val_us = val_out.unsqueeze(dim=2)
        val_us = val_us.repeat(1,1,max_len,1)
        val_cat = torch.cat((val_us, dep_embed), -1).float()
        atten_expand = (val_cat * val_cat.transpose(1,2))

        attention_score = torch.sum(atten_expand, dim=-1)
        attention_score = attention_score / np.power(feat_dim, 0.5)
        exp_attention_score = torch.exp(attention_score)
        exp_attention_score = torch.mul(exp_attention_score, adj.float()) # mask
        sum_attention_score = torch.sum(exp_attention_score, dim=-1).unsqueeze(dim=-1).repeat(1,1,max_len)

        attention_score = torch.div(exp_attention_score, sum_attention_score + 1e-10)
        if 'HalfTensor' in val_out.type():
            attention_score = attention_score.half()

        return attention_score

    def get_avarage(self, aspect_indices, x):
        aspect_indices_us = torch.unsqueeze(aspect_indices, 2)
        x_mask = x * aspect_indices_us
        aspect_len = (aspect_indices_us != 0).sum(dim=1)
        x_sum = x_mask.sum(dim=1)
        x_av = torch.div(x_sum, aspect_len)
        return x_av
    
    def set_dropout(self, dropout):
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, segment_ids, valid_ids, mem_valid_ids, dep_adj_matrix, dep_value_matrix):
        # Generate sentence representation with BERT
        sequence_output, pooled_output = self.bert(input_ids, segment_ids)
        
        # Dependency type embeddings
        dep_embed = self.dep_embedding(dep_value_matrix)
        
        # Initializing valid output tensor (i.e. 0 for padding, only keeping representations of tokens in sentence)
        batch_size, max_len, feat_dim = sequence_output.shape
        valid_output = torch.zeros(batch_size, max_len, feat_dim, device=input_ids.device).type_as(sequence_output)
        for i in range(batch_size):
            temp = sequence_output[i][valid_ids[i] == 1]
            valid_output[i][:temp.size(0)] = temp
        valid_output = self.dropout(valid_output)

        attention_score_for_output = [] # Useless code?
        tgcn_layer_outputs = []
        semgcn_layer_outputs = []
        seq_out_tgcn = valid_output
        seq_out_semgcn = valid_output
        for tgcn in self.TGCNLayers:
            # Computing attention
            attention_score = self.get_attention(seq_out_tgcn, dep_embed, dep_adj_matrix)
            attention_score_for_output.append(attention_score) # Useless code?
            
            # Applying GCN layer
            seq_out = F.relu(tgcn(seq_out_tgcn, attention_score, dep_embed))
            
            # Saving layer output to be used for layer ensemble later
            tgcn_layer_outputs.append(seq_out_tgcn)
        
        # Average aspect terms for each layer and combining into list
        tgcn_layer_outputs_pool = [self.get_avarage(mem_valid_ids, x_out) for x_out in tgcn_layer_outputs]
        
        # Layer ensemble for tgcn
        tgcn_pool = torch.stack(tgcn_layer_outputs_pool, -1) # stacking layer outputs 
        ensemble_tgcn = torch.matmul(tgcn_pool, F.softmax(self.ensemble_linear_tgcn.weight, dim=0))
        ensemble_tgcn = ensemble_tgcn.squeeze(dim=-1)
        ensemble_tgcn = self.dropout(ensemble_tgcn)
        
        output = self.fc_single(ensemble_tgcn)

        return output

# Main code

In [6]:
import logging
import argparse
import math
import os
import sys
from time import strftime, localtime
import random
import numpy as np
import subprocess

from pytorch_transformers import BertModel, BertConfig
# from data_utils import Tokenizer4Bert, ABSADataset
# from asa_tgcn_model import AsaTgcn

# !pip install scikit-learn
from sklearn import metrics
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split


CONFIG_NAME = 'config.json'
WEIGHTS_NAME = 'pytorch_model.bin'

logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))

class Instructor:
    def __init__(self, opt):
        self.opt = opt
        logger.info(opt)
        deptype2id = ABSADataset.load_deptype_map(opt)
        polarity2id = ABSADataset.get_polarity2id()
        logger.info(deptype2id)
        logger.info(polarity2id)
        self.deptype2id = deptype2id
        self.polarity2id = polarity2id
        
        self.vocab_path = os.path.join(opt.bert_model, 'vocab.txt')
        self.tokenizer = Tokenizer4Bert(opt.max_seq_len, opt.bert_model)
        config = BertConfig.from_json_file(os.path.join(opt.bert_model, CONFIG_NAME))
        config.num_labels=opt.polarities_dim
        config.num_types=len(self.deptype2id)
        logger.info(config)
        if opt.model_type == 'tgcn':
            self.model = AsaTgcn.from_pretrained(opt.bert_model, config=config, dropout = opt.dropout)
        else:
            self.model = AsaTgcnSem.from_pretrained(opt.bert_model, config=config, modules = opt.modules,
                                                    tokenizer = self.tokenizer, opt=self.opt) 
#                                                 use_ensemble = opt.use_ensemble,
#                                                     fusion_type = opt.fusion_type, dropout = opt.dropout, 
#                                                     concat_dropout = opt.concat_dropout,
#                                                    cooc_path = opt.cooc_path, cooc = opt.cooc)
        self.model.set_dropout(opt.dropout)
        self.model.to(opt.device)
        
        self.fulltrainset = ABSADataset(opt.train_file, self.tokenizer, self.opt, deptype2id=deptype2id)
        self.trainset = ABSADataset(opt.train_file, self.tokenizer, self.opt, deptype2id=deptype2id)
        self.testset = ABSADataset(opt.test_file, self.tokenizer, self.opt, deptype2id=deptype2id)
        
        
        if os.path.exists(opt.val_file):
            self.valset = ABSADataset(opt.val_file, self.tokenizer, self.opt, deptype2id=deptype2id)
        elif opt.valset_ratio > 0:
            valset_len = int(len(self.trainset) * opt.valset_ratio)
            self.trainset, self.valset = random_split(self.trainset, (len(self.trainset)-valset_len, valset_len))
        else:
            self.valset = self.testset

        if opt.device.type == 'cuda':
            logger.info('cuda memory allocated: {}'.format(torch.cuda.memory_allocated(device=opt.device.index)))

    def _print_args(self):
        n_trainable_params, n_nontrainable_params = 0, 0
        for p in self.model.parameters():
            n_params = torch.prod(torch.tensor(p.shape))
            if p.requires_grad:
                n_trainable_params += n_params
            else:
                n_nontrainable_params += n_params
        logger.info('n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params))
        logger.info('> training arguments:')
        for arg in vars(self.opt):
            logger.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg)))

    def _reset_params(self):
        for child in self.model.children():
            if type(child) != BertModel:  # skip bert params
                for p in child.parameters():
                    if p.requires_grad:
                        if len(p.shape) > 1:
                            torch.nn.init.xavier_uniform_(p)
                        else:
                            stdv = 1. / math.sqrt(p.shape[0])
                            torch.nn.init.uniform_(p, a=-stdv, b=stdv)

    def save_model(self, save_path, model, args):
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_path, WEIGHTS_NAME)
        output_config_file = os.path.join(save_path, CONFIG_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)

        config = model_to_save.config
        config.__dict__["deptype2id"] = self.deptype2id
        config.__dict__["polarity2id"] = self.polarity2id
        with open(output_config_file, "w", encoding='utf-8') as writer:
            writer.write(config.to_json_string())
        output_args_file = os.path.join(save_path, 'training_args.bin')
        torch.save(args, output_args_file)
        subprocess.run(['cp', self.vocab_path, os.path.join(save_path, 'vocab.txt')])

    def _train(self, criterion, optimizer, train_data_loader, val_data_loader, test_data_loader):
        max_val_acc = -1
        max_val_f1 = -1
        global_step = 0
        path = None

        model_home = self.opt.model_path 
#         model_home += '-' + strftime("%y%m%d-%H%M", localtime())

        results = {"bert_model": self.opt.bert_model, "batch_size": self.opt.batch_size,
                   "learning_rate": self.opt.learning_rate, "seed": self.opt.seed,
                  "num_epoch": self.opt.num_epoch, "l2reg": self.opt.l2reg,
                  "dropout": self.opt.dropout}
        for epoch in range(self.opt.num_epoch):
            logger.info('>' * 100)
            logger.info('epoch: {}'.format(epoch))
            n_correct, n_total, loss_total = 0, 0, 0
            self.model.train()
            for i_batch, t_sample_batched in enumerate(train_data_loader):
                global_step += 1
                optimizer.zero_grad()
                # t_sample_batched["raw_text"],
                outputs = self.model(t_sample_batched["input_ids"].to(self.opt.device),
                                     t_sample_batched["segment_ids"].to(self.opt.device),
                                     t_sample_batched["valid_ids"].to(self.opt.device),
                                     t_sample_batched["mem_valid_ids"].to(self.opt.device),
                                     t_sample_batched["dep_adj_matrix"].to(self.opt.device),
                                     t_sample_batched["dep_value_matrix"].to(self.opt.device))

                targets = t_sample_batched['polarity'].to(self.opt.device)

                loss = criterion(outputs, targets)
                loss.backward()

                optimizer.step()

                n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
                n_total += len(outputs)
                loss_total += loss.item() * len(outputs)
                if global_step % self.opt.log_step == 0:
                    train_acc = n_correct / n_total
                    train_loss = loss_total / n_total
                    logger.info('epoch: {}, loss: {:.4f}, train acc: {:.4f}'.format(epoch, train_loss, train_acc))
            val_acc, val_f1 = Instructor._evaluate_acc_f1(self.model, val_data_loader, device=self.opt.device)
            logger.info('>epoch: {}, val_acc: {:.4f}, val_f1: {:.4f}'.format(epoch, val_acc, val_f1))
            results["{}_val_acc".format(epoch)] = val_acc
            results["{}_val_f1".format(epoch)] = val_f1
            saving_path = os.path.join(model_home, "epoch_{}".format(epoch))
            if not os.path.exists(saving_path):
                os.makedirs(saving_path)
            if val_acc > max_val_acc or (val_acc == max_val_acc and val_f1 > max_val_f1):
                max_val_acc = val_acc
                max_val_f1 = val_f1
                
                if opt.save_models == 'last':
                    best_path = saving_path
                    best_model = self.model
                elif opt.save_models == 'all':
                    self.save_model(saving_path, self.model, self.opt)
                elif opt.save_models == 'none':
                    pass 

                self.model.eval()
                saving_path = os.path.join(model_home, "epoch_{}_eval.txt".format(epoch))
                test_acc, test_f1 = self._evaluate_acc_f1(self.model, test_data_loader, device=self.opt.device,
                                                          saving_path=saving_path)
                logger.info('>> epoch: {}, test_acc: {:.4f}, test_f1: {:.4f}'.format(epoch, test_acc, test_f1))

                results["max_val_acc"] = max_val_acc
                results["test_acc"] = test_acc
                results["test_f1"] = test_f1
            
            output_eval_file = os.path.join(model_home, "eval_results.txt")
            with open(output_eval_file, "w") as writer:
                for k,v in results.items():
                    writer.write("{}={}\n".format(k,v))
        acc_file = os.path.join(model_home, "acc-{:.4f}".format(test_acc))
        if opt.save_models == 'last':
            self.save_model(best_path, best_model, self.opt)
        with open(acc_file, 'w') as f:
            f.write(f"accuracy: {test_acc}")
        return max_val_acc, test_acc, test_f1

    @staticmethod
    def _evaluate_acc_f1(model, data_loader, device, saving_path=None):
        n_correct, n_total = 0, 0
        t_targets_all, t_outputs_all = None, None
        model.eval()

        saving_path_f = open(saving_path, 'w') if saving_path is not None else None

        with torch.no_grad():
            for t_batch, t_sample_batched in enumerate(data_loader):
                t_targets = t_sample_batched['polarity'].to(device)
                t_raw_texts = t_sample_batched['raw_text']
                t_aspects = t_sample_batched['aspect']

                t_outputs = model(t_sample_batched["input_ids"].to(device),
                                  t_sample_batched["segment_ids"].to(device),
                                  t_sample_batched["valid_ids"].to(device),
                                  t_sample_batched["mem_valid_ids"].to(device),
                                  t_sample_batched["dep_adj_matrix"].to(device),
                                  t_sample_batched["dep_value_matrix"].to(device))

                n_correct += (torch.argmax(t_outputs, -1) == t_targets).sum().item()
                n_total += len(t_outputs)

                if t_targets_all is None:
                    t_targets_all = t_targets
                    t_outputs_all = t_outputs
                else:
                    t_targets_all = torch.cat((t_targets_all, t_targets), dim=0)
                    t_outputs_all = torch.cat((t_outputs_all, t_outputs), dim=0)

                if saving_path_f is not None:
                    for t_target, t_output, t_raw_text, t_aspect in zip(t_targets.detach().cpu().numpy(),
                                                                        torch.argmax(t_outputs, -1).detach().cpu().numpy(),
                                                                        t_raw_texts, t_aspects):
                        saving_path_f.write("{}\t{}\t{}\t{}\n".format(t_target, t_output, t_raw_text, t_aspect))

        acc = n_correct / n_total
        f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu(), labels=[0, 1, 2], average='macro')
        return acc, f1

    def train(self):
        # Loss and Optimizer
        criterion = nn.CrossEntropyLoss()
        _params = filter(lambda p: p.requires_grad, self.model.parameters())
        optimizer = torch.optim.Adam(_params, lr=self.opt.learning_rate, weight_decay=self.opt.l2reg)

        train_data_loader = DataLoader(dataset=self.trainset, batch_size=self.opt.batch_size, shuffle=True)
        test_data_loader = DataLoader(dataset=self.testset, batch_size=self.opt.batch_size, shuffle=False)
        val_data_loader = DataLoader(dataset=self.valset, batch_size=self.opt.batch_size, shuffle=False)
        full_train_data_loader = DataLoader(dataset = self.fulltrainset, batch_size = self.opt.batch_size, shuffle=True)

        self._reset_params()
        max_val_acc, test_acc, test_f1 = self._train(criterion, optimizer, train_data_loader, val_data_loader, test_data_loader)
        return max_val_acc, test_acc, test_f1


def test(opt):
    logger.info(opt)
    config = BertConfig.from_json_file(os.path.join(opt.model_path, CONFIG_NAME))
    logger.info(config)

    tokenizer = Tokenizer4Bert(opt.max_seq_len, opt.model_path)
    if opt.model_type == 'tgcn':
        model = AsaTgcn.from_pretrained(opt.model_path)
    elif opt.model_type == 'tgcn+sem':
        model = AsaTgcnSem.from_pretrained(opt.model_path)
    model.set_dropout(opt.dropout)
    model.to(opt.device)

    deptype2id = config.deptype2id
    logger.info(deptype2id)
    testset = ABSADataset(opt.test_file, tokenizer, opt, deptype2id=deptype2id)
    test_data_loader = DataLoader(dataset=testset, batch_size=opt.batch_size, shuffle=False)
    test_acc, test_f1 = Instructor._evaluate_acc_f1(model, test_data_loader, device=opt.device)
    logger.info('>> test_acc: {:.4f}, test_f1: {:.4f}'.format(test_acc, test_f1))


def get_args(model_type = 'tgcn', # tgcn, tgcn+sem, tri_gcn
             # Select which modules to use for hybrid model
             tgcn = True,
             semgcn = True, 
             lexgcn = True,
             tgcn_layers = 3,
             semgcn_layers = 2,
             lexgcn_layers = 2,
             path = None, 
             year='2015',
             val_file='val.txt',
             log = 'log',
             bert_model='bert_large_uncased',
             cooc_path = 'cooc_matrix.csv', # Path to co-occurrence matrix file
             cooc = None, # Pandas DataFrame co-occurrence matrix. If not specified, it will be loaded from cooc_path
             learning_rate=2e-5,
             dropout=0.2,
             concat_dropout = 0.5,
             bert_dropout=0.2,
             l2reg=0.01,
             num_epoch=50,
             batch_size=16,
             log_step=5,
             max_seq_len=100,
             polarities_dim=3,
             device='cuda',
             seed=50,
             valset_ratio=0.1,
             do_train=True,
             do_eval=True,
             eval_epoch_num=0,
             fusion_type = 'concat', # 'concat' or 'gate'
             use_ensemble = True, 
            save_models='last',
            print_sentences = False,
             optim = 'adam'
            ):
    assert model_type == 'tgcn' or model_type == 'tgcn+sem' or model_type == 'tri_gcn'
    opt = argparse.Namespace()
    opt.model_type = model_type
    opt.modules = {'tgcn': tgcn, 'semgcn': semgcn, 'lexgcn': lexgcn}
    opt.num_layers = {'tgcn': tgcn_layers, 'semgcn': semgcn_layers, 'lexgcn': lexgcn_layers}
    opt.year = year
    fusion = "" if model_type == 'tgcn' else "+" + fusion_type
    opt.train_file = f'data/train{year}restaurant.txt'
    opt.test_file = f'data/test{year}restaurant.txt'
    opt.model_path = f'test_models/{year}{model_type}{fusion}_seed{seed}_reg{l2reg}_drop{dropout}_cdrop{concat_dropout}_lr{learning_rate}_epochs{num_epoch}_{optim.lower()}'
#     if model_type == 'tgcn':
#         opt.model_path = f'models/rest_{year}/BERT.L_seed{seed}_reg{l2reg}_drop{dropout}_lr{learning_rate}_epochs{num_epoch}' 
#     elif model_type == 'tgcn+sem':
#         opt.model_path = f'models/rest_{year}/{model_type}/{model_type}_seed{seed}_reg{l2reg}_drop{dropout}_lr{learning_rate}_epochs{num_epoch}'
    if do_eval and not do_train:
        opt.model_path += f'/epoch_{eval_epoch_num}'
    if path:
        opt.model_path = path
    opt.val_file = val_file
    opt.log = log
    opt.bert_model = bert_model
    opt.cooc_path = cooc_path
    opt.cooc = cooc
    opt.learning_rate = learning_rate
    opt.dropout = dropout
    opt.concat_dropout = concat_dropout
    opt.bert_dropout = bert_dropout
    opt.l2reg = l2reg
    opt.num_epoch = num_epoch
    opt.batch_size = batch_size
    opt.log_step = log_step
    opt.max_seq_len = max_seq_len
    opt.polarities_dim = polarities_dim
    opt.device = device
    opt.seed = seed
    opt.valset_ratio = valset_ratio
    opt.do_train = do_train
    opt.do_eval = do_eval
    opt.eval_epoch_num = eval_epoch_num
    opt.fusion_type = fusion_type
    opt.use_ensemble = True
    opt.save_models = save_models
    opt.print_sent = print_sentences
    opt.optim = optim
    return opt


def set_seed(opt):
    if opt.seed is not None:
        random.seed(opt.seed)
        np.random.seed(opt.seed)
        torch.manual_seed(opt.seed)
        torch.cuda.manual_seed(opt.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        

def main(opt):
    opt = opt
    set_seed(opt)

    opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \
        if opt.device is None else torch.device(opt.device)
    opt.n_gpu = torch.cuda.device_count()

    if not os.path.exists(opt.log):
        os.makedirs(opt.log)

    log_file = '{}/log-{}.log'.format(opt.log, strftime("%y%m%d-%H%M", localtime()))
    logger.addHandler(logging.FileHandler(log_file))

    if opt.do_train:
        ins = Instructor(opt)
        max_val_acc, test_acc, test_f1 = ins.train()
    elif opt.do_eval:
        test(opt)
    
    return max_val_acc, test_acc, test_f1

### Loading cooc matrix so we don't have to load it for each model

In [7]:
cooc = pd.read_csv('cooc_matrix_ids.csv', index_col=0)

# Run program

Hyperparameter searching

In [None]:
import csv

csv_file = 'final_results.csv'

res = {'2015': [], '2016': []}

for i in range(40):
    lr = np.random.choice(np.logspace(-6, -3))
    d = np.random.choice([0.1, 0.2, 0.4])
    cdrop = np.random.choice([0.1, 0.2, 0.4])
    w_decay = np.random.choice(np.logspace(-5, -3))
    seed = np.random.randint(1000)
    year = '2016'
    opt = get_args(batch_size = 16, seed = seed, dropout = d,
                  l2reg = w_decay, learning_rate = lr, year = year,
                  num_epoch = 20, model_type = 'tri_gcn', save_models = 'none', fusion_type = 'concat',
                  concat_dropout = cdrop, cooc = cooc, tgcn = True, semgcn = True, lexgcn = True, use_ensemble = True,
                  tgcn_layers = 2, semgcn_layers = 2, lexgcn_layers = 2, optim = 'adam')
    opt.device = torch.device('cuda')
    max_val_acc, test_acc, test_f1 = main(opt)
    res[year].append((opt, max_val_acc, test_acc, test_f1))
    try:
        with open(csv_file, 'a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow((year, max_val_acc, test_acc, test_f1, seed, lr, d, cdrop, w_decay))
    finally:
        print('FAILED SOME SHIT')

Training with best hyperparameters

In [None]:
import csv

csv_file = 'final_results.csv'


for i in range(20):
    lr = 25e-6
    d = 0.4
    cdrop = 0.4
    w_decay = 5e-4
    seed = np.random.randint(1000)
    year = '2016'
    opt = get_args(batch_size = 16, seed = seed, dropout = d,
                  l2reg = w_decay, learning_rate = lr, year = year,
                  num_epoch = 20, model_type = 'tri_gcn', save_models = 'none', fusion_type = 'concat',
                  concat_dropout = cdrop, cooc = cooc, tgcn = True, semgcn = True, lexgcn = True, use_ensemble = True,
                  tgcn_layers = 2, semgcn_layers = 2, lexgcn_layers = 2, optim = 'adam')
    opt.device = torch.device('cuda')
    max_val_acc, test_acc, test_f1 = main(opt)

    with open(csv_file, 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow((year, max_val_acc, test_acc, test_f1, seed, lr, d, cdrop, w_decay))


Namespace(model_type='tri_gcn', modules={'tgcn': True, 'semgcn': True, 'lexgcn': True}, num_layers={'tgcn': 2, 'semgcn': 2, 'lexgcn': 2}, year='2016', train_file='data/train2016restaurant.txt', test_file='data/test2016restaurant.txt', model_path='test_models/2016tri_gcn+concat_seed925_reg0.0005_drop0.4_cdrop0.4_lr2.5e-05_epochs20_adam', val_file='val.txt', log='log', bert_model='bert_large_uncased', cooc_path='cooc_matrix.csv', cooc=       100      1000  10000  10003     10005     10007     10009      1001  \
100    0.0  0.000000    0.0    0.0  0.000000  0.000000  0.000000  0.000000   
1000   0.0  0.000000    0.0    0.0  0.000606  0.000303  0.000606  0.035120   
10000  0.0  0.000000    0.0    0.0  0.000000  0.000000  0.000000  0.333333   
10003  0.0  0.000000    0.0    0.0  0.000000  0.000000  0.000000  0.000000   
10005  0.0  0.333333    0.0    0.0  0.000000  0.166667  0.000000  0.166667   
...    ...       ...    ...    ...       ...       ...       ...       ...   
9992   0.0  0.000

In [None]:
import numpy as np

def log_sampler(min_lr, max_lr, num_samples):
    # Generate random values from a logarithmic scale
    log_min_lr = np.log10(min_lr)
    log_max_lr = np.log10(max_lr)
    log_lr_samples = np.random.uniform(log_min_lr, log_max_lr, num_samples)

    # Convert back to linear scale
    lr_samples = np.power(10, log_lr_samples)

    return lr_samples

num_trials = 100

lr_space = np.linspace(5e-6, 2e-5, num = 50)
cd_space = np.linspace(0, 0.6, 14)
d_space = np.linspace(0, 0.4, 8)
reg_space = np.logspace(-2.5, -1.2, num=100)
for i in range(num_trials):
    c_d = np.random.choice(cd_space)
    d = np.random.choice(d_space)
    reg = np.random.choice(reg_space)
    lr = np.random.choice(lr_space)
    fusion_type = 'concat'
    opt = get_args(batch_size = 16, seed = np.random.randint(1000), dropout = d,
              l2reg = reg, learning_rate = lr, year = '2015',
              num_epoch = 25, model_type = 'tgcn+sem', save_models = 'none', fusion_type = fusion_type,
              concat_dropout = d, tgcn = True, semgcn = True, lexgcn = False)
    main(opt)

# Exploration/testing

In [None]:
opt = get_args(batch_size = 16, seed = 10, dropout = 0.2, 
              l2reg = 0.01, learning_rate = 2e-5, year = '2015',
              num_epoch = 30, model_type = 'tgcn+sem', save_models = 'none', cooc = cooc)
opt.device = torch.device('cuda')
Ins = Instructor(opt)

In [None]:
dir(opt)

In [None]:
train_data_loader = DataLoader(dataset=Ins.trainset, batch_size=Ins.opt.batch_size, shuffle=True)
batch_0 = list(train_data_loader)[0]
# print(batch_0)
input_ids = batch_0['input_ids'].to('cuda')
segment_ids = batch_0['segment_ids'].to('cuda')
valid_ids = batch_0['valid_ids'].to('cuda')
raw_text = batch_0['raw_text']

sequence_output, pooled_output = Ins.model.bert(input_ids, segment_ids)

In [None]:
adj_example = Ins.model.get_lex_adj(input_ids, 16, 100)
print(adj_example[0,:15,:15])

print('datatype: ', Ins.model.cooc.dtypes)

print(2833 in Ins.model.cooc)
Ins.model.cooc[2833][2833]

# for i in input_ids[0]:
#     for j in input_ids[0]:
#         print(int(i), int(j), str(i.item()) in Ins.model.cooc, str(j.item()) in Ins.model.cooc)
        

# Ins.model.cooc[float(2833)][np.float(2833)]

In [None]:
print(batch_0['segment_ids'][0])
print(batch_0['valid_ids'][0])
print(batch_0['mem_valid_ids'][0])
print(batch_0['input_ids'][0])
print(batch_0['raw_text'][0])

def non_zero(tensor):
    # Find the index of the last non-zero element
    last_nonzero_index = torch.nonzero(tensor).max().item()

    # Strip the tensor
    stripped_tensor = tensor[:last_nonzero_index + 1]
    
    return stripped_tensor

print(non_zero(batch_0['valid_ids'][0]))
print(non_zero(batch_0['input_ids'][0]))

for valid, inp in zip(batch_0['valid_ids'][0], batch_0['input_ids'][0]):
    if valid == 0 and inp != 0:
        print(inp)

Ins.tokenizer.tokenizer.tokenize(batch_0['raw_text'][0])

In [None]:
batch_size, max_len, feat_dim = sequence_output.shape
print(f"Batch size: {batch_size}\nMax length: {max_len}\nNumber of features: {feat_dim}")
valid_output = torch.zeros(batch_size, max_len, feat_dim, device=input_ids.device).type_as(sequence_output)
for i in range(batch_size):
    temp = sequence_output[i][valid_ids[i] == 1]
    valid_output[i][:temp.size(0)] = temp
valid_output = Ins.model.dropout(valid_output)

In [None]:
print(valid_output[0,:,0])
print(valid_ids[0])
print(raw_text[0])

print(len(torch.nonzero(input_ids[0])))
print(len(raw_text[0].split()))

In [None]:
import itertools

input_ids = batch_0['input_ids']
print(input_ids[0])
print(valid_ids[0])
print(batch_0['raw_text'][0])
id_map = Ins.model.id_to_index_map
print('Index of word "do": ', id_map[2079])

cooc_matrix = Ins.model.cooc_matrix
id_to_index_map = Ins.model.id_to_index_map

adj = lex_test(input_ids, 16, 100, id_to_index_map, cooc_matrix)

sentence = batch_0['raw_text'][0]

print(f'Sentence has {len(sentence.split())} words')

print(adj[0,:23,:23])

for i, j in list(itertools.product(sentence.split(), sentence.split())):
    if i in cooc and j in cooc:
        print(f'{i}, {j}: {cooc[i][j]}')
    elif not i in cooc:
        print(f'{i} not in cooc matrix.')
    else:
        print(f'{j} not in cooc matrix.')

In [None]:
def lex_test(input_ids, batch_size, max_len, id_to_index_map, cooc_matrix):
        # Initialize an empty adjacency tensor
        adj_tensor = torch.zeros((batch_size, max_len, max_len))
        
        for i, id_sequence in enumerate(input_ids):
            # Get word list
            num_words = int(torch.sum(id_sequence != 0))

            word_indices = []
            
            for word_id in id_sequence:
                word_id_int = int(word_id) # conver from torch.tensor to int
                index = id_to_index_map[word_id_int] if word_id_int in id_to_index_map else -1
                word_indices.append(index)
#             print('word indices: ', word_indices)
#             print('num words: ', num_words)
#             print('input ids: ', input_ids)
            
            if i==0:
                print('\nLex function test prints.')
                print('id sequence for sentence 0:', id_sequence)
                print(f'Num. words in lex function: {num_words}\n')
            
            for j in range(num_words):
                for k in range(num_words):
                    if j != k:
                        adj_tensor[i, j, k] = cooc_matrix[word_indices[j]][word_indices[k]]
                    else:
                        adj_tensor[i, j, k] = adj_tensor[i, j, k] / (2 * num_words)
                
        # Calculate the sums of rows for each matrix
        row_sums = adj_tensor.sum(dim=2, keepdim=True).repeat(1, 1, max_len)

        # Calculate the sums of columns for each matrix
        column_sums = adj_tensor.sum(dim=1, keepdim=True).repeat(1, max_len, 1)

        # Create a diagonal mask for each matrix
        diagonal_mask = torch.eye(adj_tensor.size(-1)).bool().unsqueeze(0).repeat(batch_size, 1, 1)

        total_sum = row_sums + column_sums

        # Set the diagonal entries to the sum of all the row and column entries (will be averaged later)
        res = torch.where(diagonal_mask, total_sum, adj_tensor)
        
#         print('Res: ', res)
        
        adj_tensor = adj_tensor + res
        
        return adj_tensor

In [None]:
input_ids = batch_0['input_ids']
print(input_ids)
max_len = 100
batch_size = 16

adj = Ins.model.get_lex_adj(input_ids, batch_size, max_len)
print(adj)

id_map = Ins.model.id_to_index_map

print(id_map[2320])

print('Batch 0, sentence 0: ')
for word_id in input_ids[0]:
    print(int(word_id), int(word_id) in id_map)

print(adj.shape)
print(adj[0,2])

In [None]:
print(Ins.tokenizer.tokenizer)
id_map = {Ins.tokenizer.tokenizer.convert_tokens_to_ids(Ins.tokenizer.tokenizer.tokenize(w))[0]: i for i, w in enumerate(Ins.model.cooc.columns)}

In [None]:
# print(input_ids)
for i, id_sequence in enumerate(input_ids):
    print(id_sequence)
    # Get word list
    num_words = int(torch.sum(id_sequence != 0))
    print(num_words)

    word_indices = []

    for word_id in id_sequence:
        print(word_id)
        word_id_int = int(word_id)
        index = Ins.model.id_to_index_map[word_id_int] if word_id_int in Ins.model.id_to_index_map else -1
        word_indices.append(index)
    print('word indices: ', word_indices)
    print('num words: ', num_words)
    print('input ids: ', input_ids)

In [None]:
# Check current CUDA memory availability
allocated_memory = torch.cuda.max_memory_allocated() / 1024**2  # Convert bytes to megabytes
cached_memory = torch.cuda.max_memory_cached() / 1024**2  # Convert bytes to megabytes

print(f"Peak allocated memory: {allocated_memory:.2f} MB")
print(f"Peak cached memory: {cached_memory:.2f} MB")

In [None]:
tokenizer = Tokenizer4Bert(max_seq_len = 100, pretrained_bert_name = 'bert_large_uncased')
print(tokenizer.tokenizer.convert_tokens_to_ids(tokenizer.tokenizer.tokenize("B")))
tokenizer.tokenizer.tokenize("i ate a good sammy")

In [None]:
tokenizer = Tokenizer4Bert(max_seq_len = 100, pretrained_bert_name = 'bert_large_uncased')

def ws_test(text):
    tokens = []
    valid_ids = []
    for i, word in enumerate(text):
        if len(text) <= 0:
            continue
        token = tokenizer.tokenizer.tokenize(word)
        tokens.extend(token)
        for m in range(len(token)):
            if m == 0:
                valid_ids.append(1)
            else:
                valid_ids.append(0)
#     print(tokens)
    token_ids = tokenizer.tokenizer.convert_tokens_to_ids(tokens)
    return tokens, token_ids, valid_ids

def create_feature_test(sentence, depinfo, print_sent = False):
    text_left, text_right, aspect, polarity = sentence

    cls_id = tokenizer.tokenizer.vocab["[CLS]"]
    sep_id = tokenizer.tokenizer.vocab["[SEP]"]

    doc = text_left + " " + aspect + " " + text_right

    left_tokens, left_token_ids, left_valid_ids = ws_test(text_left.split(" "))
    right_tokens, right_token_ids, right_valid_ids = ws_test(text_right.split(" "))
    aspect_tokens, aspect_token_ids, aspect_valid_ids = ws_test(aspect.split(" "))
    tokens = left_tokens + aspect_tokens + right_tokens
    input_ids = [cls_id] + left_token_ids + aspect_token_ids + right_token_ids + [sep_id] + aspect_token_ids + [sep_id]
    valid_ids = [1] + left_valid_ids + aspect_valid_ids + right_valid_ids + [1] + aspect_valid_ids + [1]
    mem_valid_ids = [0] + [0] * len(left_tokens) + [1] * len(aspect_tokens) + [0] * len(right_tokens) # aspect terms mask
    segment_ids = [0] * (len(tokens) + 2) + [1] * (len(aspect_tokens)+1)

#     dep_instance_parser = DepInstanceParser(basicDependencies=depinfo, tokens=[])
#     if self.dep_order == "first":
#         dep_adj_matrix, dep_type_matrix = dep_instance_parser.get_first_order()
#     elif self.dep_order == "second":
#         dep_adj_matrix, dep_type_matrix = dep_instance_parser.get_second_order()
#     elif self.dep_order == "third":
#         dep_adj_matrix, dep_type_matrix = dep_instance_parser.get_third_order()
#     else:
#         raise ValueError()

    token_head_list = []
    for input_id, valid_id in zip(input_ids, valid_ids):
        if input_id == cls_id:
            continue
        if input_id == sep_id:
            break
        if valid_id == 1:
            token_head_list.append(input_id)
         
    print(input_ids)
    print(token_head_list)
    
    input_ids = tokenizer.id_to_sequence(input_ids)
    valid_ids = tokenizer.id_to_sequence(valid_ids)
    segment_ids = tokenizer.id_to_sequence(segment_ids)
    mem_valid_ids = tokenizer.id_to_sequence(mem_valid_ids)
    
    

In [None]:
cls_id = tokenizer.tokenizer.vocab["[CLS]"]
sep_id = tokenizer.tokenizer.vocab["[SEP]"]
print(f'cls_id: {cls_id}\nsep_id: {sep_id}')

textdata = Ins.trainset.dataset.textdata
depinfo_all = Ins.trainset.dataset.depinfo
for sentence, depinfo in zip(textdata, depinfo_all):
    text_left, text_right, aspect, polarity = sentence
    left_tokens, left_token_ids, left_valid_ids = ws_test(text_left.split(" "))
    right_tokens, right_token_ids, right_valid_ids = ws_test(text_right.split(" "))
    aspect_tokens, aspect_token_ids, aspect_valid_ids = ws_test(aspect.split(" "))
    tokens = left_tokens + aspect_tokens + right_tokens
    input_ids = [cls_id] + left_token_ids + aspect_token_ids + right_token_ids + [sep_id] + aspect_token_ids + [sep_id]
    valid_ids = [1] + left_valid_ids + aspect_valid_ids + right_valid_ids + [1] + aspect_valid_ids + [1]
    mem_valid_ids = [0] + [0] * len(left_tokens) + [1] * len(aspect_tokens) + [0] * len(right_tokens) # aspect terms mask
    segment_ids = [0] * (len(tokens) + 2) + [1] * (len(aspect_tokens)+1)
    
    print(tokens)
    print(input_ids)
    print(valid_ids)

In [None]:
i = 1

sentence = list(textdata)[i]
depinfo = list(depinfo_all)[i]
print(sentence)
print(depinfo)
print(valid_ids)

text_left, text_right, aspect, polarity = sentence
left_tokens, left_token_ids, left_valid_ids = ws_test(text_left.split(" "))
right_tokens, right_token_ids, right_valid_ids = ws_test(text_right.split(" "))
aspect_tokens, aspect_token_ids, aspect_valid_ids = ws_test(aspect.split(" "))
tokens = left_tokens + aspect_tokens + right_tokens
input_ids = [cls_id] + left_token_ids + aspect_token_ids + right_token_ids + [sep_id] + aspect_token_ids + [sep_id]
valid_ids = [1] + left_valid_ids + aspect_valid_ids + right_valid_ids + [1] + aspect_valid_ids + [1]
mem_valid_ids = [0] + [0] * len(left_tokens) + [1] * len(aspect_tokens) + [0] * len(right_tokens) # aspect terms mask
segment_ids = [0] * (len(tokens) + 2) + [1] * (len(aspect_tokens)+1)

temp = input_ids[valid_ids == 1]
print(input_ids)
print(temp)