In [None]:
!pip install torch-scatter

In [None]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from transformers import BertTokenizerFast, BertModel
import torch_scatter
import inspect
from transformers import get_cosine_schedule_with_warmup
from torch.cuda import amp
import ast

In [None]:
def get_param(shape):
    param = nn.Parameter(torch.Tensor(*shape))
    nn.init.xavier_normal_(param.data)
    return param
def com_mult(a, b):
    r1, i1 = a[:, 0], a[:, 1]
    r2, i2 = b[:, 0], b[:, 1]
    return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim = -1)

def conj(a):    
    a[:, 1] = -a[:, 1]
    return a
def ccorr(a, b):
    return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))

class CompGcnBasis(nn.Module):
    nodes_dim = 0
    head_dim = 0
    tail_dim = 1
    def __init__(self, in_channels, out_channels, num_relations, num_basis_vector,act = torch.tanh,cache = True,dropout = 0.2):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_basis_vector = num_basis_vector
        self.act = act
        self.device = None
        self.cache = cache
        
        #----------- Creating learnable basis vector , shape is (num_basis, feature_size(in channel))
        self.basis_vector = get_param((num_basis_vector, in_channels))
        # this weight matrix initialize the weight features for each relation(including inverse), shape is (2*num_relations, num_basis)
        self.rel_weight = get_param((num_relations*2, self.num_basis_vector))
        # this learnable weight matrix is for projection, that project each relation to the same dimension of node_dimension
        self.weight_rel = get_param((in_channels,out_channels))
        # add another embedding for loop
        self.loop_rel = get_param((1,in_channels))
        #----------- Creating three updated matrix, as three kind of relations updating, in, out, loop
        # using for updating weight
        self.w_in = get_param((in_channels,out_channels))
        self.w_out = get_param((in_channels,out_channels))
        self.w_loop = get_param((in_channels,out_channels))
        
        # define some helpful parameter
        self.in_norm, self.out_norm = None, None
        self.in_index, self.out_index = None, None
        self.in_type, self.out_type = None, None
        self.loop_index, self.loop_type =None, None
        
        self.drop = nn.Dropout(dropout)
        self.bn = nn.BatchNorm1d(out_channels)
    def relation_transform(self, entity_embedding, relation_embedding,type_):
        '''
        This function given entity embedding and relation embedding, in order return three types of 
        non-parameterized operations, which is subjection, corr, multiplication
        '''
        assert type_ in ["mul","sub","corr"], "not implemented now"
        if type_ == "mul":
            out = entity_embedding*relation_embedding
        elif type_ == "sub":
            out = entity_embedding - relation_embedding
        else:
            out = ccorr(entity_embedding,relation_embedding)
        return out
    
    def normalization(self, edge_index, num_entity):
        '''
        As normal GCN, this function calculate the normalization adj matrix 
        '''
        head, tail = edge_index
        edge_weight = torch.ones_like(head).float()
        degree = torch_scatter.scatter_add(edge_weight,head,dim_size=num_entity,dim = self.nodes_dim)
        degree_inv = degree.pow(-0.5)
        # if inf, in order to prevent nan in scatter function
        degree_inv[degree_inv == float("inf")] = 0
        norm = degree_inv[head] * edge_weight * degree_inv[tail]
        return norm
    def scatter_function(self,type_, src, index, dim_size = None):
        '''
        This function given scatter_ type, which should me max, mean,or sum, given source array, given index array, given dimension size
        '''
        assert type_.lower() in ["sum","mean","max"]
        return torch_scatter.scatter(src, index, dim=0,out=None,dim_size = dim_size, reduce= type_)
    
    def propogating_message(self, method, node_features,edge_index,edge_type, rel_embedding, edge_norm,mode,type_):
        '''
        This function done the basic aggregation
        '''
        assert method in ["sum", "mean", "max"]
        assert mode in ["in","out","loop"]
        size = node_features.shape[0]
        coresponding_weight = getattr(self, 'w_{}'.format(mode))
        #-------------- this index selection: given relation embedding and relation_basic representation, choose the inital basis vector part
        relation_embedding = torch.index_select(rel_embedding,dim = 0, index = edge_type)
        # ------------- using index of tail in edge index to represent head by relation
        node_features = node_features[edge_index[1]]
        out = self.relation_transform(node_features, relation_embedding,type_)
        out = torch.matmul(out,coresponding_weight)
        out = out if edge_norm is None else out * edge_norm.view(-1, 1)
        out = self.scatter_function(method,out,edge_index[0],  size)
        return out    
    def forward(self, nodes_features, edge_index,edge_type):
        '''
        Forward propogate function:
            Given input nodes_features, adj_matrix, relation_matrix
        '''
        with amp.autocast():
            if self.device is None:
                self.device = edge_index.device
            # ----------- First done the basis part, which means represent each relation using a vector space defining previously
            relation_embedding = torch.mm(self.rel_weight,self.basis_vector)
            # ----------- add a self-loop dimension
            relation_embedding = torch.cat([relation_embedding,self.loop_rel],dim = 0)
            num_edges = edge_index.shape[1]//2
            num_nodes = nodes_features.shape[self.nodes_dim]
            if not self.cache or self.in_norm == None:
                #---------------- in represent in_relation, out represent out_relation
                self.in_index, self.out_index = edge_index[:,:num_edges], edge_index[:,num_edges:]
                self.in_type, self.out_type = edge_type[:num_edges], edge_type[num_edges:]
                # --------------- create self-loop part
                self.loop_index = torch.stack([torch.arange(num_nodes), torch.arange(num_nodes)]).to(self.device)
                self.loop_type = torch.full((num_nodes,), relation_embedding.shape[0]-1, dtype = torch.long).to(self.device)
                # -------------- create normalization part
                self.in_norm = self.normalization(self.in_index, num_nodes)
                self.out_norm = self.normalization(self.out_index, num_nodes)
            #print(self.in_norm.isinf().any())
            in_res = self.propogating_message('sum',nodes_features,self.in_index,self.in_type, relation_embedding,self.in_norm,"in","sub")
            loop_res = self.propogating_message('sum',nodes_features,self.loop_index,self.loop_type, relation_embedding,None,"loop","sub")
            out_res = self.propogating_message('sum',nodes_features,self.out_index,self.out_type, relation_embedding,self.out_norm,"out","sub")
            # I don't know why but source code done it
            out = self.drop(in_res)*(1/3) + self.drop(out_res)*(1/3) + loop_res*(1/3)
            # update the relation embedding
            out_2 = torch.matmul(relation_embedding,self.weight_rel)
            return self.act(out),out_2

