In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25ldone
[?25h  Created wheel for torch_geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910459 sha256=83f364f8a3b163e8198c762f1ed13fcb572dc862ac188563ef3e9c7c2cca9267
  Stored in directory: /root/.cache/pip/wheels/ac/dc/30/e2874821ff308ee67dcd7a66dbde912411e19e35a1addda028
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.3.1
[0m

In [2]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from transformers import T5Model, T5TokenizerFast
import networkx as nx
from torch_geometric.data import Data
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch_geometric.data import Batch
from transformers import T5Model
from torch_geometric.nn import GCNConv
from datasets import load_dataset
from torch_geometric.data import Data
from transformers import T5TokenizerFast
import networkx as nx
import torch
import re
import matplotlib.pyplot as plt

model_name = "t5-small"


class WebNLGDataset(Dataset):
    def __init__(self, dataset, max_edges=512):
        self.dataset = dataset
        self.tokenizer = T5TokenizerFast.from_pretrained(model_name)
        self.node_to_idx = {}  # Node to index mapping
        self.max_edges = max_edges

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data_dict = self.dataset[idx]
        text = data_dict['lex']['text'][0]
        triples = data_dict['original_triple_sets']['otriple_set'][0]

        graph_nx = self.triples_to_graph(triples)
        edge_index = self.get_edge_index(graph_nx)

        encoding = self.tokenizer.encode_plus(text, truncation=True, padding='max_length', max_length=512, return_tensors='pt')

        graph_data = Data(x=encoding['input_ids'].squeeze(dim=0), edge_index=edge_index)
        graph_data.attention_mask = encoding['attention_mask'].squeeze(dim=0)
        graph_data.y = encoding['input_ids'].squeeze(dim=0)
        
        #print("Original Sample: \n", text, "\n", triples)  # Print the original sample      
        #print("NetworkX Graph: \n", graph_nx.edges) # Print the NetworkX graph
        #self.visualize_graph(graph_nx)

        return graph_data

    def triples_to_graph(self, triples):
        self.node_to_idx = {}  # reset for each new graph
        graph_nx = nx.MultiDiGraph()
        for triple in triples:
            triple = re.sub(r'\([^)]*\)', '', triple).split('|')  # remove brackets and split by '|'
            subject, relation, obj = map(str.strip, triple)

            # Add string node names to the graph
            if subject not in self.node_to_idx:
                self.node_to_idx[subject] = len(self.node_to_idx)
            if obj not in self.node_to_idx:
                self.node_to_idx[obj] = len(self.node_to_idx)

            graph_nx.add_edge(subject, obj, key=relation)
        return graph_nx

    def get_edge_index(self, graph_nx):
        edge_index = torch.tensor([[self.node_to_idx[n] for n in edge[:2]] for edge in graph_nx.edges]).t().contiguous()
        return edge_index
    
    def visualize_graph(self, graph_nx):
        plt.figure(figsize=(8, 6))
        pos = nx.spring_layout(graph_nx)  # positions for all nodes
        nx.draw(graph_nx, pos, with_labels=True)
        labels = nx.get_edge_attributes(graph_nx, 'key')
        nx.draw_networkx_edge_labels(graph_nx, pos, edge_labels=labels)
        plt.show()

    
class AdapterBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AdapterBlock, self).__init__()

    def forward(self, x, edge_index):
        return x  # Return the input tensor as-is


class TransformerGCN(nn.Module):
    def __init__(self, vocab_size, adapter_dim):
        super(TransformerGCN, self).__init__()
        self.transformer = T5Model.from_pretrained(model_name)
        self.hidden_size = self.transformer.config.hidden_size  # Get the hidden size from the config
        self.reduce_dim = nn.Linear(self.hidden_size, adapter_dim)  # Use the hidden size instead of 768

        # Freeze the parameters of the T5 model
      #  for param in self.transformer.parameters():
      #      param.requires_grad = False

        self.adapter_blocks = nn.ModuleList([
        AdapterBlock(block.layer[1].DenseReluDense.wo.weight.size(0), adapter_dim) for block in self.transformer.encoder.block
        ])

        self.output_head = nn.Linear(adapter_dim, vocab_size)

    def forward(self, input_ids, attention_mask, edge_index):
        if input_ids.dim() == 1:  # If the input is 1D (batch size 1)
            input_ids = input_ids.unsqueeze(0)  # Add a batch dimension
        if attention_mask.dim() == 1:  # Same for the attention_mask
            attention_mask = attention_mask.unsqueeze(0)

        shifted_input_ids = torch.cat([torch.zeros((input_ids.size(0), 1), dtype=torch.long, device=input_ids.device), input_ids[:, :-1]], dim=-1)

        input_embeds = self.transformer.get_input_embeddings()(input_ids)
        hidden_states = input_embeds
        for block, adapter_block in zip(self.transformer.encoder.block, self.adapter_blocks):
            hidden_states, _ = block(hidden_states, attention_mask=attention_mask, encoder_hidden_states=None, encoder_attention_mask=None)
            hidden_states = adapter_block(hidden_states, edge_index)

        transformer_outputs = self.transformer(inputs_embeds=hidden_states, attention_mask=attention_mask, decoder_input_ids=shifted_input_ids)
        transformer_outputs = self.reduce_dim(transformer_outputs.last_hidden_state)
        return self.output_head(transformer_outputs)


class ModifiedT5Block(nn.Module):
    def __init__(self, original_block, adapter_dim):
        super(ModifiedT5Block, self).__init__()
        self.original_block = original_block
        self.adapter = AdapterBlock(original_block.layer[1].DenseReluDense.wi.weight.size(-1), adapter_dim)

    def forward(self, x, edge_index, **kwargs):
        x, _ = self.original_block(x, **kwargs)
        return self.adapter(x, edge_index)



from torch_geometric.data import DataLoader as GeometricDataLoader

def train(model, dataloader, epochs, device):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters())
    tokenizer = T5TokenizerFast.from_pretrained(model_name)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    for epoch in range(epochs):
        model.train()
        for data in tqdm(dataloader):
            data = data.to(device) # Moving batch to device
            optimizer.zero_grad()

            outputs = model(input_ids=data.x, attention_mask=data.attention_mask, edge_index=data.edge_index)
            loss = criterion(outputs.view(-1, outputs.size(-1)), data.y.view(-1))
            loss.backward()
            optimizer.step()

