In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25ldone
[?25h  Created wheel for torch_geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910459 sha256=b140590f97b5a32637bb66b88a37f33d92f778069a92823c498c89de44e27ce6
  Stored in directory: /root/.cache/pip/wheels/ac/dc/30/e2874821ff308ee67dcd7a66dbde912411e19e35a1addda028
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.3.1
[0m

In [16]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')
model = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')
input_ids = tokenizer.encode("answer: hello, how are you", return_tensors="pt")  # Batch size 1
outputs = model.generate(input_ids)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs])

['good']


In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from transformers import T5Model, T5TokenizerFast
import networkx as nx
from torch_geometric.data import Data
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch_geometric.data import Batch
from transformers import T5Model
from torch_geometric.nn import GCNConv
from datasets import load_dataset
from torch_geometric.data import Data
from transformers import T5TokenizerFast
import networkx as nx
import torch
import re
import matplotlib.pyplot as plt

#model_name = "t5-small"
model_name = "google/flan-t5-small"


class WebNLGDataset(Dataset):
    def __init__(self, dataset, max_edges=512):
        self.dataset = dataset
        self.tokenizer = T5TokenizerFast.from_pretrained(model_name)
        self.node_to_idx = {}  # Node to index mapping
        self.max_edges = max_edges

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data_dict = self.dataset[idx]
        text = data_dict['lex']['text'][0]
        triples = data_dict['original_triple_sets']['otriple_set'][0]

        graph_nx = self.triples_to_graph(triples)
        edge_index = self.get_edge_index(graph_nx)

        encoding = self.tokenizer.encode_plus(text, truncation=True, padding='max_length', max_length=512, return_tensors='pt')

        graph_data = Data(x=encoding['input_ids'].squeeze(dim=0), edge_index=edge_index)
        graph_data.attention_mask = encoding['attention_mask'].squeeze(dim=0)
        graph_data.y = encoding['input_ids'].squeeze(dim=0)
        
        #print("Original Sample: \n", text, "\n", triples)  # Print the original sample      
        #print("NetworkX Graph: \n", graph_nx.edges) # Print the NetworkX graph
        #self.visualize_graph(graph_nx)

        return graph_data

    def triples_to_graph(self, triples):
        self.node_to_idx = {}  # reset for each new graph
        graph_nx = nx.MultiDiGraph()
        for triple in triples:
            triple = re.sub(r'\([^)]*\)', '', triple).split('|')  # remove brackets and split by '|'
            subject, relation, obj = map(str.strip, triple)

            # Add string node names to the graph
            if subject not in self.node_to_idx:
                self.node_to_idx[subject] = len(self.node_to_idx)
            if obj not in self.node_to_idx:
                self.node_to_idx[obj] = len(self.node_to_idx)

            graph_nx.add_edge(subject, obj, key=relation)
        return graph_nx

    def get_edge_index(self, graph_nx):
        edge_index = torch.tensor([[self.node_to_idx[n] for n in edge[:2]] for edge in graph_nx.edges]).t().contiguous()
        return edge_index
    
    def visualize_graph(self, graph_nx):
        plt.figure(figsize=(8, 6))
        pos = nx.spring_layout(graph_nx)  # positions for all nodes
        nx.draw(graph_nx, pos, with_labels=True)
        labels = nx.get_edge_attributes(graph_nx, 'key')
        nx.draw_networkx_edge_labels(graph_nx, pos, edge_labels=labels)
        plt.show()

    
class AdapterBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AdapterBlock, self).__init__()
        self.layer_norm = nn.LayerNorm(input_dim, eps=1e-6)
        self.gcn = GCNConv(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(hidden_dim, input_dim)  # Restore dimension to input_dim
        self.res_fc = nn.Linear(input_dim, input_dim)  # Transform the residual

    def forward(self, x, edge_index):
        x_res = self.res_fc(x)  # Transform the residual tensor
        x = self.layer_norm(x)
        x = self.gcn(x, edge_index)
        x = self.relu(x)
        x = self.fc(x)
        return x + x_res  # Now you can add the tensors


from transformers import T5ForConditionalGeneration

class TransformerGCN(nn.Module):
    def __init__(self, vocab_size, adapter_dim):
        super(TransformerGCN, self).__init__()
        self.transformer = T5ForConditionalGeneration.from_pretrained(model_name)
        self.hidden_size = self.transformer.config.hidden_size  # Get the hidden size from the config
        self.reduce_dim = nn.Linear(32128, adapter_dim)  # Use the hidden size instead of 768

        # Freeze the parameters of the T5 model
        for param in self.transformer.parameters():
            param.requires_grad = False

        self.adapter_blocks = nn.ModuleList([
        AdapterBlock(block.layer[1].DenseReluDense.wo.weight.size(0), adapter_dim) for block in self.transformer.encoder.block
        ])

        self.output_head = nn.Linear(adapter_dim, vocab_size)

    def forward(self, input_ids, attention_mask, edge_index):
        if input_ids.dim() == 1:  # If the input is 1D (batch size 1)
            input_ids = input_ids.unsqueeze(0)  # Add a batch dimension
        if attention_mask.dim() == 1:  # Same for the attention_mask
            attention_mask = attention_mask.unsqueeze(0)

        shifted_input_ids = torch.cat([torch.zeros((input_ids.size(0), 1), dtype=torch.long, device=input_ids.device), input_ids[:, :-1]], dim=-1)

        input_embeds = self.transformer.get_input_embeddings()(input_ids)
        hidden_states = input_embeds
        for block, adapter_block in zip(self.transformer.encoder.block, self.adapter_blocks):
            hidden_states, _ = block(hidden_states, attention_mask=attention_mask, encoder_hidden_states=None, encoder_attention_mask=None)
            hidden_states = adapter_block(hidden_states, edge_index)

        transformer_outputs = self.transformer(inputs_embeds=hidden_states, attention_mask=attention_mask, decoder_input_ids=shifted_input_ids)
        transformer_outputs = self.reduce_dim(transformer_outputs[0])
        return self.output_head(transformer_outputs)


class ModifiedT5Block(nn.Module):
    def __init__(self, original_block, adapter_dim):
        super(ModifiedT5Block, self).__init__()
        self.original_block = original_block
        self.adapter = AdapterBlock(original_block.layer[1].DenseReluDense.wi.weight.size(-1), adapter_dim)

    def forward(self, x, edge_index, **kwargs):
        x, _ = self.original_block(x, **kwargs)
        return self.adapter(x, edge_index)



from torch_geometric.data import DataLoader as GeometricDataLoader

def train(model, dataloader, epochs, device):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Lower learning rate
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') # Add learning rate scheduler
    tokenizer = T5TokenizerFast.from_pretrained(model_name)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    for epoch in range(epochs):
        model.train()
        i = 0
        for data in tqdm(dataloader):
            data = data.to(device) # Moving batch to device
            optimizer.zero_grad()

            outputs = model(input_ids=data.x, attention_mask=data.attention_mask, edge_index=data.edge_index)
            loss = criterion(outputs.view(-1, outputs.size(-1)), data.y.view(-1))
            loss.backward()
            optimizer.step()
            
            if i % 100 == 0:
                print(f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}")
            
            i += 1 

# Usage
dataset_dict = load_dataset('web_nlg', 'webnlg_challenge_2017')['train']
dataset = WebNLGDataset(dataset_dict)
vocab_size = len(dataset.tokenizer)
model = TransformerGCN(vocab_size=vocab_size, adapter_dim=512)
dataloader = GeometricDataLoader(dataset, batch_size=2)
train(model, dataloader, epochs=2, device=torch.device('cuda'))


caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


Downloading builder script:   0%|          | 0.00/3.51k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.11k [00:00<?, ?B/s]

Downloading and preparing dataset web_nlg/webnlg_challenge_2017 (download: 24.32 MiB, generated: 8.99 MiB, post-processed: Unknown size, total: 33.31 MiB) to /root/.cache/huggingface/datasets/web_nlg/webnlg_challenge_2017/0.0.0/28ffb892f7f42450dd9558684aa43bcaf44b1b3bf0d77cb8d73534646af88dda...


Downloading data: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/6940 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/4615 [00:00<?, ? examples/s]

Dataset web_nlg downloaded and prepared to /root/.cache/huggingface/datasets/web_nlg/webnlg_challenge_2017/0.0.0/28ffb892f7f42450dd9558684aa43bcaf44b1b3bf0d77cb8d73534646af88dda. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/308M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

  0%|          | 2/3470 [00:01<48:09,  1.20it/s]  

Epoch: 0, Iteration: 0, Loss: 26.284656524658203


  3%|▎         | 102/3470 [00:21<10:46,  5.21it/s]

Epoch: 0, Iteration: 100, Loss: 8.19456672668457


  6%|▌         | 202/3470 [00:40<10:26,  5.21it/s]

Epoch: 0, Iteration: 200, Loss: 8.687564849853516


  9%|▊         | 302/3470 [01:00<10:06,  5.23it/s]

Epoch: 0, Iteration: 300, Loss: 6.294469833374023


 12%|█▏        | 402/3470 [01:19<09:54,  5.16it/s]

Epoch: 0, Iteration: 400, Loss: 4.097095489501953


 14%|█▍        | 502/3470 [01:39<09:43,  5.09it/s]

Epoch: 0, Iteration: 500, Loss: 5.831324577331543


 17%|█▋        | 602/3470 [01:58<09:09,  5.22it/s]

Epoch: 0, Iteration: 600, Loss: 6.2011566162109375


 20%|██        | 702/3470 [02:17<08:50,  5.22it/s]

Epoch: 0, Iteration: 700, Loss: 7.284809112548828


 23%|██▎       | 802/3470 [02:37<08:30,  5.23it/s]

Epoch: 0, Iteration: 800, Loss: 8.466331481933594


 26%|██▌       | 902/3470 [02:56<08:12,  5.22it/s]

Epoch: 0, Iteration: 900, Loss: 6.41180419921875


 29%|██▉       | 1002/3470 [03:16<07:51,  5.23it/s]

Epoch: 0, Iteration: 1000, Loss: 6.61683464050293


 32%|███▏      | 1102/3470 [03:35<07:33,  5.22it/s]

Epoch: 0, Iteration: 1100, Loss: 6.747409820556641


 35%|███▍      | 1202/3470 [03:55<07:14,  5.22it/s]

Epoch: 0, Iteration: 1200, Loss: 4.775228023529053


 38%|███▊      | 1302/3470 [04:14<06:55,  5.21it/s]

Epoch: 0, Iteration: 1300, Loss: 5.036314964294434


 40%|████      | 1402/3470 [04:34<06:38,  5.20it/s]

Epoch: 0, Iteration: 1400, Loss: 5.78169059753418


 43%|████▎     | 1502/3470 [04:53<06:17,  5.22it/s]

Epoch: 0, Iteration: 1500, Loss: 5.161454677581787


 46%|████▌     | 1602/3470 [05:12<05:56,  5.24it/s]

Epoch: 0, Iteration: 1600, Loss: 4.452596664428711


 49%|████▉     | 1702/3470 [05:32<05:39,  5.20it/s]

Epoch: 0, Iteration: 1700, Loss: 5.453591823577881


 52%|█████▏    | 1802/3470 [05:51<05:19,  5.21it/s]

Epoch: 0, Iteration: 1800, Loss: 3.9197356700897217


 55%|█████▍    | 1902/3470 [06:11<05:00,  5.22it/s]

Epoch: 0, Iteration: 1900, Loss: 4.28248405456543


 58%|█████▊    | 2002/3470 [06:30<04:40,  5.23it/s]

Epoch: 0, Iteration: 2000, Loss: 5.168229103088379


 61%|██████    | 2102/3470 [06:50<04:22,  5.22it/s]

Epoch: 0, Iteration: 2100, Loss: 3.1320719718933105


 63%|██████▎   | 2202/3470 [07:09<04:04,  5.18it/s]

Epoch: 0, Iteration: 2200, Loss: 4.894715785980225


 66%|██████▋   | 2302/3470 [07:28<03:44,  5.20it/s]

Epoch: 0, Iteration: 2300, Loss: 4.2428507804870605


 69%|██████▉   | 2402/3470 [07:48<03:25,  5.20it/s]

Epoch: 0, Iteration: 2400, Loss: 4.1129631996154785


 72%|███████▏  | 2502/3470 [08:07<03:04,  5.24it/s]

Epoch: 0, Iteration: 2500, Loss: 4.916940689086914


 75%|███████▍  | 2602/3470 [08:27<02:46,  5.22it/s]

Epoch: 0, Iteration: 2600, Loss: 5.902958393096924


 78%|███████▊  | 2702/3470 [08:46<02:26,  5.24it/s]

Epoch: 0, Iteration: 2700, Loss: 5.421992778778076


 81%|████████  | 2802/3470 [09:06<02:08,  5.20it/s]

Epoch: 0, Iteration: 2800, Loss: 3.4193310737609863


 84%|████████▎ | 2902/3470 [09:25<01:49,  5.20it/s]

Epoch: 0, Iteration: 2900, Loss: 4.30832052230835


 87%|████████▋ | 3002/3470 [09:45<01:29,  5.21it/s]

Epoch: 0, Iteration: 3000, Loss: 4.683661937713623


 89%|████████▉ | 3102/3470 [10:04<01:10,  5.20it/s]

Epoch: 0, Iteration: 3100, Loss: 5.6105122566223145


 92%|█████████▏| 3202/3470 [10:24<00:51,  5.23it/s]

Epoch: 0, Iteration: 3200, Loss: 5.874905109405518


 95%|█████████▌| 3302/3470 [10:43<00:32,  5.22it/s]

Epoch: 0, Iteration: 3300, Loss: 2.9546291828155518


 98%|█████████▊| 3402/3470 [11:02<00:13,  5.21it/s]

Epoch: 0, Iteration: 3400, Loss: 2.3941736221313477


100%|██████████| 3470/3470 [11:16<00:00,  5.13it/s]
  0%|          | 2/3470 [00:00<10:52,  5.31it/s]

Epoch: 1, Iteration: 0, Loss: 2.8929851055145264


  3%|▎         | 102/3470 [00:19<10:44,  5.23it/s]

Epoch: 1, Iteration: 100, Loss: 5.401103496551514


  6%|▌         | 202/3470 [00:39<10:26,  5.22it/s]

Epoch: 1, Iteration: 200, Loss: 7.0125412940979


  9%|▊         | 302/3470 [00:58<10:10,  5.19it/s]

Epoch: 1, Iteration: 300, Loss: 3.633180856704712


 12%|█▏        | 402/3470 [01:18<09:44,  5.25it/s]

Epoch: 1, Iteration: 400, Loss: 3.374939203262329


 14%|█▍        | 502/3470 [01:37<09:26,  5.23it/s]

Epoch: 1, Iteration: 500, Loss: 6.556642055511475


 17%|█▋        | 602/3470 [01:57<09:11,  5.20it/s]

Epoch: 1, Iteration: 600, Loss: 6.159722805023193


 20%|██        | 702/3470 [02:16<08:58,  5.14it/s]

Epoch: 1, Iteration: 700, Loss: 5.892805576324463


 23%|██▎       | 802/3470 [02:36<08:31,  5.21it/s]

Epoch: 1, Iteration: 800, Loss: 6.814853191375732


 26%|██▌       | 902/3470 [02:55<08:13,  5.20it/s]

Epoch: 1, Iteration: 900, Loss: 6.195101261138916


 29%|██▉       | 1002/3470 [03:15<08:04,  5.10it/s]

Epoch: 1, Iteration: 1000, Loss: 5.049078941345215


 32%|███▏      | 1102/3470 [03:34<07:35,  5.20it/s]

Epoch: 1, Iteration: 1100, Loss: 7.34785795211792


 35%|███▍      | 1202/3470 [03:54<07:15,  5.21it/s]

Epoch: 1, Iteration: 1200, Loss: 4.701527118682861


 38%|███▊      | 1302/3470 [04:13<06:57,  5.19it/s]

Epoch: 1, Iteration: 1300, Loss: 4.832247734069824


 40%|████      | 1402/3470 [04:33<06:41,  5.15it/s]

Epoch: 1, Iteration: 1400, Loss: 3.7869534492492676


 43%|████▎     | 1502/3470 [04:52<06:22,  5.14it/s]

Epoch: 1, Iteration: 1500, Loss: 4.177012920379639


 46%|████▌     | 1602/3470 [05:12<05:59,  5.20it/s]

Epoch: 1, Iteration: 1600, Loss: 3.381134510040283


 49%|████▉     | 1702/3470 [05:31<05:39,  5.21it/s]

Epoch: 1, Iteration: 1700, Loss: 4.6269378662109375


 52%|█████▏    | 1802/3470 [05:51<05:22,  5.18it/s]

Epoch: 1, Iteration: 1800, Loss: 3.7507476806640625


 55%|█████▍    | 1902/3470 [06:10<05:00,  5.22it/s]

Epoch: 1, Iteration: 1900, Loss: 3.545541524887085


 58%|█████▊    | 2002/3470 [06:30<04:40,  5.23it/s]

Epoch: 1, Iteration: 2000, Loss: 4.112112522125244


 61%|██████    | 2102/3470 [06:49<04:23,  5.20it/s]

Epoch: 1, Iteration: 2100, Loss: 2.4518654346466064


 63%|██████▎   | 2202/3470 [07:09<04:02,  5.23it/s]

Epoch: 1, Iteration: 2200, Loss: 3.874729633331299


 66%|██████▋   | 2302/3470 [07:28<03:49,  5.09it/s]

Epoch: 1, Iteration: 2300, Loss: 3.923366069793701


 68%|██████▊   | 2356/3470 [07:39<03:36,  5.15it/s]

In [3]:
def test(model, dataloader, device):
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        for data in dataloader:
            data = data.to(device) 

            data.x = data.x.unsqueeze(0)  # Add batch dimension
            data.attention_mask = data.attention_mask.unsqueeze(0) 

            outputs = model.transformer.generate(input_ids=data.x, attention_mask=data.attention_mask, decoder_start_token_id=model.transformer.config.pad_token_id)

            # Convert the tensor outputs to text using the tokenizer
            output_text = [dataset.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs]

            # print input and output
            print(f"Input: {dataset.tokenizer.decode(data.x[0].tolist(), skip_special_tokens=True)}")  # Convert tensor to list
            print(f"Output: {output_text}")

# Usage
dataset_dict = load_dataset('web_nlg', 'webnlg_challenge_2017')['test']
dataset_dict = [sample for sample in dataset_dict if sample['lex']['text']] # filter out samples with empty targets 
dataset = WebNLGDataset(dataset_dict)
dataloader = GeometricDataLoader(dataset, batch_size=5)
test(model, dataloader, device=torch.device('cuda'))


  0%|          | 0/3 [00:00<?, ?it/s]



Input: Aaron S Daggett was awarded the Purple Heart. The Battle of Mine Run was one fought by Aaron S Daggett. Stellendam, Netherlands is the birthplace of Ab Klink. Abdul Rahman Ya'kub was in office while Tuanku Bujang Tuanku Othman was Vice President. Abdul Taib Mahmud belongs to the party of Parti Bumiputera Sarawak.
Output: ["Ab Klink wurde von Abdul Rahman Ya'kub in Amt angewart"]
Input: Abdul Taib Mahmud's successor was Sulaiman Abdul Rahman Taib. Abdulsalami Abubakar ended his career on 1999-05-29. Abdulsalami Abubakar was born in Minna. Abdulsalami Abubakar's birthplace was Niger State. Abner W. Sibal ended his military career January 3, 1965.
Output: ['Abdulsalami Abubakar hat seine militärische Karriere Ende Ende 1999']
Input: Abner W Sibal died in Alexandria, Virginia. Adam Holloway was born in Kent. Adam Holloway's residence is Gravesend. The alma mater of Adenan Satem is the University of Adelaide. Adolf Schärf's place of birth was Mikulov.
Output: ['Adam Holloway wurde in

KeyboardInterrupt: 

In [8]:
!pip install sacrebleu 

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting sacrebleu
  Downloading sacrebleu-2.3.1-py3-none-any.whl (118 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m118.9/118.9 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-2.7.0-py2.py3-none-any.whl (15 kB)
Installing collected packages: portalocker, sacrebleu
Successfully installed portalocker-2.7.0 sacrebleu-2.3.1
[0m

In [33]:
from sacrebleu import corpus_bleu
from random import sample
from tqdm import tqdm
from torch_geometric.data import DataLoader as GeometricDataLoader

# load the WebNLG validation dataset
validation_dataset = load_dataset('web_nlg', 'webnlg_challenge_2017')['test']
validation_dataset = [sample for sample in validation_dataset if sample['lex']['text']]  # filter out samples with empty targets

validation_data = WebNLGDataset(validation_dataset)

# set up the validation data loader
validation_loader = GeometricDataLoader(validation_data, batch_size=1, shuffle=False)

# switch model to evaluation mode
model.eval()

device = 'cuda'

# generate predictions for the validation dataset
predictions = []
references = []
with torch.no_grad():
    for data in tqdm(validation_loader, desc='Validation Progress', leave=False):
        data.x = data.x.to(device).unsqueeze(0)  # add batch dimension
        data.attention_mask = data.attention_mask.to(device).unsqueeze(0)  # add batch dimension

        outputs = model.transformer.generate(input_ids=data.x, attention_mask=data.attention_mask, decoder_start_token_id=model.transformer.config.pad_token_id)
        # convert token IDs to strings
        predicted_texts = dataset.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        target_texts = dataset.tokenizer.batch_decode(data.y, skip_special_tokens=True)
        # append predicted and target texts for BLEU evaluation
        predictions.extend(predicted_texts)
        references.extend(target_texts)


  0%|          | 0/3 [00:00<?, ?it/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
                                                                        

In [35]:
# calculate BLEU scores
#bleu = corpus_bleu(predictions, [references])

multiple_references = []
for i in range(len(validation_dataset)):
    multiple_references.append(validation_dataset[i]['lex']['text'])
    
#bleu = corpus_bleu(predictions, references)
bleu_multiple = corpus_bleu(predictions, multiple_references)

#print(f"BLEU score: {bleu.score}")
print(f"BLEU score with multiple references: {bleu_multiple.score}")

BLEU score with multiple references: 100.00000000000004


In [36]:
len(predictions)

2753

In [37]:
len(validation_dataset)

2753

In [40]:
i=1860
print(validation_dataset[i])
print('------------------------')
print(validation_dataset[i]['original_triple_sets']['otriple_set'])
print('------------------------')
print(predictions[i])
print('------------------------')
print(multiple_references[i])

{'category': 'Astronaut', 'size': 7, 'eid': 'Id970', 'original_triple_sets': {'otriple_set': [['William_Anders | dateOfRet | "1969-09-01"^^xsd:date', 'William_Anders | mission | Apollo_8', 'William_Anders | nationality | United_States', 'William_Anders | birthPlace | British_Hong_Kong', 'Apollo_8 | crew2Up | Buzz_Aldrin', 'Apollo_8 | crewMembers | Frank_Borman', 'Apollo_8 | operator | NASA']]}, 'modified_triple_sets': {'mtriple_set': [['William_Anders | dateOfRetirement | "1969-09-01"', 'William_Anders | was a crew member of | Apollo_8', 'William_Anders | nationality | United_States', 'William_Anders | birthPlace | British_Hong_Kong', 'Apollo_8 | backup pilot | Buzz_Aldrin', 'Apollo_8 | crewMembers | Frank_Borman', 'Apollo_8 | operator | NASA']]}, 'shape': '', 'shape_type': '', 'lex': {'comment': ['good', 'good', 'good'], 'lid': ['Id1', 'Id2', 'Id3'], 'text': ["William Anders was born in British Hong Kong and is a U.S Citizen. William was a member of the Apollo 8 crew (along with Frank

In [None]:
# Print the postprocessed output for the first item in the dataset
dataset[-3]