In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25ldone
[?25h  Created wheel for torch_geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910459 sha256=71eebd85d6cc60e3e5c5d8cf48a66f9cb1629becaa850e9dcbae208ebefbacf5
  Stored in directory: /root/.cache/pip/wheels/ac/dc/30/e2874821ff308ee67dcd7a66dbde912411e19e35a1addda028
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.3.1
[0m

In [5]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from transformers import T5Model, T5TokenizerFast
import networkx as nx
from torch_geometric.data import Data
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch_geometric.data import Batch
from transformers import T5Model
from torch_geometric.nn import GCNConv
from datasets import load_dataset
from torch_geometric.data import Data
from transformers import T5TokenizerFast
import networkx as nx
import torch
import re
import matplotlib.pyplot as plt

#model_name = "t5-small"
model_name = "google/flan-t5-small"


class WebNLGDataset(Dataset):
    def __init__(self, dataset, max_edges=512):
        self.dataset = dataset
        self.max_edges = max_edges
        self.prefix = "translate graph to text:"
        
        # Initiate the tokenizer
        self.tokenizer = T5TokenizerFast.from_pretrained(model_name)
        self.node_to_idx = {}  # Node to index mapping

        # Add special tokens
        new_tokens = ['<H>', '<R>', '<T>']
        new_tokens_vocab = {'additional_special_tokens': new_tokens}
        self.tokenizer.add_special_tokens(new_tokens_vocab)


    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data_dict = self.dataset[idx]
        
        triples = data_dict['original_triple_sets']['otriple_set'][0]
        target_text = data_dict['lex']['text'][0]

        # linearize the triples
        input_text = self.prefix
        for triple in triples:
            triple_txt = triple.split("|")
            input_text += " <H> " + triple_txt[0] + " <R> " + triple_txt[1] + " <T> " + triple_txt[2]
        
        # graph creation
        graph_nx = self.triples_to_graph(triples)
        edge_index = self.get_edge_index(graph_nx)

        # encoding input and target texts
        input_encoding = self.tokenizer.encode_plus(input_text, truncation=True, padding='max_length', max_length=512, return_tensors='pt')
        target_encoding = self.tokenizer.encode_plus(target_text, truncation=True, padding='max_length', max_length=512, return_tensors='pt')
        
        graph_data = Data(x=input_encoding['input_ids'].squeeze(dim=0), edge_index=edge_index)
        graph_data.attention_mask = input_encoding['attention_mask'].squeeze(dim=0)
        graph_data.y = target_encoding['input_ids'].squeeze(dim=0)
    
        '''print("Original Sample: \n", input_text, "\n", target_text)  # Print the original and target samples      
        print("NetworkX Graph: \n", graph_nx.edges) # Print the NetworkX graph
        self.visualize_graph(graph_nx)
        print("Decoded Input: \n", self.tokenizer.decode(graph_data.x, skip_special_tokens=True))
        print("Decoded Target: \n", self.tokenizer.decode(graph_data.y, skip_special_tokens=True))'''

        return graph_data

    def triples_to_graph(self, triples):
        self.node_to_idx = {}  # reset for each new graph
        graph_nx = nx.MultiDiGraph()
        for triple in triples:
            triple = re.sub(r'\([^)]*\)', '', triple).split('|')  # remove brackets and split by '|'
            subject, relation, obj = map(str.strip, triple)

            # Add string node names to the graph
            if subject not in self.node_to_idx:
                self.node_to_idx[subject] = len(self.node_to_idx)
            if obj not in self.node_to_idx:
                self.node_to_idx[obj] = len(self.node_to_idx)

            graph_nx.add_edge(subject, obj, key=relation)
        return graph_nx

    def get_edge_index(self, graph_nx):
        edge_index = torch.tensor([[self.node_to_idx[n] for n in edge[:2]] for edge in graph_nx.edges]).t().contiguous()
        return edge_index
    
    def visualize_graph(self, graph_nx):
        plt.figure(figsize=(8, 6))
        pos = nx.spring_layout(graph_nx)  # positions for all nodes
        nx.draw(graph_nx, pos, with_labels=True)
        labels = nx.get_edge_attributes(graph_nx, 'key')
        nx.draw_networkx_edge_labels(graph_nx, pos, edge_labels=labels)
        plt.show()


    
class AdapterBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AdapterBlock, self).__init__()

    def forward(self, x, edge_index):
        return x  # Return the input tensor as-is


from transformers import T5ForConditionalGeneration

class TransformerGCN(nn.Module):
    def __init__(self, vocab_size, adapter_dim):
        super(TransformerGCN, self).__init__()
        self.transformer = T5ForConditionalGeneration.from_pretrained(model_name)
        self.hidden_size = self.transformer.config.hidden_size
        #print('hidden size ', self.hidden_size)
        #self.reduce_dim = nn.Linear(32128, adapter_dim)

        # Freeze the parameters of the T5 model
        #for param in self.transformer.parameters():
         #   param.requires_grad = False

        self.adapter_blocks = nn.ModuleList([
        AdapterBlock(block.layer[1].DenseReluDense.wo.weight.size(0), adapter_dim) for block in self.transformer.encoder.block
        ])

        self.output_head = nn.Linear(adapter_dim, vocab_size)

    def forward(self, input_ids, attention_mask, edge_index):
        if input_ids.dim() == 1:  # If the input is 1D (batch size 1)
            #print('unsqueeze input_ids')
            input_ids = input_ids.unsqueeze(0)  # Add a batch dimension
        if attention_mask.dim() == 1:  # Same for the attention_mask
            #print('unsqueeze attention_mask')
            attention_mask = attention_mask.unsqueeze(0)

        #print(f"Input_ids shape: {input_ids.shape}")  # Check input_ids shape
        #print(f"Attention_mask shape: {attention_mask.shape}")  # Check attention_mask shape

        shifted_input_ids = torch.cat([torch.zeros((input_ids.size(0), 1), dtype=torch.long, device=input_ids.device), input_ids[:, :-1]], dim=-1)

        #print(f"Shifted_input_ids shape: {shifted_input_ids.shape}")  # Check shifted_input_ids shape

        input_embeds = self.transformer.get_input_embeddings()(input_ids)
        hidden_states = input_embeds

        for block, adapter_block in zip(self.transformer.encoder.block, self.adapter_blocks):
            hidden_states, _ = block(hidden_states, attention_mask=attention_mask, encoder_hidden_states=None, encoder_attention_mask=None)
            hidden_states = adapter_block(hidden_states, edge_index)
        
        #print(f"Hidden_states shape: {hidden_states.shape}")  # Check hidden_states shape

        transformer_outputs = self.transformer(inputs_embeds=hidden_states, attention_mask=attention_mask, decoder_input_ids=shifted_input_ids)
        #print(f"Transformer_outputs shape (before reduction): {transformer_outputs[0].shape}")  # Check transformer_outputs shape before dimension reduction

        #transformer_outputs = self.reduce_dim(transformer_outputs[0])
        #print(f"Transformer_outputs shape (after reduction): {transformer_outputs.shape}")  # Check transformer_outputs shape after dimension reduction

        #return self.output_head(transformer_outputs)
        return transformer_outputs


from torch_geometric.data import DataLoader as GeometricDataLoader

def train(model, dataloader, epochs, device):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # Lower learning rate
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') # Add learning rate scheduler
    tokenizer = T5TokenizerFast.from_pretrained(model_name)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    for epoch in range(epochs):
        model.train()
        i = 0
        for data in tqdm(dataloader):
            data = data.to(device) # Moving batch to device
            optimizer.zero_grad()

            outputs = model(input_ids=data.x, attention_mask=data.attention_mask, edge_index=data.edge_index)
            #loss = criterion(outputs.view(-1, outputs.size(-1)), data.y.view(-1))
            logits = outputs.logits  # Extract logits
            # Reshape the labels to match the logits
            labels = data.y.view(1, -1)  # The resulting shape is (1, 512)
            loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
            loss.backward()
            #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping
            optimizer.step()
            
            #scheduler.step(loss) # Update learning rate
            
            if i % 100 == 0:
                print(f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}, Learning Rate: {optimizer.param_groups[0]['lr']}") 
            i += 1


# Usage
dataset_dict = load_dataset('web_nlg', 'webnlg_challenge_2017')['train']
dataset = WebNLGDataset(dataset_dict)
vocab_size = len(dataset.tokenizer)
model = TransformerGCN(vocab_size=vocab_size, adapter_dim=512)
dataloader = GeometricDataLoader(dataset, batch_size=1)
#train(model, dataloader, epochs=1, device=torch.device('cuda'))


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
"""def test(model, dataloader, device):
    model = model.to(device)
    model.eval()
    
    with torch.no_grad():
        for data in dataloader:
            data = data.to(device)
            outputs = model(input_ids=data.x, attention_mask=data.attention_mask, edge_index=data.edge_index)
            
            # Get the predicted token ids by taking the argmax over the token dimension
            predicted_ids = outputs.argmax(-1)
            
            # Convert the tensor outputs to text using the tokenizer
            output_text = [dataset.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in predicted_ids]

            # print input and output
            print(f"Input: {dataset.tokenizer.decode(data.x, skip_special_tokens=True)}")
            print(f"Output: {output_text}")

# Usage
dataset_dict = load_dataset('web_nlg', 'webnlg_challenge_2017')['test']
dataset_dict = [sample for sample in dataset_dict if sample['lex']['text']] # filter out samples with empty targets 
dataset = WebNLGDataset(dataset_dict)
dataloader = GeometricDataLoader(dataset, batch_size=2)
test(model, dataloader, device=torch.device('cuda'))"""


NameError: name 'load_dataset' is not defined

In [6]:
def test(model, dataloader, device):
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        for data in dataloader:
            data = data.to(device)

            data.x = data.x.unsqueeze(0)  # Add batch dimension
            data.attention_mask = data.attention_mask.unsqueeze(0) 

            outputs = model.transformer.generate(input_ids=data.x, attention_mask=data.attention_mask, decoder_start_token_id=model.transformer.config.pad_token_id, max_length=512, num_beams=4, early_stopping=True)

            # Convert the tensor outputs to text using the tokenizer
            output_text = dataset.tokenizer.decode(outputs[0].tolist(), skip_special_tokens=True)
            target_text = dataset.tokenizer.decode(data.y.tolist(), skip_special_tokens=True)

            # print input and output
            print(f"Input: {dataset.tokenizer.decode(data.x[0].tolist(), skip_special_tokens=True)}")  # Convert tensor to list
            print(f"target: {target_text}")  # Convert tensor to list
            print(f"Output: {output_text}")

# Usage
dataset_dict = load_dataset('web_nlg', 'webnlg_challenge_2017')['test']
dataset_dict = [sample for sample in dataset_dict if sample['lex']['text']] # filter out samples with empty targets 
dataset = WebNLGDataset(dataset_dict)
dataloader = GeometricDataLoader(dataset, batch_size=1)
test(model, dataloader, device=torch.device('cuda'))


  0%|          | 0/3 [00:00<?, ?it/s]

Input: translate graph to text:  Aaron_S._Daggett  award  Purple_Heart
target: Aaron S Daggett was awarded the Purple Heart.
Output:          
Input: translate graph to text:  Aaron_S._Daggett  battle  Battle_of_Mine_Run
target: The Battle of Mine Run was one fought by Aaron S Daggett.
Output:        
Input: translate graph to text:  Ab_Klink  placeOfBirth  "Stellendam, Netherlands"@en
target: Stellendam, Netherlands is the birthplace of Ab Klink.
Output:  Ab_Klink  placeOfBirth  "Stellendam, Netherlands"@en
Input: translate graph to text:  Abdul_Rahman_Ya'kub  governor  Tuanku_Bujang_Tuanku_Othman
target: Abdul Rahman Ya'kub was in office while Tuanku Bujang Tuanku Othman was Vice President.
Output:       
Input: translate graph to text:  Abdul_Taib_Mahmud  party  "Parti Bumiputera Sarawak"@en
target: Abdul Taib Mahmud belongs to the party of Parti Bumiputera Sarawak.
Output:          
Input: translate graph to text:  Abdul_Taib_Mahmud  successor  Sulaiman_Abdul_Rahman_Taib
target: Ab