# 基于MindSpore的RingFormer实现

## HeteroTransformer

In [1]:
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore.common.initializer import initializer, Normal, XavierUniform
from mindspore.nn import Cell, SequentialCell, ReLU
from mindspore import context
from copy import deepcopy
import math
import numpy as np
from ogb.graphproppred.mol_encoder import AtomEncoder
from mindspore_gl import BatchedGraph, BatchedGraphField
from torch_geometric.nn import GINConv, global_add_pool, GCNConv, global_mean_pool, dense_diff_pool, DenseGINConv, GPSConv
from torch_geometric.nn.models import MLP, AttentiveFP
from torch_geometric.utils import remove_self_loops

class SparseEdgeConv(Cell):
    def __init__(self, in_channels, out_channels, heads=1, concat=True, beta=False, 
                 dropout=0., bias=True, root_weight=True, combine='add', clip_attn=False, **kwargs):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.beta = beta and root_weight
        self.root_weight = root_weight
        self.concat = concat
        self.dropout = dropout
        self.combine = combine
        self.clip_attn = clip_attn
        
        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_key = nn.Dense(in_channels[0], heads * out_channels)
        self.lin_query = nn.Dense(in_channels[1], heads * out_channels)
        self.lin_value = nn.Dense(in_channels[0], heads * out_channels)

        if self.combine.startswith('cat'):
            if self.combine[-1] == '1':
                self.lin_combine = nn.SequentialCell([nn.Dense(in_channels[0]*2, in_channels[0])])
            elif self.combine[-1] == '2':
                self.lin_combine = nn.SequentialCell([
                    nn.Dense(in_channels[0]*2, in_channels[0]), 
                    nn.Dropout(dropout)
                ])
            else:
                self.lin_combine = nn.SequentialCell([
                    nn.Dense(in_channels[0]*2, in_channels[0]), 
                    nn.ReLU()
                ])
        elif self.combine.startswith('add_lin'):
            self.lin_combine = nn.SequentialCell([
                nn.Dense(in_channels[0], in_channels[0]), 
                nn.ReLU()
            ])
        elif self.combine.startswith('lin_add'):
            self.lin_combine = nn.Dense(in_channels[0], in_channels[0])
        elif self.combine.startswith('dual_lin_add'):
            self.lin_combine0 = nn.Dense(in_channels[0], in_channels[0])
            self.lin_combine1 = nn.Dense(in_channels[0], in_channels[0])
            
        if concat:
            self.lin_skip = nn.Dense(in_channels[1], heads * out_channels, has_bias=bias)
            if self.beta:
                self.lin_beta = nn.Dense(3 * heads * out_channels, 1, has_bias=False)
            else:
                self.lin_beta = None
        else:
            self.lin_skip = nn.Dense(in_channels[1], out_channels, has_bias=bias)
            if self.beta:
                self.lin_beta = nn.Dense(3 * out_channels, 1, has_bias=False)
            else:
                self.lin_beta = None
                
        self.softmax = ops.Softmax(axis=-1)
        self.dropout_op = nn.Dropout(dropout)
        self._alpha = None

    def construct(self, x, edge_index, edge_attr=None, return_attention_weights=None):
        if isinstance(x, Tensor):
            x = (x, x)

        # Propagate
        out = self.propagate(x, edge_index, edge_attr)
        
        alpha = self._alpha
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(axis=1)

        if self.root_weight:
            x_r = self.lin_skip(x[1])
            out = out + x_r
        else:
            out = out + x[1]

        if return_attention_weights:
            return out, (edge_index, alpha)
        else:
            return out

    def message(self, query_i, key_j, value_j, edge_attr, index, ptr, size_i):
        assert edge_attr is not None
        
        H, C = self.heads, self.out_channels
        
        if self.combine == 'add':
            key_j = value_j = key_j + edge_attr
        elif self.combine.startswith('cat'):
            key_j = value_j = self.lin_combine(ops.concat([key_j, edge_attr], axis=-1))
        elif self.combine == 'add_lin':
            key_j = value_j = self.lin_combine(key_j + edge_attr)
        elif self.combine == 'lin_add':
            edge_attr = self.lin_combine(edge_attr)
            key_j = value_j = ops.relu(key_j + edge_attr)
        elif self.combine.startswith('dual_lin_add'):
            if self.combine[-1] == '1':
                edge_attr = self.lin_combine0(edge_attr)
                key_j, value_j = self.lin_combine0(key_j), self.lin_combine0(value_j)
                key_j = ops.relu(key_j + edge_attr)
                value_j = ops.relu(value_j + edge_attr)
            elif self.combine[-1] == '2':
                edge_attr = self.lin_combine0(edge_attr)
                key_j, value_j = self.lin_combine0(key_j), self.lin_combine0(value_j)
                key_j = key_j + edge_attr
                value_j = value_j + edge_attr
            elif self.combine[-1] == '3':
                edge_attr = self.lin_combine0(edge_attr)
                key_j, value_j = self.lin_combine0(key_j), self.lin_combine0(value_j)
                key_j = self.dropout_op(key_j + edge_attr)
                value_j = self.dropout_op(value_j + edge_attr)
            elif self.combine[-1] == '4':
                edge_attr = self.lin_combine0(edge_attr)
                key_j = value_j = self.lin_combine1(key_j) + edge_attr
            elif self.combine[-1] == '5':
                edge_attr = self.lin_combine0(edge_attr)
                key_j = value_j = ops.relu(self.lin_combine1(key_j) + edge_attr)
        
        query_i = self.lin_query(query_i).view(-1, H, C)
        key_j = self.lin_key(key_j).view(-1, H, C)
        value_j = self.lin_value(value_j).view(-1, H, C)

        alpha = (query_i * key_j).sum(axis=-1) / math.sqrt(self.out_channels)
        if self.clip_attn:
            alpha = ops.clip_by_value(alpha, -5, 5)
        alpha = self.softmax(alpha)
        self._alpha = alpha
        alpha = self.dropout_op(alpha)

        out = value_j * alpha.view(-1, self.heads, 1)
        return out

