In [1]:
import pandas as pd
import torch
import torch_geometric as pyg
from torch_geometric.data import HeteroData
import transformers
from transformers import BartForConditionalGeneration, BartTokenizer
import copy

In [2]:
bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base", forced_bos_token_id=0)
tok = BartTokenizer.from_pretrained("facebook/bart-base")

In [86]:
bart

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05,

# 1. Dataset

In [3]:
from dataset import Vocab, S2SDataset

In [4]:
vocabs = dict()
vocabs["node"] = Vocab("data/webnlg-few/node.pkl")
vocabs["relation"] = Vocab("data/webnlg-few/relation.pkl")

In [5]:
dataset = S2SDataset(data_dir='data/',
                     dataset='webnlg-few',
                     tokenizer=tok,
                     node_vocab=vocabs['node'],
                     relation_vocab=vocabs['relation'],
                     num_samples='all',
                     usage='train')

In [6]:
from utils import build_optimizer, init_seed, init_logger, init_device, read_configuration, collate_fn_graph_text, \
    format_time

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, collate_fn=collate_fn_graph_text)
batch = next(iter(dataloader))

In [7]:
nodes, edges, types, node_masks, kd_description, kd_description_masks, kd_positions, \
        recon_relations, recon_positions, recon_masks, gen_outputs, gen_masks, pointer, pointer_masks = batch

# 2. Model

In [62]:
import math

In [12]:
# model

import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn.conv.gcn_conv import GCNConv
from torch_geometric.nn.conv.rgcn_conv import RGCNConv
from transformers import RobertaTokenizer, RobertaForMaskedLM
import math
import numpy as np
import torch.backends.cudnn as cudnn


class ListModule(nn.Module):
    """
    Abstract list layer class.
    """
    def __init__(self, *args):
        """
        Model initializing.
        """
        super(ListModule, self).__init__()
        idx = 0
        for module in args:
            self.add_module(str(idx), module)
            idx += 1

    def __getitem__(self, idx):
        """
        Getting the indexed layer.
        """
        if idx < 0 or idx >= len(self._modules):
            raise IndexError('index {} is out of range'.format(idx))
        it = iter(self._modules.values())
        for i in range(idx):
            next(it)
        return next(it)

    def __iter__(self):
        """
        Iterating on the layers.
        """
        return iter(self._modules.values())

    def __len__(self):
        """
        Number of layers.
        """
        return len(self._modules)


class GraphEncoder(nn.Module):
    def __init__(self, num_nodes, num_relations, gnn_layers, embedding_size, initilized_embedding, dropout_ratio=0.3):
        super(GraphEncoder, self).__init__()
        self.num_nodes = num_nodes
        self.num_relations = num_relations
        self.gnn_layers = gnn_layers
        self.embedding_size = embedding_size
        self.dropout_ratio = dropout_ratio

        self.node_embedding = nn.Embedding(num_nodes, embedding_size)
        self.node_embedding.from_pretrained(torch.from_numpy(np.load(initilized_embedding)), freeze=False)

        self.dropout = nn.Dropout(dropout_ratio)

        self.gnn = []
        for layer in range(gnn_layers):
            self.gnn.append(RGCNConv(embedding_size, embedding_size, num_relations))  # if rgcn is too slow, you can use gcn
        self.gnn = ListModule(*self.gnn)
        # learnable prompt for decoder cross attention (graph prompt)
        self.gprompt = nn.Parameter(torch.randn(4, 768)/math.sqrt(768), requires_grad = True)

    def forward(self, nodes, edges, types):
        """
        :param nodes: tensor, shape [batch_size, num_nodes]
        :param edges: List(List(edge_idx)) : list length of batch size
        :param types: List(type_idx)
        """
        batch_size = nodes.size(0)
        device = nodes.device

        # (batch_size, num_nodes, output_size)
        node_embeddings = []
        for bid in range(batch_size):
            embed = self.node_embedding(nodes[bid, :])
            edge_index = torch.as_tensor(edges[bid], dtype=torch.long, device=device)
            edge_type = torch.as_tensor(types[bid], dtype=torch.long, device=device)
            for lidx, rgcn in enumerate(self.gnn):
                if lidx == len(self.gnn) - 1:
                    embed = rgcn(embed, edge_index=edge_index, edge_type=edge_type)
                else:
                    embed = self.dropout(F.relu(rgcn(embed, edge_index=edge_index, edge_type=edge_type)))
            node_embeddings.append(embed)
        node_embeddings = torch.stack(node_embeddings, 0)  # [batch_size, num_node, embedding_size]
        node_embeddings2 = torch.cat([node_embeddings, self.gprompt.repeat(batch_size, 1,1)], dim=1)

        return node_embeddings, node_embeddings2
    
    