In [None]:
class CompGcn_non_first_layer(nn.Module):
    nodes_dim = 0
    head_dim = 0
    tail_dim = 1
    def __init__(self, in_channels, out_channels, num_relations,act = torch.tanh,dropout = 0.2):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.act = act
        self.device = None
        
        # this learnable weight matrix is for projection, that project each relation to the same dimension of node_dimension
        self.weight_rel = get_param((in_channels,out_channels))
        # add another embedding for loop
        self.loop_rel = get_param((1,in_channels))
        #----------- Creating three updated matrix, as three kind of relations updating, in, out, loop
        # using for updating weight
        self.w_in = get_param((in_channels,out_channels))
        self.w_out = get_param((in_channels,out_channels))
        self.w_loop = get_param((in_channels,out_channels))
        self.drop = nn.Dropout(dropout)
        self.bn = nn.BatchNorm1d(out_channels)
    def relation_transform(self, entity_embedding, relation_embedding,type_):
        '''
        This function given entity embedding and relation embedding, in order return three types of 
        non-parameterized operations, which is subjection, corr, multiplication
        '''
        assert type_ in ["mul","sub","corr"], "not implemented now"
        if type_ == "mul":
            out = entity_embedding*relation_embedding
        elif type_ == "sub":
            out = entity_embedding - relation_embedding
        else:
            out = ccorr(entity_embedding,relation_embedding)
        return out
    
    def normalization(self, edge_index, num_entity):
        '''
        As normal GCN, this function calculate the normalization adj matrix 
        '''
        head, tail = edge_index
        edge_weight = torch.ones_like(head).float()
        degree = torch_scatter.scatter_add(edge_weight,head,dim_size=num_entity,dim = self.nodes_dim)
        degree_inv = degree.pow(-0.5)
        # if inf, in order to prevent nan in scatter function
        degree_inv[degree_inv == float("inf")] = 0
        norm = degree_inv[head] * edge_weight * degree_inv[tail]
        return norm
    def scatter_function(self,type_, src, index, dim_size = None):
        '''
        This function given scatter_ type, which should me max, mean,or sum, given source array, given index array, given dimension size
        '''
        assert type_.lower() in ["sum","mean","max"]
        return torch_scatter.scatter(src, index, dim=0,out=None,dim_size = dim_size, reduce= type_)
    
    def propogating_message(self, method, node_features,edge_index,edge_type, rel_embedding, edge_norm,mode,type_):
        '''
        This function done the basic aggregation
        '''
        assert method in ["sum", "mean", "max"]
        assert mode in ["in","out","loop"]
        size = node_features.shape[0]
        coresponding_weight = getattr(self, 'w_{}'.format(mode))
        #-------------- this index selection: given relation embedding and relation_basic representation, choose the inital basis vector part
        relation_embedding = torch.index_select(rel_embedding,dim = 0, index = edge_type)
        # ------------- using index of tail in edge index to represent head by relation
        node_features = node_features[edge_index[1]]
        out = self.relation_transform(node_features, relation_embedding,type_)
        out = torch.matmul(out,coresponding_weight)
        out = out if edge_norm is None else out * edge_norm.view(-1, 1)
        out = self.scatter_function(method,out,edge_index[0],  size)
        return out    
    def forward(self, nodes_features, edge_index,edge_type,relation_embedding):
        '''
        Forward propogate function:
            Given input nodes_features, adj_matrix, relation_matrix
        '''
        with amp.autocast():
            if self.device is None:
                self.device = edge_index.device
            # ----------- add a self-loop dimension
            relation_embedding = torch.cat([relation_embedding,self.loop_rel],dim = 0)
            num_edges = edge_index.shape[1]//2
            num_nodes = nodes_features.shape[self.nodes_dim]
            #---------------- in represent in_relation, out represent out_relation
            self.in_index, self.out_index = edge_index[:,:num_edges], edge_index[:,num_edges:]
            self.in_type, self.out_type = edge_type[:num_edges], edge_type[num_edges:]
            # --------------- create self-loop part
            self.loop_index = torch.stack([torch.arange(num_nodes), torch.arange(num_nodes)]).to(self.device)
            self.loop_type = torch.full((num_nodes,), relation_embedding.shape[0]-1, dtype = torch.long).to(self.device)
            # -------------- create normalization part
            self.in_norm = self.normalization(self.in_index, num_nodes)
            self.out_norm = self.normalization(self.out_index, num_nodes)
            #print(self.in_norm.isinf().any())
            in_res = self.propogating_message('sum',nodes_features,self.in_index,self.in_type, relation_embedding,self.in_norm,"in","sub")
            loop_res = self.propogating_message('sum',nodes_features,self.loop_index,self.loop_type, relation_embedding,None,"loop","sub")
            out_res = self.propogating_message('sum',nodes_features,self.out_index,self.out_type, relation_embedding,self.out_norm,"out","sub")
            # I don't know why but source code done it
            out = self.drop(in_res)*(1/3) + self.drop(out_res)*(1/3) + loop_res*(1/3)
            # update the relation embedding
            out_2 = torch.matmul(relation_embedding,self.weight_rel)
            return self.act(out),out_2[:-1]# ignoring self loop inserted 