class SparseEdgeFullLayer(Cell):
    def __init__(self, in_dim, out_dim, num_heads, dropout=0.0, dim_edge=None, 
                 layer_norm=True, activation='relu', root_weight=True, residual=True, 
                 use_bias=False, combine='add', clip_attn=False, **kwargs):
        super().__init__()
        self.in_channels = in_dim
        self.out_channels = out_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.residual = residual
        self.layer_norm = layer_norm

        self.attention = SparseEdgeConv(in_dim, out_dim//num_heads, heads=num_heads, 
                                      root_weight=root_weight, dropout=dropout, concat=True, 
                                      use_bias=use_bias, combine=combine, clip_attn=clip_attn)

        self.O_h = nn.Dense(out_dim, out_dim)

        if self.layer_norm:
            self.layer_norm1_h = nn.LayerNorm((out_dim,))

        # FFN for h
        self.FFN_h_layer1 = nn.Dense(out_dim, out_dim * 2)
        self.activation_fn = self._get_activation(activation)
        self.FFN_h_layer2 = nn.Dense(out_dim * 2 if activation != 'glu' else out_dim, out_dim)

        if self.layer_norm:
            self.layer_norm2_h = nn.LayerNorm((out_dim,))

        self.dropout_op = nn.Dropout(dropout)

    def _get_activation(self, activation):
        if activation == 'relu':
            return ops.ReLU()
        elif activation == 'gelu':
            return ops.GeLU()
        elif activation == 'silu':
            return ops.SiLU()
        elif activation == 'glu':
            return ops.GLU()
        else:
            raise ValueError(f'activation function {activation} is not valid!')

    def construct(self, x, edge_index, edge_attr, **kwargs):
        h = x
        h_in1 = h  # for first residual connection

        # multi-head attention out
        h = self.attention(x, edge_index, edge_attr)
        h = self.dropout_op(h)
        h = self.O_h(h)

        if self.residual:
            h = h_in1 + h  # residual connection

        if self.layer_norm:
            h = self.layer_norm1_h(h)

        h_in2 = h  # for second residual connection

        # FFN for h
        h = self.FFN_h_layer1(h)
        h = self.activation_fn(h)
        h = self.dropout_op(h)
        h = self.FFN_h_layer2(h)

        if self.residual:
            h = h_in2 + h  # residual connection

        if self.layer_norm:
            h = self.layer_norm2_h(h)

        return h

class Het_Transfomer(Cell):
    def __init__(self, metadata, dim, num_gc_layers, gnn='GINE', inter_gnn='GINE', ring_gnn='GPS', 
                 norm=None, transformer_norm=None, aggr='sum', jk='cat', dropout=0.0, attn_dropout=0.0, 
                 pool='add', first_residual=False, residual=False, heads=4, use_bias=False, padding=True, 
                 init_embs=False, mask_non_edge=False, add_mol=False, combine_mol='add', root_weight=True, 
                 combine_edge='add', clip_attn=False, **kwargs):
        super().__init__()
        self.num_gc_layers = num_gc_layers
        self.convs = nn.CellList()
        self.jk = jk
        self.dropout = dropout
        self.residual = residual
        self.first_residual = first_residual
        self.aggr = aggr
        self.use_edge_attr = True
        self.ring_gnn = ring_gnn
        self.add_mol = add_mol
        self.combine_mol = combine_mol
        
        assert norm is None
        
        if 'mol' in metadata[0]:
            self.use_mol = True
            print('Adding Mol node to heterogenous graph!')
        else:
            self.use_mol = False
        if 'pair' in metadata[0]:
            self.use_pair = True
            print('Adding Pair node to heterogenous graph!')
        else:
            self.use_pair = False
            
        self.pool = global_add_pool
            
        if 'cat' in aggr:
            self.lin_atom = nn.CellList()
            self.lin_ring = nn.CellList()
            if self.use_mol:
                self.lin_mol = nn.CellList()
            if self.use_pair:
                self.lin_pair = nn.CellList()
                
        # Initialize GNN layers based on configuration
        # (Implementation of specific GNN layers would be needed here)
        
        num_atom_messages = 0
        num_ring_messages = 0
        num_pair_messages = 0
        for rel in metadata[1]:
            if rel[-1] == 'atom':
                num_atom_messages += 1
            elif rel[-1] == 'ring':
                num_ring_messages += 1
            elif rel[-1] == 'pair':
                num_pair_messages += 1
                
        for _ in range(num_gc_layers):
            # Initialize conv_dict with appropriate GNN layers
            conv_dict = {}
            conv = HeteroConv(conv_dict, aggr='cat' if 'cat' in aggr else aggr)
            self.convs.append(conv)
            
            if aggr == 'cat':
                self.lin_atom.append(nn.SequentialCell([
                    nn.Dense(num_atom_messages*dim, dim), 
                    nn.ReLU()
                ]))
                self.lin_ring.append(nn.SequentialCell([
                    nn.Dense(num_ring_messages*dim, dim), 
                    nn.ReLU()
                ]))
                if self.use_mol:
                    self.lin_mol.append(nn.SequentialCell([
                        nn.Dense(dim, dim), 
                        nn.ReLU()
                    ]))
                if self.use_pair:
                    self.lin_pair.append(nn.SequentialCell([
                        nn.Dense(dim*num_pair_messages, dim), 
                        nn.ReLU()
                    ]))
            elif aggr == 'cat_self':
                self.lin_atom.append(nn.SequentialCell([
                    nn.Dense((num_atom_messages+1)*dim, dim), 
                    nn.ReLU()
                ]))
                self.lin_ring.append(nn.SequentialCell([
                    nn.Dense((num_ring_messages+1)*dim, dim), 
                    nn.ReLU()
                ]))
                if self.use_mol:
                    self.lin_mol.append(nn.SequentialCell([
                        nn.Dense(2*dim, dim), 
                        nn.ReLU()
                    ]))
                if self.use_pair:
                    self.lin_pair.append(nn.SequentialCell([
                        nn.Dense(2*dim, dim), 
                        nn.ReLU()
                    ]))
    
    def construct(self, x_dict, edge_index_dict, batch_dict, edge_attr_dict=None, edge_type_dict=None, data=None):
        x_atom = [x_dict['atom']] if self.first_residual else []
        x_ring = [x_dict['ring']] if self.first_residual else []
        
        for i, conv in enumerate(self.convs):
            if self.use_edge_attr:
                x_dict = conv(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
            else:
                x_dict = conv(x_dict, edge_index_dict)
                
            x_dict = {key: self.dropout_op(ops.relu(x)) for key, x in x_dict.items()}
            
            if self.aggr == 'cat':
                x_dict['atom'] = self.dropout_op(self.lin_atom[i](x_dict['atom']))
                x_dict['ring'] = self.dropout_op(self.lin_ring[i](x_dict['ring']))
            elif self.aggr == 'cat_self':
                x_dict['atom'] = self.dropout_op(self.lin_atom[i](ops.concat((x_atom[-1], x_dict['atom']), -1)))
                x_dict['ring'] = self.dropout_op(self.lin_ring[i](ops.concat((x_ring[-1], x_dict['ring']), -1)))
                
            x_atom.append(x_dict['atom'])
            x_ring.append(x_dict['ring'])
            
        if self.jk == 'cat':
            x_atom = ops.concat(x_atom, 1)
            x_ring = ops.concat(x_ring, 1)
        elif self.jk == 'last':
            x_atom = x_atom[-1]
            x_ring = x_ring[-1]
            
        x_atom = self.pool(x_atom, batch_dict['atom'])
        
        if self.add_mol:
            x_ring_out = self.pool(x_ring[data['ring'].ring_mask], batch_dict['ring'][data['ring'].ring_mask])
            x_mol = self.pool(x_ring[~data['ring'].ring_mask], batch_dict['ring'][~data['ring'].ring_mask])
            
            if self.combine_mol == 'add':
                x_ring_out = x_ring_out + x_mol
            elif self.combine_mol == 'cat':
                x_ring_out = ops.concat((x_ring_out, x_mol), -1)
            elif self.combine_mol == 'drop':
                pass
        else:
            x_ring_out = self.pool(x_ring, batch_dict['ring'])
            x_mol = None
            
        return x_atom, x_ring_out, None, x_mol
    

class BondEncoder(nn.Cell):
    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()
        
        self.bond_embedding_list = nn.CellList()
        full_bond_feature_dims = [22, 6, 2]
        for i, dim in enumerate(full_bond_feature_dims):
            emb = nn.Embedding(dim, emb_dim)
            emb.embedding_table.set_data(initializer(XavierUniform(), emb.embedding_table.shape))
            self.bond_embedding_list.append(emb)

    def construct(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])

        return bond_embedding