# Usage
dataset_dict = load_dataset('web_nlg', 'webnlg_challenge_2017')['train']
dataset = WebNLGDataset(dataset_dict)
vocab_size = len(dataset.tokenizer)
model = TransformerGCN(vocab_size=vocab_size, adapter_dim=512)
dataloader = GeometricDataLoader(dataset, batch_size=4)
train(model, dataloader, epochs=4, device=torch.device('cuda'))


caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


Downloading builder script:   0%|          | 0.00/3.51k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.11k [00:00<?, ?B/s]

Downloading and preparing dataset web_nlg/webnlg_challenge_2017 (download: 24.32 MiB, generated: 8.99 MiB, post-processed: Unknown size, total: 33.31 MiB) to /root/.cache/huggingface/datasets/web_nlg/webnlg_challenge_2017/0.0.0/28ffb892f7f42450dd9558684aa43bcaf44b1b3bf0d77cb8d73534646af88dda...


Downloading data: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/6940 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/4615 [00:00<?, ? examples/s]

Dataset web_nlg downloaded and prepared to /root/.cache/huggingface/datasets/web_nlg/webnlg_challenge_2017/0.0.0/28ffb892f7f42450dd9558684aa43bcaf44b1b3bf0d77cb8d73534646af88dda. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

100%|██████████| 1735/1735 [16:14<00:00,  1.78it/s]
100%|██████████| 1735/1735 [16:13<00:00,  1.78it/s]
100%|██████████| 1735/1735 [16:13<00:00,  1.78it/s]
100%|██████████| 1735/1735 [16:12<00:00,  1.78it/s]


In [3]:
def test(model, dataloader, device):
    model = model.to(device)
    model.eval()
    
    with torch.no_grad():
        for data in dataloader:
            data = data.to(device)
            outputs = model(input_ids=data.x, attention_mask=data.attention_mask, edge_index=data.edge_index)
            
            # Get the predicted token ids by taking the argmax over the token dimension
            predicted_ids = outputs.argmax(-1)
            
            # Convert the tensor outputs to text using the tokenizer
            output_text = [dataset.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in predicted_ids]

            # print input and output
            print(f"Input: {dataset.tokenizer.decode(data.x, skip_special_tokens=True)}")
            print(f"Output: {output_text}")

# Usage
dataset_dict = load_dataset('web_nlg', 'webnlg_challenge_2017')['test']
dataset_dict = [sample for sample in dataset_dict if sample['lex']['text']] # filter out samples with empty targets 
dataset = WebNLGDataset(dataset_dict)
dataloader = GeometricDataLoader(dataset, batch_size=2)
test(model, dataloader, device=torch.device('cuda'))


  0%|          | 0/3 [00:00<?, ?it/s]

Input: Aaron S Daggett was awarded the Purple Heart. The Battle of Mine Run was one fought by Aaron S Daggett.
Output: ['William Adlayv wast was born the ville with He William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William William Buzz Buzz William Willia