In [None]:
class CompGcn_total(nn.Module):
    def __init__(self, in_channel, out_channel,num_relation, num_basis_vector, edge_idx, edge_type, basis = False):
        '''
        Notice that in preprocessing, we assume that the node number will not be changed on the graph, only change relation.
        input params:
            1. conv_dim, a list/tuple that include the convolution in_channel, out_channel, like [[in_1, out_1], [out_1, out_2]]. Assume that each graph has the same dim
            2. num_relation, the number of relations_type for each graph(after preprocessing, should be the same for each graph)
            3. num_basis_vector, the first layer basis of first graph.
            4. edge_idx, adj matrix 
            5. edge_type, relation init
            6. basis, whether need basis
        '''
        super().__init__()
#         self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.edge_idx = edge_idx
        self.edge_type = edge_type
        self.basis = basis
        if basis:
            self.conv1 = CompGcnBasis(in_channels = in_channel, out_channels= out_channel, num_relations=num_relation, num_basis_vector= num_basis_vector)
            self.conv2 = CompGcn_non_first_layer(out_channel, out_channel, num_relation)
        else:
            self.conv1 = CompGcn_non_first_layer(in_channel, out_channel, num_relation)
            self.conv2 = CompGcn_non_first_layer(out_channel, out_channel, num_relation)
    def forward(self, init_features = None, node_embeding = None,rel_embeding = None,device = None):
        with amp.autocast():
            if self.basis:
                node_embd, rel_embd = self.conv1(init_features, self.edge_idx.to(device), self.edge_type.to(device))
                node_embd, rel_embd = self.conv2(node_embd,self.edge_idx.to(device), self.edge_type.to(device), rel_embd)
            else:
                node_embd, rel_embd = self.conv1(node_embeding,self.edge_idx.to(device), self.edge_type.to(device), rel_embeding)
                node_embd, rel_embd = self.conv2(node_embd,self.edge_idx.to(device), self.edge_type.to(device), rel_embd)
            #print(node_embd.shape, rel_embd.shape)
            return node_embd, rel_embd