class RingEncoder(nn.Cell):
    def __init__(self, emb_dim, pe=False):
        super(RingEncoder, self).__init__()
        
        self.ring_embedding_list = nn.CellList()
        full_ring_feature_dims = [60]
        for i, dim in enumerate(full_ring_feature_dims):
            emb = nn.Embedding(dim+1, emb_dim)
            emb.embedding_table.set_data(initializer(XavierUniform(), emb.embedding_table.shape))
            self.ring_embedding_list.append(emb)

    def construct(self, x):
        x_embedding = 0
        for i in range(x.shape[1]):
            x_embedding += self.ring_embedding_list[i](x[:,i])

        return x_embedding
    
class RingBondDegreeEncoder(nn.Cell):
    def __init__(self, emb_dim, num_edge_types=17):
        super(RingBondDegreeEncoder, self).__init__()
        
        self.ring_embedding_list = nn.CellList()
        full_ring_feature_dims = [7]*num_edge_types
        for i, dim in enumerate(full_ring_feature_dims):
            emb = nn.Embedding(dim+1, emb_dim, padding_idx=0)
            emb.embedding_table.set_data(initializer(XavierUniform(), emb.embedding_table.shape))
            emb.embedding_table[0] = 0.0
            self.ring_embedding_list.append(emb)

    def construct(self, x):
        x_embedding = 0
        for i in range(x.shape[1]):
            x_embedding += self.ring_embedding_list[i](x[:,i])

        return x_embedding

class Attention(nn.Cell):
    def __init__(self, in_size, hidden_size=16):
        super(Attention, self).__init__()

        self.project = SequentialCell(
            nn.Dense(in_size, hidden_size),
            nn.Tanh(),
            nn.Dense(hidden_size, 1, has_bias=False)
        )

    def construct(self, z):
        w = self.project(z)
        beta = ops.softmax(w, axis=1)
        return (beta * z).sum(1).squeeze(), beta