def compute_kd_loss(node_embeddings, desc_embeddings, node_masks, kd_masks):
    assert node_embeddings.size() == desc_embeddings.size()
    mse_loss = nn.MSELoss(reduction='none')
    loss = mse_loss(node_embeddings, desc_embeddings)
    loss = loss.mean(dim=-1)
    masks = node_masks * kd_masks
    loss = loss.masked_select(masks).mean()
    return loss

def compute_ce_loss(logits, labels, masks):
    ce_loss = nn.CrossEntropyLoss(ignore_index=0, reduction="none")
    loss = ce_loss(logits.view(-1, logits.size(-1)), labels.view(-1))
    loss = loss.reshape_as(labels)
    loss = loss.masked_select(masks).mean()
    return loss

def compute_alignment_loss(batch, bart, graph_enc):
    nodes, edges, types, node_masks, kd_description, kd_description_masks, kd_positions, \
        recon_relations, recon_positions, recon_masks, gen_outputs, gen_masks, pointer, pointer_masks = batch
    with torch.no_grad():
        output_dict = bart(input_ids=kd_description,
                          attention_mask=kd_description_masks,
                          output_hidden_states=True,
                          return_dict=True)
    positions = kd_positions.unsqueeze(-1).expand(-1, -1, output_dict["encoder_last_hidden_state"].size(-1))
    teacher_embeddings = torch.gather(output_dict["encoder_last_hidden_state"], dim=1, index=positions)
    teacher_embeddings = teacher_embeddings.detach()
    student_embeddings, student_embeddings2 = graph_enc(nodes, edges, types)

    node_masks = node_masks
    kd_masks = torch.ne(kd_positions, 0)
    kd_loss = compute_kd_loss(student_embeddings, teacher_embeddings, node_masks, kd_masks)
    return kd_loss, student_embeddings2

In [13]:
graph_enc = GraphEncoder(num_nodes = vocabs['node'].size(), # 존재하는 node의 총 개수
                         num_relations = vocabs['relation'].size(), # 존재하는 relation의 총 개수 
                         gnn_layers = 2,
                         embedding_size = 768,
                         initilized_embedding='data/webnlg-few/node_embeddings.npy',
                         dropout_ratio=0.3)

In [14]:
loss, student_emb = compute_alignment_loss(batch, bart, graph_enc)

In [18]:
from transformers.modeling_outputs import BaseModelOutput
encoder_outputs = BaseModelOutput()
encoder_outputs.last_hidden_state = student_emb
out = bart.generate(inputs_embeds = torch.randn(4, 10, 768), encoder_outputs=encoder_outputs)

torch.Size([4, 30, 768])

In [78]:
def batch_loss(bart, graph_enc, batch):
    # compute alignment loss
    align_loss, student_emb = compute_alignment_loss(batch, bart, graph_enc)
    
    nodes, edges, types, node_masks, kd_description, kd_description_masks, kd_positions, \
        recon_relations, recon_positions, recon_masks, gen_outputs, gen_masks, pointer, pointer_masks = batch
    
    output = bart(encoder_outputs=[student_emb],
                  decoder_input_ids=gen_outputs[:,:-1],
                  labels=gen_outputs[:,1:].contiguous(),
                 decoder_attention_mask = gen_masks[:,:-1])
    gen_loss = output[0]
    
    return align_loss, gen_loss
    

In [79]:
batch_loss(bart, graph_enc, batch)

(tensor(2.0582, grad_fn=<MeanBackward0>),
 tensor(37.6058, grad_fn=<NllLossBackward0>))

In [77]:
gen_masks.shape

torch.Size([4, 38])

In [84]:
graph_enc.parameters

<bound method Module.parameters of GraphEncoder(
  (node_embedding): Embedding(1469, 1024)
  (dropout): Dropout(p=0.3, inplace=False)
  (gnn): ListModule(
    (0): RGCNConv(1024, 1024, num_relations=108)
    (1): RGCNConv(1024, 1024, num_relations=108)
  )
)>