In [None]:
class Score_func(nn.Module):
    """
        Func:
            Contain all the score functions we often meet. Now we finished ConvE, TransE, TransH, DisMult
        
        Args:
            sub_emb: the head embedding (subject)
            rel_emb: the relation embedding (relation)
            obj_emb: the tail embedding (object)
            kernel_size: a tuple. Only when the score function is ConvE, we need it to do the 
                        convolutional computation. i.e. kernel_size = (hight, width)
            func_type: a string indicating the score function we wanna use. default to be "TransE"
            conv_drop: a list containing floats, indicating the dropout rate we will use in the ConvE. 
                        If None, set to be all the same as "dropout" value. Default to all be the 
                        tuned parameter in compGCN paper.
            conv_bias: whether to use bias. Default to be True
            gamma: a float - margin hyperparameter. Only when we use TransE as our score function, we 
                    need it. Default to be 40.0, the tuned best parameter in compGCN.
    """
    
    def __init__(self, sub_emb, rel_emb, obj_emb, func_type="transE", 
                 kernel_size = None, conv_drop=(0.2, 0.3, 0.2), 
                 conv_bias=True, gamma=40.0):
        # we can't use self.__class__, because it may cause a recursive problem
        super(Score_func, self).__init__()
        
        self.func_type = func_type.lower()
        self.gamma     = gamma
        self.sub_emb   = sub_emb
        self.rel_emb   = rel_emb
        self.obj_emb   = obj_emb
        
        if self.func_type == "transh":
            self.relation_norm_embedding  = torch.nn.Embedding(num_embeddings=relation_num,
                                                              embedding_dim=self.dimension)
            self.relation_hyper_embedding = torch.nn.Embedding(num_embeddings=relation_num,
                                                               embedding_dim=self.dimension)
            self.entity_embedding         = torch.nn.Embedding(num_embeddings=entity_num,
                                                               embedding_dim=self.dimension)
        
        if self.func_type == "conve":
            assert not kernel_size is None  # to ensure that the kernel size is defined
            
            if not conv_drop:
                self.hidden_drop = [dropout, dropout, dropout]
            else:
                l = len(hidden_drop)
                assert l <= 3  # ensure the length of hidden_drop smaller equal to 3
                if l == 1:
                    self.conv_drop = [conv_drop[0], conv_drop[0], conv_drop[0]]
                elif l == 2:
                    self.conv_drop = [conv_drop[0], conv_drop[0], conv_drop[1]]
                else:
                    self.conv_drop = conv_drop
                
            
            self.kernel_size    = kernel_size
            self.bias           = conv_bias
            
            self.bn0            = torch.nn.BatchNorm2d(1)
            self.bn1            = torch.nn.BatchNorm2d(self.out_channels)
            self.bn2            = torch.nn.BatchNorm1d(self.kernel_size)

            self.hidden_drop    = torch.nn.Dropout(self.conv_drop[0])
            self.hidden_drop2   = torch.nn.Dropout(self.conv_drop[1])
            self.feature_drop   = torch.nn.Dropout(self.conv_drop[2])
            self.m_conv1        = torch.nn.Conv2d(1, out_channels=self.out_channels, 
                                                  kernel_size=(self.kernel_size, self.kernel_size), 
                                                  stride=1, padding=0, bias=self.bias)

            flat_sz_h           = int(2*self.kernel_size[1]) - self.kernel_size + 1
            flat_sz_w           = self.kernel_size[0] - self.kernel_size + 1
            self.flat_sz        = flat_sz_h * flat_sz_w * self.out_channels
            self.fc             = torch.nn.Linear(self.flat_sz, self.kernel_size)
    
    def concat(self, e1_embed, rel_embed):
        e1_embed    = e1_embed. view(-1, 1, self.p.embed_dim)
        rel_embed   = rel_embed.view(-1, 1, self.p.embed_dim)
        stack_inp   = torch.cat([e1_embed, rel_embed], 1)
        stack_inp   = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.p.k_w, self.p.k_h))
        return stack_inp
    
    def projected(self, ent, norm):
        norm = F.normalize(norm, p=2, dim=-1)
        return ent - torch.sum(ent * norm, dim = 1, keepdim=True) * norm
    
    def forward_score(self):
        if   self.func_type == "transe":
            x        = self.sub_emb + self.rel_emb - self.obj_emb
        elif self.func_type == "transh":
            head       = self.entity_embedding(self.sub_emb)
            tail       = self.entity_embedding(self.obj_emb)
            r_norm     = self.relation_norm_embedding(self.rel_emb)
            r_hyper    = self.relation_hyper_embedding(self.rel_emb)
            head_hyper = self.projected(head, r_norm)
            tail_hyper = self.projected(tail, r_norm)
            x          = torch.norm(head_hyper + r_hyper - tail_hyper, p=2, dim=2)
        elif self.func_type == "distmult":
            x        =   torch.mm(self.sub_emb + self.rel_emb, self.obj_emb.transpose(1, 0))
            x        +=  self.bias.expand_as(x)
        elif self.func_type == "conve":
            stk_inp  = self.concat(sub_emb, rel_emb)
            x        = self.bn0(stk_inp)
            x        = self.m_conv1(x)
            x        = self.bn1(x)
            x        = F.relu(x)
            x        = self.feature_drop(x)
            x        = x.view(-1, self.flat_sz)
            x        = self.fc(x)
            x        = self.hidden_drop2(x)
            x        = self.bn2(x)
            x        = F.relu(x)

            x = torch.mm(x, self.obj_emb.transpose(1,0))
            x += self.bias.expand_as(x)
        return x