class HeteroTransformer(nn.Cell):
    def __init__(self, metadata,
                 nclass,
                 nhid=128, 
                 nlayer=5,
                 dropout=0, 
                 attn_dropout=0.0,
                 norm=None, 
                 transformer_norm=None,
                 heads=4,
                 pool='add',
                 conv='GINE',
                 inter_conv='GINE',
                 ring_conv='GINE',
                 jk='cat',
                 final_jk='cat',
                 intra_jk='cat',
                 aggr='cat',
                 criterion='MSE',
                 normalize=False,
                 residual=False,
                 target_task=None,
                 ring_init='atom_deepset',
                 mol_init='atom_deepset',
                 pair_init='random',
                 pe_dim=0,
                 pe_emb_dim=128,
                 num_lin_layer=1,
                 model='Het',
                 contrastive=False,
                 num_deepset_layer=1,
                 init_embs=False,
                 padding=True,
                 mask_non_edge=False,
                 cat_pe=False,
                 use_bias=False,
                 add_mol=False,
                 combine_mol='add',
                 float_pe=False,
                 combine_edge='add',
                 root_weight=True,
                 num_ring_edge_types=1,
                 clip_attn=False,
                 **kwargs):
        super().__init__()
        
        self.dropout = dropout
        self.normalize = normalize
        self.target_task = target_task
        self.pe_dim = pe_dim
        self.final_jk = final_jk
        self.ring_init = ring_init
        self.mol_init = mol_init
        self.contrastive = contrastive
        self.cat_pe = cat_pe
        self.add_mol = add_mol
        self.float_pe = float_pe
        self.num_ring_edge_types = num_ring_edge_types
            
        first_residual = True
        Encoder = Het_Transfomer  # You'll need to implement this in MindSpore

        self.encoder = Encoder(metadata, dim=nhid, gnn=conv, inter_gnn=inter_conv, ring_gnn=ring_conv, 
                             num_gc_layers=nlayer, heads=heads, norm=norm, transformer_norm=transformer_norm, 
                             dropout=dropout, attn_dropout=attn_dropout, pool=pool,
                             aggr=aggr, jk=jk, intra_jk=intra_jk, first_residual=first_residual, 
                             init_embs=init_embs, padding=padding, mask_non_edge=mask_non_edge, 
                             residual=residual, use_bias=use_bias, add_mol=add_mol, combine_mol=combine_mol, 
                             root_weight=root_weight, combine_edge=combine_edge, clip_attn=clip_attn)
        
        # AtomEncoder needs to be implemented in MindSpore
        self.atom_encoder = AtomEncoder(nhid)  
        self.ring_encoder = RingEncoder(nhid-pe_emb_dim) if cat_pe else RingEncoder(nhid)
        
        # Edge attr encoder
        ring_bond_encoder = nn.Embedding(42, nhid)
        ring_bond_encoder.embedding_table.set_data(initializer(XavierUniform(), ring_bond_encoder.embedding_table.shape))
        ar_bond_encoder = nn.Embedding(2, nhid)
        ar_bond_encoder.embedding_table.set_data(initializer(XavierUniform(), ar_bond_encoder.embedding_table.shape))
        ra_bond_encoder = nn.Embedding(2, nhid)
        ra_bond_encoder.embedding_table.set_data(initializer(XavierUniform(), ra_bond_encoder.embedding_table.shape))
          
        self.bond_encoder = nn.CellDict({'a2a': BondEncoder(nhid), 'a2r': ar_bond_encoder, 
                                       'r2r': ring_bond_encoder, 'r2a': ra_bond_encoder})
        
        if self.add_mol:
            self.mol_encoder = nn.Embedding(2, nhid-pe_emb_dim) if cat_pe else nn.Embedding(2, nhid)
            self.mol_encoder.embedding_table.set_data(initializer(XavierUniform(), self.mol_encoder.embedding_table.shape))
                    
        if pe_dim > 0:
            if float_pe:
                self.pe_encoder = SequentialCell(nn.Dense(pe_dim, pe_emb_dim)) if cat_pe else SequentialCell(nn.Dense(pe_dim, nhid))
            else:
                if num_ring_edge_types == 1:
                    self.pe_encoder = nn.Embedding(pe_dim+1, pe_emb_dim, padding_idx=0) if cat_pe else nn.Embedding(pe_dim+1, nhid, padding_idx=0)
                    self.pe_encoder.embedding_table.set_data(initializer(XavierUniform(), self.pe_encoder.embedding_table.shape))
                    self.pe_encoder.embedding_table[0] = 0.0
                else:
                    self.pe_encoder = RingBondDegreeEncoder(pe_emb_dim, num_ring_edge_types) if cat_pe else RingBondDegreeEncoder(nhid, num_ring_edge_types)
        
        if ring_init.startswith('atom_deepset') or ring_init == 'deepset_random':
            self.ring_deepset = SequentialCell(nn.Dense(nhid, nhid-pe_emb_dim), ReLU()) if cat_pe else SequentialCell(nn.Dense(nhid, nhid), ReLU())
            
        penultimate_dim = (nlayer+1)*nhid if jk == 'cat' else nhid
        if final_jk == 'cat':
            final_dim = penultimate_dim * 2
            if combine_mol == 'cat':
                final_dim = final_dim + penultimate_dim
        else:
            final_dim = penultimate_dim
            
        if num_lin_layer == 1:
            self.lin = nn.Dense(final_dim, nclass)
        else:
            self.lin = SequentialCell(
                nn.Dense(final_dim, penultimate_dim),
                ReLU(),
                nn.Dropout(p=dropout),
                nn.Dense(penultimate_dim, nclass)
            )
        
        if criterion == 'MSE':
            self.criterion = nn.MSELoss()
        elif criterion == 'MAE':
            self.criterion = nn.L1Loss()
        else:
            raise NameError(f"{criterion} is not implemented!")

    def construct(self, data):
        # Initialize node embeddings
        # Atom
        x_atom = self.atom_encoder(data.x_dict['atom'].astype(ms.int32))
        
        # Ring
        if self.ring_init == 'random':
            x_ring = self.ring_encoder(data.x_dict['ring'].astype(ms.int32))
        elif self.ring_init == 'zero':
            x_ring = ops.zeros((data['ring'].ptr[-1].item(), x_atom.shape[1]), x_atom.dtype)
        elif self.ring_init.startswith('atom_deepset') or self.ring_init == 'deepset_random':
            ringatoms_batch = [ops.full((n,), i, ms.int32) for i, n in enumerate(data.num_ringatoms)]
            ringatoms_batch = ops.concat(ringatoms_batch, axis=0)
            ringatoms_ptr = data['atom'].ptr[ringatoms_batch]
            ringatoms = data.ring_atoms + ringatoms_ptr
            ring_atoms_map = data.ring_atoms_map + data['ring'].ptr[ringatoms_batch]
            
            x_ring = global_add_pool(x_atom[ringatoms], ring_atoms_map)
            x_ring = ops.dropout(self.ring_deepset(x_ring), p=self.dropout, training=self.training)
            
            if self.add_mol:
                x_ring = ops.concat((x_ring, self.mol_encoder(Tensor([0], ms.int32))), 0)
        elif self.ring_init == 'add' or self.ring_init == 'mean':
            ringatoms_batch = [ops.full((n,), i, ms.int32) for i, n in enumerate(data.num_ringatoms)]
            ringatoms_batch = ops.concat(ringatoms_batch, axis=0)
            ringatoms_ptr = data['atom'].ptr[ringatoms_batch]
            ringatoms = data.ring_atoms + ringatoms_ptr
            ring_atoms_map = data.ring_atoms_map + data['ring'].ptr[ringatoms_batch]
            
            if self.ring_init == 'add':
                x_ring = global_add_pool(x_atom[ringatoms], ring_atoms_map)
            elif self.ring_init == 'mean':
                x_ring = global_mean_pool(x_atom[ringatoms], ring_atoms_map)
        
        if self.pe_dim > 0:
            if self.cat_pe:
                if not self.float_pe:
                    if self.num_ring_edge_types == 1:
                        x_ring = ops.concat((x_ring, self.pe_encoder(data['ring'].ring_pe.reshape(-1).astype(ms.int32))), -1)
                    else:
                        x_ring = ops.concat((x_ring, self.pe_encoder(data['ring'].ring_pe.astype(ms.int32))), -1)
                else:
                    x_ring = ops.concat((x_ring, self.pe_encoder(data['ring'].ring_pe)), -1)
            else:
                if not self.float_pe:
                    if self.num_ring_edge_types == 1:
                        x_ring = x_ring + self.pe_encoder(data['ring'].ring_pe.reshape(-1).astype(ms.int32))
                    else:
                        x_ring = x_ring + self.pe_encoder(data['ring'].ring_pe.astype(ms.int32))
                else:
                    x_ring = x_ring + self.pe_encoder(data['ring'].ring_pe)
        
        x_dict = {'atom': x_atom, 'ring': x_ring}
        edge_attr_dict = {edge_type: self.bond_encoder[edge_type[1]](edge_attr) 
                         for edge_type, edge_attr in data.edge_attr_dict.items()}
        
        atom_embs, ring_embs, pair_embs, mol_embs = self.encoder(
            x_dict, data.edge_index_dict, data.batch_dict, edge_attr_dict, 
            edge_type_dict=data.edge_attr_dict, data=data)
        
        return atom_embs, ring_embs, pair_embs, mol_embs
    
    def get_embs(self, data):
        atom_embs, ring_embs, pair_embs, mol_embs = self(data)
        if self.final_jk == 'cat':
            graph_embs = ops.concat([atom_embs, ring_embs], axis=1)
        elif self.final_jk == 'add':
            graph_embs = atom_embs + ring_embs
        elif self.final_jk == 'attention':
            graph_embs = [atom_embs, ring_embs]
            graph_embs = ops.stack(graph_embs, axis=1)
            graph_embs, attn_values = self.final_attn(graph_embs)
        elif self.final_jk == 'attention_param':
            graph_embs = [atom_embs, ring_embs]
            graph_embs = ops.stack(graph_embs, axis=1)
            graph_embs = (graph_embs * ops.softmax(self.final_attn, axis=1)).sum(1)
        elif self.final_jk == 'atom':
            graph_embs = atom_embs
        elif self.final_jk == 'ring':
            graph_embs = ring_embs
        elif self.final_jk == 'mol':
            graph_embs = mol_embs
        else:
            raise NameError(f"{self.final_jk} is not implemented!")
        return graph_embs
    
    def predict_score(self, data):
        atom_embs, ring_embs, pair_embs, mol_embs = self(data)
        if self.final_jk == 'cat':
            graph_embs = ops.concat([atom_embs, ring_embs], axis=1)
        elif self.final_jk == 'add':
            graph_embs = atom_embs + ring_embs
        elif self.final_jk == 'attention':
            graph_embs = [atom_embs, ring_embs]
            graph_embs = ops.stack(graph_embs, axis=1)
            graph_embs, attn_values = self.final_attn(graph_embs)
        elif self.final_jk == 'attention_param':
            graph_embs = [atom_embs, ring_embs]
            graph_embs = ops.stack(graph_embs, axis=1)
            graph_embs = (graph_embs * ops.softmax(self.final_attn, axis=1)).sum(1)
        elif self.final_jk == 'atom':
            graph_embs = atom_embs
        elif self.final_jk == 'ring':
            graph_embs = ring_embs
        elif self.final_jk == 'mol':
            graph_embs = mol_embs
        else:
            raise NameError(f"{self.final_jk} is not implemented!")

        scores = self.lin(graph_embs)
        return scores
    
    def calc_contra_loss(self, data):
        atom_embs, ring_embs, pair_embs, mol_embs = self(data)
        g1, g2 = [self.project(g) for g in [atom_embs, ring_embs]]
        loss = self.ssl_criterion(g1=g1, g2=g2)
        return loss
    
    def calc_loss(self, data):
        scores = self.predict_score(data)
        mask = (data.y != 0).astype(ms.float32)
        scores = scores * mask
        loss = self.criterion(scores, data.y)
        return loss
    
def global_add_pool(x: Tensor, batch: Tensor = None, size: int = None) -> Tensor:
    """MindSpore implementation of global additive pooling.
    
    Args:
        x (Tensor): Node feature matrix with shape [N, F] or [N]
        batch (Tensor, optional): Batch vector assigning each node to a graph. Shape [N]
        size (int, optional): Number of graphs in batch. If None, inferred from batch.
        
    Returns:
        Tensor: Graph-level outputs with shape [B, F] or [B]
    """
    if x.ndim == 1:
        dim = -1
    else:
        dim = -2
    
    if batch is None:
        return x.sum(axis=dim, keepdims=x.ndim <= 2)
    
    if size is None:
        size = int(batch.max().asnumpy().item()) + 1
    
    # MindSpore's unsorted_segment_sum requires int32 batch indices
    batch = batch.astype(ms.int32)
    
    if x.ndim == 1:
        return ops.unsorted_segment_sum(x, batch, size)
    else:
        return ops.unsorted_segment_sum(x, batch, size)

def global_mean_pool(x: Tensor, batch: Tensor = None, size: int = None) -> Tensor:
    """MindSpore implementation of global mean pooling.
    
    Args:
        x (Tensor): Node feature matrix with shape [N, F] or [N]
        batch (Tensor, optional): Batch vector assigning each node to a graph. Shape [N]
        size (int, optional): Number of graphs in batch. If None, inferred from batch.
        
    Returns:
        Tensor: Graph-level outputs with shape [B, F] or [B]
    """
    if x.ndim == 1:
        dim = -1
    else:
        dim = -2
    
    if batch is None:
        return x.mean(axis=dim, keepdims=x.ndim <= 2)
    
    if size is None:
        size = int(batch.max().asnumpy().item()) + 1
    
    # MindSpore's unsorted_segment_sum requires int32 batch indices
    batch = batch.astype(ms.int32)
    
    if x.ndim == 1:
        sum_pool = ops.unsorted_segment_sum(x, batch, size)
        counts = ops.unsorted_segment_sum(ops.ones_like(x), batch, size)
    else:
        sum_pool = ops.unsorted_segment_sum(x, batch, size)
        counts = ops.unsorted_segment_sum(ops.ones_like(x[:, 0]), batch, size)
    
    # Avoid division by zero
    counts = ops.maximum(counts, ops.ones_like(counts))
    
    if x.ndim == 1:
        return sum_pool / counts
    else:
        return sum_pool / counts.reshape(-1, 1)

  from .autonotebook import tqdm as notebook_tqdm


## 模型训练及评估

In [2]:
import warnings
import deepchem as dc
import warnings
import deepchem as dc
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.metrics import mean_absolute_error, mean_absolute_percentage_error, mean_squared_error
import matplotlib.pyplot as plt
import seaborn as sns
import mindspore as ms
import mindspore.nn as nn
from mindspore.train.callback import Callback
from mindspore.common.initializer import initializer, Normal