In [None]:
class CompGcn_with_temporal(nn.Module):
    def __init__(self, conv_dim, num_layer ,num_relation, num_entity,node_dim, num_basis_vector, edge_idx, edge_type, num_class,score_func="TransE"):
        '''
        Notice that in preprocessing, we assume that the node number will not be changed on the graph, only change relation.
        input params:
            1. conv_dim, (in_channel, out_channel)
            2. num_layer, number of CompGCN layer, now assume to 2 per graph
            3. num_relation, the number of relations_type for each graph(after preprocessing, should be the same for each graph)
            4. num_entity, number of entity, should be the same for each graph
            5. node_dimension, dimension of nodes
            6. num_basis_vector, the first layer basis of first graph.
            7. edge_idx, a list of edge_idx
            8. edge_type, a list of edge_type
            9. num_class, class type
        '''
        super().__init__()
        self.node_features = get_param(shape= (num_entity,node_dim))
        assert node_dim == conv_dim[0]
        self.conv1 = CompGcn_total(conv_dim[0], conv_dim[1], num_relation, num_basis_vector, edge_idx[0],edge_type[0],True)
        self.conv2 = CompGcn_total(conv_dim[1], conv_dim[1], num_relation, num_basis_vector, edge_idx[1],edge_type[1],False)
        self.conv3 = CompGcn_total(conv_dim[1], conv_dim[1], num_relation, num_basis_vector, edge_idx[2],edge_type[2],False)
        # change bert input to same as before
        self.ln1 = nn.Linear(conv_dim[1], num_class)
        self.dropout_node = nn.Dropout(0.2)
        self.dropout_rel = nn.Dropout(0.2)
        self.ln1 = self.func_init(self.ln1)
        self.score = score_func
    def func_init(self,m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        return m
    def forward(self, head_index, tail_index, rel_index):
        '''
        Node index and rel index are corresponding information in a batch for bert part, we only care about the node, edge relation in a batch.
        Since the embedding is tail - relation to head, the source will be tail, target will be head
        '''
        with amp.autocast():
            device = self.node_features.device
            node_embd1, rel_embd1 = self.conv1(self.node_features,device = device ) 
            node_embd2, rel_embd2 = self.conv2(node_embeding = node_embd1, rel_embeding = rel_embd1,device = device)
            node_embd2, rel_embd2 = self.dropout_node(node_embd2), self.dropout_rel(rel_embd2) 
            node_embd3, rel_embd3 = self.conv3(node_embeding = node_embd2, rel_embeding = rel_embd2,device = device)
            # then choose corresponding index out:
            # shape should be (len(index), hidden_out)
            hidden_node_state = node_embd3[tail_index,:]
            hidden_rel_state  = rel_embd3[rel_index,:]
            hidden_target_state = node_embd3[head_index,:]
            head, rel, tail      = (
                                        hidden_node_state, 
                                        hidden_rel_state, 
                                        hidden_target_state
                                   )
            score                = Score_func(head, rel, tail, func_type=self.score)
            score                = score.forward_score()
            score = self.ln1(score)
        return score # hidden_node_state, hidden_rel_state, hidden_target_state

In [None]:
class language_Dataset(torch.utils.data.Dataset):
    def __init__(self, df):
        '''
        df is dataframe given previously
        '''
        self.df = df
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        '''
        This function will return the index
        '''
        return torch.tensor(self.df.iloc[idx]["labels"]), self.df.iloc[idx]["index_where"]

In [None]:
def try_gpu(i=0):
    """Return gpu(i) if exists, otherwise return cpu()."""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

# In[ ]:
def train_with_amp(net, train_set, criterion, optimizer, epochs,batch_size, scheduler, gradient_accumulate_step, max_grad_norm , num_gpu):
    net.train()
    # instantiate a scalar object 
    ls = []
    device = [try_gpu(i) for i in range(num_gpu)]
    print("train on " + str(device))
    enable_amp = True if "cuda" in device[0].type else False
    scaler = amp.GradScaler(enabled= enable_amp)
    net.to(device[0])
    global_step = 0
    train_iter = torch.utils.data.DataLoader(train_set, batch_size = batch_size)
    for epoch in range(epochs):
        for idx, value in enumerate(train_iter):
            labels, index = value
            labels = labels.to(device[0])
            head_values = torch.tensor(index[0]).to(device[0])
            tail_values = torch.tensor(index[2]).to(device[0])
            rel_values = torch.tensor(index[1]).to(device[0])
            # when forward process, use amp
            with amp.autocast(enabled= enable_amp):
                output = net(head_values,tail_values, rel_values)  
            loss = criterion(output, labels.view(-1,1).float())
            # prevent gradient to 0
            if gradient_accumulate_step > 1:
                # 如果显存不足，通过 gradient_accumulate 来解决
                loss = loss/gradient_accumulate_step
            
            # 放大梯度，避免其消失
            scaler.scale(loss).backward()
            # do the gradient clip
            gradient_norm = nn.utils.clip_grad_norm_(net.parameters(),max_grad_norm)
            if (idx + 1) % gradient_accumulate_step == 0:
                # 多少 step 更新一次梯度
                # 通过 scaler.step 来unscale 回梯度值， 如果气结果不是infs 和Nans， 调用optimizer.step()来更新权重
                # 否则忽略step调用， 保证权重不更新
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                global_step += 1
                scheduler.step()
            # 每100次计算 print 出一次loss
            if idx % 1000 == 0 or idx == len(train_iter) -1:
                with torch.no_grad():
                    print("==============Epochs "+ str(epoch) + " ======================")
                    print("loss: " + str(loss) + "; grad_norm: " + str(gradient_norm))
                torch.save({'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss},"./checkpoint.params")
                ls.append(loss.item())
                #print("successfully done one train")
    import matplotlib.pyplot as plt
    plt.figure()
    plt.plot(ls)
    
            

In [None]:
if __name__ == "__main__":
    edge_idx_1 = torch.tensor(np.load("../input/3-graphs-info/edge_idx0.npz")["arr_0"])
    edge_idx_2 = torch.tensor(np.load("../input/3-graphs-info/edge_idx1.npz")["arr_0"])
    edge_idx_3 = torch.tensor(np.load("../input/3-graphs-info/edge_idx2.npz")["arr_0"])
    num_nodes_0 = torch.tensor(np.load("../input/3-graphs-info/graph_0_num_nodes.npz")["arr_0"])
    num_nodes_1 = torch.tensor(np.load("../input/3-graphs-info/graph_1_num_nodes.npz")["arr_0"])
    num_nodes_2 = torch.tensor(np.load("../input/3-graphs-info/graph_2_num_nodes.npz")["arr_0"])
    num_relation_0 = torch.tensor(np.load("../input/3-graphs-info/graph_0_num_edges.npz")["arr_0"])
    num_relation_1 = torch.tensor(np.load("../input/3-graphs-info/graph_1_num_edges.npz")["arr_0"])
    num_relation_2 = torch.tensor(np.load("../input/3-graphs-info/graph_2_num_edges.npz")["arr_0"])
    edge_type_0 = torch.tensor(np.load("../input/3-graphs-info/edge_type0.npz")["arr_0"])
    edge_type_1 = torch.tensor(np.load("../input/3-graphs-info/edge_type1.npz")["arr_0"])
    edge_type_2 = torch.tensor(np.load("../input/3-graphs-info/edge_type2.npz")["arr_0"])
    head2idx = np.load('../input/3-graphs-info/graph_2entity2index.npy', allow_pickle=True).item()
    rel2idx =  np.load('../input/3-graphs-info/graph_2rel2index.npy', allow_pickle=True).item()
    conv_dim, num_layer, node_dim, num_basis, edge_idx, edge_type = [64, 128], 2, 64, 37, [edge_idx_1,edge_idx_2,edge_idx_3],[edge_type_0,edge_type_1,edge_type_2]
    tmp = CompGcn_with_temporal(conv_dim,num_layer, num_relation_2, num_nodes_2+1, node_dim, num_basis,edge_idx, edge_type, 1)
    train = pd.read_csv("../input/train-valid-test-dataset/train.csv").drop("Unnamed: 0", axis = 1)
    train["index_where"] = train["index_where"].apply(ast.literal_eval)
    train_set = language_Dataset(train)
    loss = nn.BCEWithLogitsLoss()
    batch_size = 2
    lr = 2e-6
    num_gpu = 1
    optimizer = torch.optim.AdamW(tmp.parameters(), lr = lr)
    scheduler = get_cosine_schedule_with_warmup(optimizer= optimizer, num_warmup_steps = 0, num_training_steps= len(torch.utils.data.DataLoader(train_set, batch_size = batch_size)), num_cycles = 0.5)
    train_with_amp(tmp, train_set, loss,optimizer,3,batch_size, scheduler,1,1000, num_gpu)