from data_loader_het import get_dataset_het
from copy import deepcopy
import mindspore.dataset as ds
from mindspore import context

import random
import numpy as np

def setup_seed(seed):
    ms.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    context.set_context(mode=context.GRAPH_MODE)

class SchedulerCallback(Callback):
    def __init__(self, scheduler):
        self.scheduler = scheduler
        
    def on_train_epoch_end(self, run_context):
        cb_params = run_context.original_args()
        self.scheduler.step()

def train(args, filename=None):
    # pytorch
    # device = torch.device('cuda:%d' % args.gpu)
    
    # mindspore
    context.set_context(device_target="GPU" if args.gpu >= 0 else "CPU", device_id=args.gpu)
    
    maes, mapes, mses = [], [], []
    best_vals = []

    float_pe = False
    pe_dim = 7
    num_ring_edge_types = 1
    add_mol = False

    transform = T.Compose([AddHetRingDegreePE(pe_dim), AddVirtualMol()])
    add_mol = True

    dataloader, dataloader_test, dataloader_val, transformer, meta = get_dataset_het(args, transform)
    num_classes = meta['num_classes']
    n_train = len(dataloader.get_dataset_size())
    n_val = len(dataloader_val.get_dataset_size())
    n_test = len(dataloader_test.get_dataset_size())
        
    for trial in range(args.num_trial):
        setup_seed(trial)
        # Model initialization
        model = HeteroTransformer(dataloader.dataset[0].metadata(), num_classes, args.hidden_dim, args.num_layer, 
                                heads=args.heads, conv=args.model, ring_conv=args.ring_conv, pool=args.pool, 
                                norm='BatchNorm' if args.bn else args.norm, 
                                transformer_norm=args.transformer_norm, l2norm=args.l2norm, 
                                dropout=args.dropout, attn_dropout=args.attn_dropout, 
                                criterion=args.criterion, jk=args.jk, final_jk=args.final_jk, 
                                aggr=args.aggr, normalize=args.normalize, 
                                first_residual=args.first_residual, residual=args.residual, 
                                ring_init=args.ring_init, pe_dim=pe_dim, cat_pe=args.cat_pe, 
                                use_bias=args.use_bias, add_mol=add_mol, combine_mol=args.combine_mol, 
                                float_pe=float_pe, combine_edge=args.combine_edge, 
                                root_weight=args.root_weight, num_ring_edge_types=num_ring_edge_types, 
                                clip_attn=args.clip_attn, model='Transformer')
        
        optimizer = nn.Adam(model.trainable_params(), learning_rate=args.lr, weight_decay=args.weight_decay)
        
        if args.scheduler.startswith('step'):
            step_size, gamma = args.scheduler.split('-')[1:]
            scheduler = nn.StepLR(optimizer, step_size=int(step_size), gamma=float(gamma))
        elif args.scheduler == 'cosine':
            scheduler = nn.CosineDecayLR(args.lr, args.num_epoch)
        elif args.scheduler.startswith('onecycle'):
            pct_start = float(args.scheduler.split('-')[1]) if '-' in args.scheduler else 0.1
            scheduler = nn.OneCycleLR(args.lr, args.num_epoch, steps_per_epoch=len(dataloader), pct_start=pct_start)
        else:
            scheduler = None
            
        scheduler_cb = SchedulerCallback(scheduler) if scheduler else None
        
        # Training & Validation
        best_val = float("Inf")
        best_epoch = 0
        best_model_params = None
        
        for epoch in range(1, args.num_epoch + 1):
            model.set_train()
            loss_all = 0
            
            for data in dataloader.create_dict_iterator():
                loss = model.calc_loss(data)
                optimizer.clear_grad()
                loss.backward()
                optimizer.step()
                loss_all += loss.asnumpy() * data['num_graphs'].asnumpy()
                
            if scheduler_cb:
                scheduler_cb.on_train_epoch_end(None)
                
            print('[TRAIN] Epoch:{:03d} | Loss:{:.4f}'.format(epoch, loss_all / n_train))
            
            # Validation
            model.set_train(False)
            loss_all_val = 0.0
            
            for data in dataloader_val.create_dict_iterator():
                loss = model.calc_loss(data)
                loss_all_val += loss.asnumpy() * data['num_graphs'].asnumpy()
                
            if loss_all_val < best_val:
                best_val = loss_all_val
                best_model_params = deepcopy(model.parameters_dict())
                best_epoch = epoch

            if epoch % args.eval_freq == 0:
                model.set_train(False)
                y_true = []
                y_preds = []
                
                for data in dataloader_test.create_dict_iterator():
                    y_true.extend(data['y'].asnumpy().reshape(-1, num_classes).tolist())
                    y_preds.extend(model.predict_score(data).asnumpy().reshape(-1, num_classes).tolist())
                    
                y_true = np.array(y_true)
                y_preds = np.array(y_preds)
                test_mask = y_true != 0
                y_true = y_true[test_mask].reshape(-1, 1).tolist()
                y_preds = y_preds[test_mask].reshape(-1, 1).tolist()
                
                if args.normalize:
                    y_true = transformer.inverse_transform(y_true)
                    y_preds = transformer.inverse_transform(y_preds)
                    
                mae = mean_absolute_error(y_true, y_preds)
                mape = mean_absolute_percentage_error(y_true, y_preds)
                mse = mean_squared_error(y_true, y_preds)
                
        # Test on best validation
        ms.load_param_into_net(model, best_model_params)
        model.set_train(False)
        y_true = []
        y_preds = []
        
        for data in dataloader_test.create_dict_iterator():
            y_true.extend(data['y'].asnumpy().reshape(-1, num_classes).tolist())
            y_preds.extend(model.predict_score(data).asnumpy().reshape(-1, num_classes).tolist())
            
        assert len(y_true) == n_test and len(y_preds) == n_test
        y_true = np.array(y_true)
        y_preds = np.array(y_preds)
        test_mask = y_true != 0
        y_true = y_true[test_mask].reshape(-1, 1).tolist()
        y_preds = y_preds[test_mask].reshape(-1, 1).tolist()
        
        if args.normalize:
            y_true = transformer.inverse_transform(y_true)
            y_preds = transformer.inverse_transform(y_preds)
            
        mae = mean_absolute_error(y_true, y_preds)
        mape = mean_absolute_percentage_error(y_true, y_preds)
        mse = mean_squared_error(y_true, y_preds)

        maes.append(mae)
        mapes.append(mape)
        mses.append(mse)
        best_vals.append(best_val)

    avg_val = np.mean(maes)
    std_val = np.std(maes)
    print('MAE: {:.4f}+-{:.4f}'.format(avg_val, std_val))
    
def get_dataset_het_ms(args, transform=None):
    """MindSpore版本的数据加载函数"""
    meta = {}
    transformer = None
    
    # 特征化器选择（与原始代码相同）
    if args.featurizer == 'MACCS':
        featurizer = MACCSKeysFingerprint()
    elif args.featurizer == 'ECFP6':
        featurizer = CircularFingerprint(size=1024, radius=6)
    elif args.featurizer == 'Mordred':
        featurizer = MordredDescriptors(ignore_3D=True)
    elif args.featurizer is None:
        featurizer = None
    else:
        raise NotImplementedError
        
    # 数据集加载（需要适配MindSpore）
    if args.dataset == 'HOPV':
        dataset_pyg = HOPVHetDataset(transform=transform, version=args.dataset_version)
        index_dir = './data/HOPV/'
    elif args.dataset == 'PFD':
        dataset_pyg = PolymerFAHetDataset(transform=transform, version=args.dataset_version)
        index_dir = './data/PFD/'
    elif args.dataset == 'PD':
        dataset_pyg = pNFAHetDataset(transform=transform, version=args.dataset_version)
        index_dir = './data/Polymer_NFA_p/'
    elif args.dataset == 'NFA':
        dataset_pyg = nNFAHetDataset(transform=transform, version=args.dataset_version)
        index_dir = './data/Polymer_NFA_n/'
    else:
        raise NotImplementedError
    
    # 特征处理
    X = featurizer.featurize(dataset_pyg.data.smiles) if featurizer else np.arange(len(dataset_pyg)).reshape(-1,1)
    meta['fingerprint_dim'] = X.shape[1]
    
    # 目标数据处理
    if args.target_mode == 'single':
        nonzero_mask = dataset_pyg.data.y[:, args.target_task] > -100 if (args.dataset == 'HOPV' and args.target_task == 0) else dataset_pyg.data.y[:, args.target_task] != 0
        smiles = np.array(dataset_pyg.data.smiles)[nonzero_mask].tolist()
        dataset = dc_data.DiskDataset.from_numpy(
            X[nonzero_mask], 
            dataset_pyg.data.y.numpy()[nonzero_mask, args.target_task], 
            None, 
            smiles
        )
        meta['num_classes'] = 1
    elif args.target_mode == 'multi':
        dataset = dc_data.DiskDataset.from_numpy(
            X, 
            dataset_pyg.data.y, 
            None, 
            dataset_pyg.data.smiles
        )
        meta['num_classes'] = dataset.y.shape[1]
    else:
        raise NotImplementedError
        
    # 数据集划分
    splitter = RandomSplitter() if args.splitter == 'random' else ScaffoldSplitter()
    train_index, valid_index, test_index = splitter.split(
        dataset, 
        frac_train=args.frac_train, 
        frac_valid=(1-args.frac_train)/2, 
        frac_test=(1-args.frac_train)/2
    )
    
    # 数据标准化
    if args.normalize:
        if args.scaler == 'standard':
            transformer = StandardScaler()
        elif args.scaler == 'minmax':
            transformer = MinMaxScaler()
        transformer.fit(train_dataset.y.reshape(-1, meta['num_classes']))
        y_train = transformer.transform(train_dataset.y.reshape(-1, meta['num_classes']))
        y_valid = transformer.transform(valid_dataset.y.reshape(-1, meta['num_classes']))
        y_test = transformer.transform(test_dataset.y.reshape(-1, meta['num_classes']))
    else:
        y_train = train_dataset.y.reshape(-1, meta['num_classes'])
        y_valid = valid_dataset.y.reshape(-1, meta['num_classes'])
        y_test = test_dataset.y.reshape(-1, meta['num_classes'])
    
    # 创建MindSpore数据集
    def data_generator(indices, y_data, mode='train'):
        for idx, y in zip(indices, y_data):
            data = dataset_pyg[idx]
            data.y = ms.Tensor(np.array([y]), dtype=ms.float32).reshape(1, -1)
            yield self._convert_to_mindspore_format(data)
    
    # 使用GeneratorDataset创建数据集
    train_ds = ds.GeneratorDataset(
        source=lambda: data_generator(train_index, y_train),
        column_names=["data"],
        shuffle=True
    )
    
    val_ds = ds.GeneratorDataset(
        source=lambda: data_generator(valid_index, y_valid),
        column_names=["data"],
        shuffle=False
    )
    
    test_ds = ds.GeneratorDataset(
        source=lambda: data_generator(test_index, y_test),
        column_names=["data"],
        shuffle=False
    )
    
    # 批量处理
    train_ds = train_ds.batch(args.batch_size)
    val_ds = val_ds.batch(1024)
    test_ds = test_ds.batch(1024)
    
    return train_ds, test_ds, val_ds, transformer, meta
    
if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('-dataset', type=str, default='HOPV', choices=['HOPV', 'PFD', 'NFA', 'PD', 'CEPDB'])
    parser.add_argument('-dataset_version', type=str, default='V1')
    parser.add_argument('-featurizer', type=str, default=None, choices=[None, 'MACCS', 'ECFP6', 'Mordred'])
    parser.add_argument('-normalize', type=bool, default=False)
    parser.add_argument('-scaler', type=str, default='standard', choices=['minmax', 'standard'])
    parser.add_argument('-frac_train', type=float, default=0.6)
    parser.add_argument('-target_mode', type=str, default='single')
    parser.add_argument('-target_task', type=int, default=0, help='0: PCE, 1: HOMO, 2: LUMO, 3: band gap, 4: Voc, 5: Jsc, 6: FF')
    parser.add_argument('-splitter', type=str, default='scaffold')
    parser.add_argument('-model', type=str, default='GINE')
    parser.add_argument('-ring_conv', type=str, default='SparseEdge')
    parser.add_argument('-num_trial', type=int, default=1)
    parser.add_argument('-gpu', type=int, default=1)

    parser.add_argument('-num_epoch', type=int, default=10)
    parser.add_argument('-eval_freq', type=int, default=10)
    parser.add_argument('-batch_size', type=int, default=128)
    parser.add_argument('-bn', type=bool, default=False)
    parser.add_argument('-norm', type=str, default=None, choices=[None, 'BatchNorm', 'LayerNorm'])
    parser.add_argument('-transformer_norm', type=str, default='LayerNorm', choices=[None, 'BatchNorm', 'LayerNorm'])


    parser.add_argument('-lr', type=float, default=0.001)
    parser.add_argument('-weight_decay', type=float, default=5e-4)
    parser.add_argument('-dropout', type=float, default=0.0)
    parser.add_argument('-attn_dropout', type=float, default=0.0)
    parser.add_argument('-criterion', type=str, default='MAE')
    parser.add_argument('-scheduler', type=str, default='onecycle-0.05')

    parser.add_argument('-num_layer', type=int, default=5)
    parser.add_argument('-hidden_dim', type=int, default=128)
    parser.add_argument('-heads', type=int, default=4)
    parser.add_argument('-l2norm', type=bool, default=False)
    parser.add_argument('-pool', type=str, default='add')
    parser.add_argument('-jk', type=str, default='cat')
    parser.add_argument('-final_jk', type=str, default='cat')
    parser.add_argument('-aggr', type=str, default='cat')
    parser.add_argument('-ring_init', type=str, default='random')
    parser.add_argument('-first_residual', type=bool, default=True)
    parser.add_argument('-residual', type=bool, default=True)
    parser.add_argument('-use_bias', type=bool, default=False)

    parser.add_argument('-transform', type=str, default=None, choices=[None, 'VirtualNode'])
    parser.add_argument('-best_val', type=bool, default=True)

    parser.add_argument('-PE', type=str, default='RingDegree', choices=['RingDegree', 'RingBondDegree', 'RandomWalk']) # 
    parser.add_argument('-pe_dim', type=int, default=7)
    parser.add_argument('-cat_pe', type=bool, default=True)

    parser.add_argument('-combine_mol', type=str, default='add')
    parser.add_argument('-root_weight', type=bool, default=True)
    parser.add_argument('-combine_edge', type=str, default='cat', choices=['add', 'add_lin', 'cat','add_lin'])
    parser.add_argument('-clip_attn', type=bool, default=True)
    parser.add_argument('-add_cross', type=bool, default=True)
    parser.add_argument('-add_mol', type=bool, default=True)

    parser = argparse.ArgumentParser()
    args = parser.parse_args()
    train(args, None)




_StoreAction(option_strings=['-add_mol'], dest='add_mol', nargs=None, const=None, default=True, type=<class 'bool'>, choices=None, help=None, metavar=None)

## PFD数据集复现

In [18]:
import sys
sys.argv = [
    "train.py",          
    "-dataset", "PFD",
    "-num_layer", "8",
    "-hidden_dim", "512",
    "-heads", "4",
    "-num_epoch", "40",
    "-batch_size", "32",
    "-lr", "0.01",
    "-num_trial", "3"
]

args = parser.parse_args()                
train(args, None)



[TRAIN] Epoch:010 | Loss:1.9042
[TRAIN] Epoch:020 | Loss:1.6362
[TRAIN] Epoch:030 | Loss:1.4894
[TRAIN] Epoch:040 | Loss:1.2977
[EVAL] trial:0 | MAE:1.8805
[TRAIN] Epoch:010 | Loss:2.0529
[TRAIN] Epoch:020 | Loss:1.7544
[TRAIN] Epoch:030 | Loss:1.7681
[TRAIN] Epoch:040 | Loss:1.3465
[EVAL] trial:1 | MAE:1.7271
[TRAIN] Epoch:010 | Loss:2.1798
[TRAIN] Epoch:020 | Loss:1.6379
[TRAIN] Epoch:030 | Loss:1.3286
[TRAIN] Epoch:040 | Loss:1.7871
[EVAL] trial:2 | MAE:1.7265
[FINAL EVAL] MAE: 1.7781+-0.0725


## HOPV数据集复现

In [3]:
import sys
sys.argv = [
    "train.py",          
    "-dataset", "HOPV",
    "-num_layer", "10",
    "-hidden_dim", "512",
    "-heads", "4",
    "-num_epoch", "150",
    "-batch_size", "16",
    "-lr", "0.01",
    "-num_trial", "3"
]

args = parser.parse_args()                
train(args, None)



[TRAIN] Epoch:010 | Loss:2.3964
[TRAIN] Epoch:020 | Loss:1.7668
[TRAIN] Epoch:030 | Loss:1.5432
[TRAIN] Epoch:040 | Loss:1.3757
[TRAIN] Epoch:050 | Loss:1.1448
[TRAIN] Epoch:060 | Loss:1.1641
[TRAIN] Epoch:070 | Loss:1.1050
[TRAIN] Epoch:080 | Loss:1.1042
[TRAIN] Epoch:090 | Loss:1.2098
[TRAIN] Epoch:100 | Loss:1.0693
[TRAIN] Epoch:110 | Loss:1.0667
[TRAIN] Epoch:120 | Loss:1.0903
[TRAIN] Epoch:130 | Loss:1.0695
[TRAIN] Epoch:140 | Loss:1.0076
[TRAIN] Epoch:150 | Loss:0.9961
[EVAL] trial:0 | MAE:1.6005
[TRAIN] Epoch:010 | Loss:2.3770
[TRAIN] Epoch:020 | Loss:1.6895
[TRAIN] Epoch:030 | Loss:1.5439
[TRAIN] Epoch:040 | Loss:1.4042
[TRAIN] Epoch:050 | Loss:1.3753
[TRAIN] Epoch:060 | Loss:1.3639
[TRAIN] Epoch:070 | Loss:1.3974
[TRAIN] Epoch:080 | Loss:1.0698
[TRAIN] Epoch:090 | Loss:1.0379
[TRAIN] Epoch:100 | Loss:1.1293
[TRAIN] Epoch:110 | Loss:1.0921
[TRAIN] Epoch:120 | Loss:1.0248
[TRAIN] Epoch:130 | Loss:1.3059
[TRAIN] Epoch:140 | Loss:1.0535
[TRAIN] Epoch:150 | Loss:0.9787
[EVAL] trial