In [13]:
import json
import os

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from transformers import T5Tokenizer, MT5ForConditionalGeneration, AdamW, get_linear_schedule_with_warmup

# Google's Official Preprocess Codes
# https://github.com/google-research/language/blob/master/language/totto/baseline_preprocessing/preprocess_utils.py
from preprocess_utils import get_highlighted_subtable, linearize_subtable

In [4]:
# Train Config
device=torch.device('cpu')
lr=3e-1
batch_size=8 # 4(max 6) for 't5-large' and make 'accumulation_steps' larger
accumulation_steps=1
epochs=20

# Prompt Config
prompt_len=100
hidden_dim=768

In [22]:
# Pre-Trained T5 Tokenizer
tokenizer=T5Tokenizer.from_pretrained('google/mt5-base')
# Add Special Tokens: Table Tags
tokenizer.add_special_tokens({
    'additional_special_tokens': [
        '|',
        ':',
    ]
})
# Pre-Trained T5 Model
pretrained = MT5ForConditionalGeneration.from_pretrained('google/mt5-base').to(device)
# Resize PLM's Embedding Layer
pretrained.resize_token_embeddings(len(tokenizer))
# Freeze LM
for param in pretrained.parameters():
    param.requires_grad=False

In [23]:
class WebNLGDataset(Dataset):
    
    def __init__(self, tokenizer, raw_path='../webnlg_data/release_v3.0/ru', language='en', data_path='../webnlg_data/preprocessed', split='train'):
        
        if not os.path.exists(f'{data_path}/{split}.json'):
            b = Benchmark()
            files = select_files(raw_path)
            b.fill_benchmark(files)
            b.b2json(data_path, f'{split}.json')
        
        with open(f'{data_path}/{split}.json', 'r') as f:
            dataset = json.load(f)
            entries = dataset['entries']

        full_rela_lst = []
        full_src_lst = []
        full_tgt_lst = []
        for i, entry in enumerate(entries):
            sents = entry[str(i + 1)]['lexicalisations']
            triples = entry[str(i + 1)]['modifiedtripleset']
            
            rela_lst = []
            temp_triples = ''
            for j, tripleset in enumerate(triples):
                subj, rela, obj = tripleset['subject'], tripleset['property'], tripleset['object']
                rela_lst.append(rela)
                temp_triples += ' | '
                temp_triples += '{} : {} : {}'.format(subj, rela, obj)

            for sent in sents:
                if sent["lang"] == language:
                    full_tgt_lst.append(sent["lex"])
                    full_src_lst.append(temp_triples)
                    full_rela_lst.append(rela_lst)
                    if split == 'dev':
                        break
            
        assert len(full_rela_lst) == len(full_src_lst)
        assert len(full_rela_lst) == len(full_tgt_lst)

        self.examples = []
        self.targets = []
        for src, tgt in zip(full_src_lst, full_tgt_lst):
            print(src)
            src = tokenizer.encode(src)
            if len(src)>512:
                # Truncate
                encoded = src[:511] + [tokenizer.eos_token_id]
            self.examples.append(src)
    
            print(tgt)
            tgt = tokenizer.encode(tgt)
            self.targets.append(tgt)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx], self.targets[idx]


In [24]:
class ToTToDataset(Dataset):
    def __init__(self, path_data, tokenizer):
        #
        self.data=[]
        self.label=[]
        
        # Load Dataset
        with open(path_data, 'r') as f:
            dataset=f.read().splitlines()
            f.close()
            
        for _data in dataset:
            data=json.loads(_data)
            
            # Preprocess
            subtable=get_highlighted_subtable(table=data['table'], cell_indices=data['highlighted_cells'], with_heuristic_headers=True)
            cells_linearized=linearize_subtable(
                subtable=subtable,
                table_page_title=data['table_page_title'],
                table_section_title=data['table_section_title']
            )
            
            # Encode
            print(cells_linearized)
            encoded=tokenizer.encode(cells_linearized)
            if len(encoded)>512:
                # Truncate
                encoded=encoded[:511]+[tokenizer.eos_token_id]
            self.data.append(encoded)
            print(data['sentence_annotations'][0]['final_sentence'])
            self.label.append(tokenizer.encode(data['sentence_annotations'][0]['final_sentence']))
            
        print(len(self.data), 'datas')
        print(len(self.label), 'labels')
        
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]
        
    def __len__(self):
        return len(self.data)

In [27]:
def collate_fn(batch):
    """
    Same Sequence Length on Same Batch
    """
    max_len_data=0
    max_len_label=0
    for data, label in batch:
        if len(data)>max_len_data: max_len_data=len(data)
        if len(label)>max_len_label: max_len_label=len(label)
            
    datas=[]
    attn_masks=[]
    labels=[]
    for data, label in batch:
        data.extend([tokenizer.pad_token_id]*(max_len_data-len(data)))
        datas.append(data)
        
        attn_mask=[int(e!=tokenizer.pad_token_id) for e in data]
        attn_masks.append(attn_mask)
        
        label.extend([-100]*(max_len_label-len(label)))
        labels.append(label)
        
    return torch.tensor(datas), torch.tensor(attn_masks), torch.tensor(labels)

In [28]:
dataset_train=WebNLGDataset(tokenizer=tokenizer)
dataloader_train=DataLoader(dataset_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

 | Aarhus_Airport : cityServed : "Aarhus, Denmark"
The Aarhus is the airport of Aarhus, Denmark.
 | Aarhus_Airport : cityServed : "Aarhus, Denmark"
Aarhus Airport serves the city of Aarhus, Denmark.
 | Aarhus_Airport : cityServed : Aarhus
Aarhus airport serves the city of Aarhus.
 | Aarhus_Airport : elevationAboveTheSeaLevel : 25.0
Aarhus Airport is 25 metres above sea level.
 | Aarhus_Airport : elevationAboveTheSeaLevel : 25.0
Aarhus airport is at an elevation of 25 metres above seal level.
 | Aarhus_Airport : elevationAboveTheSeaLevel : 25.0
Aarhus Airport is 25.0 metres above the sea level.
 | Aarhus_Airport : location : Tirstrup
Aarhus Airport is located in Tirstrup.
 | Aarhus_Airport : location : Tirstrup
The location of Aarhus Airport is Tirstrup.
 | Aarhus_Airport : operatingOrganisation : "Aarhus Lufthavn A/S"
Aarhus Airport is operated by Aarhus Lufthavn A/S.
 | Aarhus_Airport : operatingOrganisation : "Aarhus Lufthavn A/S"
Aarhus Lufthavn A/S is the operation organisation of 

 | Ashgabat_International_Airport : runwayName : "12R/30L"
Ashgabat International Airport has the runway name 12R/30L.
 | Ashgabat_International_Airport : runwayName : "12R/30L"
Ashgabat International Airport has a runway named 12R/30L.
 | Athens : country : Greece
Athens is located in the country of Greece.
 | Athens : country : Greece
Athens is a city in Greece.
 | Athens_International_Airport : cityServed : Athens
The Athens International Airport serves the city of Athens.
 | Athens_International_Airport : cityServed : Athens
Athens International Airport serves the city of Athens.
 | Athens_International_Airport : elevationAboveTheSeaLevelInMetres : 94
Athens International Airport is 94 metres above sea level.
 | Athens_International_Airport : elevationAboveTheSeaLevelInMetres : 94
The elevation above the sea level (in metres) of Athens International Airport is 94.
 | Athens_International_Airport : runwayLength : 3800.0
The runway length at Athens International Airport is 3,800.
 | 

 | Apollo_12 : backupPilot : Alfred_Worden
Alfred Worden was a backup pilot of Apollo 12.
 | Apollo_12 : backupPilot : Alfred_Worden
Apollo 12's backup pilot was Alfred Worden.
 | Apollo_12 : backupPilot : Alfred_Worden
Alfred Worden was the backup pilot of Apollo 12.
 | Apollo_12 : commander : David_Scott
Apollo 12 was commanded by David Scott.
 | Apollo_12 : commander : David_Scott
David Scott was the commander of Apollo 12.
 | Apollo_12 : operator : NASA
The operator of Apollo 12 was NASA.
 | Apollo_12 : operator : NASA
The Apollo 12 operator is NASA.
 | Apollo_12 : operator : NASA
NASA operated Apollo 12.
 | Apollo_12 : operator : NASA
The Apollo 12 mission was operated by NASA.
 | Apollo_12 : operator : NASA
Apollo 12 is operated by NASA.
 | Apollo_8 : commander : Frank_Borman
Frank Borman was an Apollo 8 Commander.
 | Apollo_8 : commander : Frank_Borman
Frank Borman was the Apollo 8 commander.
 | Apollo_8 : commander : Frank_Borman
The Apollo 8 commander was Frank Borman.
 | Apol

 | Adisham_Hall : address : "St. Benedict's Monastery, Adisham, Haputhale, Sri Lanka"
The address of Adisham Hall is, St Benedict's Monastery, Adisham, Haputhale, Sri Lanka.
 | Adisham_Hall : architecturalStyle : "Tudor and Jacabian"
Adisham Halls Architectural style`is Tudor and Jacabian.
 | Adisham_Hall : architecturalStyle : "Tudor and Jacabian"
The architecture style of Adisham Hall is Tudor and Jacabian.
 | Adisham_Hall : architecturalStyle : "Tudor and Jacabian"
Adisham Hall is in the architectural style of "Tudor and Jacobean".
 | Adisham_Hall : architecturalStyle : Tudor_Revival_architecture
Adisham Hall has the Tudor Revival architectural style.
 | Adisham_Hall : architecturalStyle : Tudor_Revival_architecture
Adisham Hall has the architectural style 'Tudor Revival'.
 | Adisham_Hall : architecturalStyle : Tudor_Revival_architecture
The Adisham Hall's style of architecture is Tudor Revival.
 | Adisham_Hall : buildingStartDate : "1927"
Construction of Adisham Hall began in 1927.

 | (19255)_1994_VK8 : escapeVelocity : 0.0925 (kilometrePerSeconds)
(19255) 1994 VK8 has an escape velocity of 0.0925 km/s.
 | (19255)_1994_VK8 : mass : 5.6 (kilograms)
19255 1994 VK8 has a mass of 5.6 kgs.
 | (19255)_1994_VK8 : mass : 5.6 (kilograms)
(19255) 1994 VK8 has a mass of 5.6 kilograms.
 | (19255)_1994_VK8 : mass : 5.6 (kilograms)
19255 1994 VK8 has a mass of 5.6kg.
 | (19255)_1994_VK8 : orbitalPeriod : 8788850000.0
(19255) 1994 VK8 has an orbital period of 8788850000.0.
 | (19255)_1994_VK8 : periapsis : 6155910000000.0
(19255) 1994 VK8 has a periapsis of 6155910000000.0.
 | (19255)_1994_VK8 : periapsis : 6155910000000.0
The periapsis of (19255) 1994 VK8 is 6155910000000.0.
 | (19255)_1994_VK8 : temperature : 43.0 (kelvins)
(19255) 1994 VK8 has a temperature of 43.0 kelvins.
 | (19255)_1994_VK8 : temperature : 43.0 (kelvins)
The temperature of 19255 1994 VK8 is 43 kelvins.
 | (29075)_1950_DA : density : 3.5 (kilograms)
(29075) 1950 DA has a density of 3.5 kilograms.
 | (29075

 | 10_Hygiea : apoapsis : 523951582.33968 (kilometres)
10 Hygiea has an apoapsis of 523951582.33968 (kilometres).
 | 10_Hygiea : apoapsis : 523951582.33968 (kilometres)
The asteroid called 10 Hygiea, has an apoapsis of 523951582.33968 kilometres.
 | 10_Hygiea : discoverer : Annibale_de_Gasparis
Annibale de Gasparis discovered 10 Hygiea.
 | 10_Hygiea : epoch : 2015-06-27
10 Hygiea has an epoch date of June 27th 2015.
 | 10_Hygiea : epoch : 2015-06-27
The asteroid called 10 Hygiea has an epoch date of 27th June 2015.
 | 10_Hygiea : escapeVelocity : 0.21 (kilometrePerSeconds)
10 Hygiea has an escape velocity of 0.21 kilometres per second.
 | 10_Hygiea : escapeVelocity : 0.21 (kilometrePerSeconds)
10 Hygiea has an escape velocity of 0.21 km per secs.
 | 10_Hygiea : escapeVelocity : 0.21 (kilometrePerSeconds)
The escape velocity of 10 Hygiea is 0.21 km per sec.
 | 10_Hygiea : formerName : "A900 GA"
The former name of 10 Hygiea was A900 GA.
 | 10_Hygiea : formerName : "A900 GA"
10 Hygiea was

Olive oil is an ingredient in Ajoblanco.
 | Ajoblanco : ingredient : Water
Ajoblanco contains water.
 | Ajoblanco : ingredient : Water
Water is an ingredient in Ajoblanco.
 | Ajoblanco : region : Andalusia
Ajoblanco is a food found in Andalusia.
 | Ajoblanco : region : Andalusia
Ajoblanco is from the Andalusia region.
 | Ajoblanco : region : Andalusia
Ajoblanco is from Andalusia.
 | Almond : division : Flowering_plant
Almond is classed as a flowering plant.
 | Almond : division : Flowering_plant
Almonds are in the division of flowering plants.
 | Almond : family : Rosaceae
Almond is part of the Rosaceae family.
 | Almond : family : Rosaceae
Almonds are from the Rosaceae family.
 | Almond : order : Rosales
Almond is one of the members of the Rosales order.
 | Almond : order : Rosids
Almond is part of the order of Rosids.
 | Almond : order : Rosids
Almonds are from the order Rosids.
 | Amatriciana_sauce : country : Italy
Italy is the country Amatriciana sauce comes from.
 | Amatriciana_s

one of the ingredients of Batchoy is Pork.
 | Batchoy : mainIngredient : "noodles, pork organs, vegetables, chicken, shrimp, beef"
The main ingredients of batchoy are noodles, pork organs, vegetables, chicken, shrimp, and beef.
 | Batchoy : mainIngredient : "noodles, pork organs, vegetables, chicken, shrimp, beef"
The main ingredients of Batchoy are noodles, pork organs, vegetables, chicken, shrimp and beef.
 | Batchoy : mainIngredient : "noodles, pork organs, vegetables, chicken, shrimp, beef"
The main ingredients of Batchoy are noodles, pork organs, vegetables, chicken, shrimp, beef.
 | Batchoy : region : La_Paz,_Iloilo_City
Batchoy is a food found in La Paz, Iloilo City.
 | Batchoy : region : La_Paz,_Iloilo_City
Batchoy originated from the region of La Paz, Iloilo City.
 | Beef_kway_teow : country : "Singapore and Indonesia"
Beef kway teow is found in the countries of Indonesia and Singapore.
 | Beef_kway_teow : country : "Singapore and Indonesia"
Beef kway teow is a dish commonly f

 | 14th_New_Jersey_Volunteer_Infantry_Monument : foundingDate : 1907-07-11
The 14th New Jersey Volunteer Infantry Monument was started on 11th July 1907.
 | 14th_New_Jersey_Volunteer_Infantry_Monument : foundingDate : 1907-07-11
The 14th New Jersey Volunteer Infantry Monument's founding date was on 1907-07-11.
 | 14th_New_Jersey_Volunteer_Infantry_Monument : owner : National_Park_Service
The 14th New Jersey Volunteer Infantry Monument is owned by the National Park Service.
 | 14th_New_Jersey_Volunteer_Infantry_Monument : owner : National_Park_Service
The National Park Service looks after the 14th New Jersey Volunteer Infantry Monument.
 | 14th_New_Jersey_Volunteer_Infantry_Monument : owner : National_Park_Service
14th New Jersey Volunteer Infantry Monument is owned by the National Park Service.
 | 14th_New_Jersey_Volunteer_Infantry_Monument : owner : National_Park_Service
The National Park Service is the owner of the 14th New Jersey Volunteer Infantry Monument.
 | 14th_New_Jersey_Volun

 | AEK_Athens_F.C. : ground : Olympic_Stadium_(Athens)
The Olympic Stadium (Athens) is the home ground of AEK Athens FC.
 | AEK_Athens_F.C. : ground : Olympic_Stadium_(Athens)
The ground for AEK Athens FC is the Olympic Stadium (Athens).
 | AEK_Athens_F.C. : league : Superleague_Greece
AEK Athens FC compete in the Superleague Greece.
 | AEK_Athens_F.C. : league : Superleague_Greece
AEK Athens F.C. is in the Superleague of Greece.
 | AEK_Athens_F.C. : league : Superleague_Greece
AEK Athens F.C. play in the Superleague Greece.
 | AEK_Athens_F.C. : manager : Gus_Poyet
AEK Athens FC had the manager Gus Poyet.
 | AEK_Athens_F.C. : manager : Gus_Poyet
AEK Athens F.C. manager is Gus Poyet.
 | AEK_Athens_F.C. : manager : Gus_Poyet
AEK Athens are managed by Gus Poyet.
 | AEK_Athens_F.C. : numberOfMembers : 69618
AEK Athens F.C. has 69618 members.
 | AEK_Athens_F.C. : season : 2014
AEK Athens FC played in the 2014 season.
 | AFC_Ajax : fullName : "Amsterdamsche Football Club Ajax"
The full name 

Barrow A.F.C. team won the championship before in the National League North.
 | National_League_North : champions : Barrow_A.F.C.
Barrow A.F.C. are the National League North Champions.
 | Olympic_Stadium_(Athens) : location : Athens
The Olympic Stadium (Athens) is located in the city of Athens.
 | Olympic_Stadium_(Athens) : location : Athens
The Olympic Stadium (Athens) is located in Athens.
 | Olympic_Stadium_(Athens) : location : Athens
The Olympic Stadium is in Athens.
 | Paulo_Sousa : club : ACF_Fiorentina
Paulo Sousa plays for ACF Fiorentina.
 | Paulo_Sousa : club : Inter_Milan
Paulo Sousa once played for Inter Milan.
 | Paulo_Sousa : club : Inter_Milan
Paulo Sousa is attached to the club Inter Milan.
 | Paulo_Sousa : club : Inter_Milan
Paulo Sousa plays at the Inter Milan club.
 | Paulo_Sousa : club : Juventus_F.C.
Paulo Sousa played for Juventus FC.
 | Paulo_Sousa : club : Juventus_F.C.
Paulo Sousa club is Juventus.
 | Paulo_Sousa : club : Juventus_F.C.
Paulo Sousa plays for Juv

 | Alderney_Airport : cityServed : Alderney | Alderney : capital : Saint_Anne,_Alderney
The capital of Alderney is Saint Anne and is served by the Alderney Airport.
 | Alderney_Airport : cityServed : Alderney | Alderney : capital : Saint_Anne,_Alderney
Alderney Airport serves Alderney, where the capital is Saint Anne.
 | Alderney_Airport : cityServed : Alderney | Alderney : language : English_language
Alderney Airport serves the English speaking city of Alderney.
 | Alderney_Airport : cityServed : Alderney | Alderney : language : English_language
The English language is spoken in Alderney which is served by Alderney airport.
 | Alderney_Airport : cityServed : Alderney | Alderney : leader : Elizabeth_II
The Alderney Airport serves Alderney whose leader is Elizabeth II.
 | Alderney_Airport : cityServed : Alderney | Alderney : leader : Elizabeth_II
Alderney Airport serves the city of Alderney, whose leader is Elizabeth II.
 | Alderney_Airport : cityServed : Alderney | Alderney : leader : 

Alan Bean was part of Apollo 12 which was operated by NASA.
 | Alan_Bean : mission : Apollo_12 | Apollo_12 : operator : NASA
Alan Bean was a crew member of the NASA operated Apollo 12.
 | Alan_Bean : mission : Apollo_12 | Apollo_12 : operator : NASA
Alan Bean was a crew member of NASA's Apollo 12.
 | Alan_Bean : mission : Apollo_12 | Apollo_12 : operator : NASA
Alan Bean served as a crew member of Apollo 12 which is operated by NASA.
 | Alan_Bean : mission : Apollo_12 | Apollo_12 : operator : NASA
Alan Bean was a crew member of Apollo 12, which was operated by nasa.
 | Alan_Shepard : almaMater : "NWC, M.A. 1957" | Alan_Shepard : award : Distinguished_Service_Medal_(United_States_Navy)
Alan Shepard, who was awarded the Distinguished Service Medal by the United States Navy, went to school at NWC and graduated with an MA in 1957.
 | Alan_Shepard : almaMater : "NWC, M.A. 1957" | Alan_Shepard : award : Distinguished_Service_Medal_(United_States_Navy)
1957 NWC M.A. graduate Alan Shepard is a

The construction of 250 Delaware Avenue, which has 12 floors, began in January 2014.
 | 250_Delaware_Avenue : floorArea : 30843.8 (square metres) | 250_Delaware_Avenue : floorCount : 12
250 Delaware Avenue has a floor area of 30853.8 square metres and has 12 floors.
 | 250_Delaware_Avenue : floorArea : 30843.8 (square metres) | 250_Delaware_Avenue : floorCount : 12
250 Delaware Avenue has 12 floors and a total floor area of 30843.8 square metres.
 | 250_Delaware_Avenue : floorArea : 30843.8 (square metres) | 250_Delaware_Avenue : floorCount : 12
250 Delaware Avenue has a floor area of 30843.8 square metres and a floor count of 12.
 | 250_Delaware_Avenue : location : Buffalo,_New_York | 250_Delaware_Avenue : architecturalStyle : Postmodern_architecture
250 Delaware Avenue is in Buffalo, New York, and it has the Postmodern style of architecture.
 | 250_Delaware_Avenue : location : Buffalo,_New_York | 250_Delaware_Avenue : architecturalStyle : Postmodern_architecture
250 Delaware Ave. in 

 | Ampara_Hospital : country : Sri_Lanka | Sri_Lanka : leader : Ranil_Wickremesinghe
Ampara Hospital is in Sri Lanka, whose leader is Ranil Wickremesinghe.
 | Ampara_Hospital : country : Sri_Lanka | Sri_Lanka : leader : Ranil_Wickremesinghe
Sri Lanka's leader is Ranil Wickremesinghe and it is home to the Ampara Hospital.
 | Ampara_Hospital : country : Sri_Lanka | Sri_Lanka : leader : Ranil_Wickremesinghe
Ampara Hospital is in Sri Lanka whose leader is Ranil Wickremesinghe.
 | Asher_and_Mary_Isabelle_Richardson_House : architect : Alfred_Giles_(architect) | Alfred_Giles_(architect) : birthPlace : England
The architect of Asher and Mary Isabelle Richardson House was Alfred Giles, who was born in England.
 | Asher_and_Mary_Isabelle_Richardson_House : architect : Alfred_Giles_(architect) | Alfred_Giles_(architect) : birthPlace : England
Asher and Mary Isabelle Richardson House was designed by architect Alfred Giles, who was born in England.
 | Asher_and_Mary_Isabelle_Richardson_House : arc

101 Helena was discovered by James Craig Watson, who comes from Canada.
 | 101_Helena : discoverer : James_Craig_Watson | James_Craig_Watson : stateOfOrigin : Canada
101 Helena was discovered by James Craig Watson from Canada.
 | 101_Helena : escapeVelocity : 0.0348 (kilometrePerSeconds) | 101_Helena : apoapsis : 441092000.0 (kilometres)
101 Helena has an escape velocity of 0.0348 km per second and an apoapsis of 441092000.0 km.
 | 101_Helena : escapeVelocity : 0.0348 (kilometrePerSeconds) | 101_Helena : apoapsis : 441092000.0 (kilometres)
101 Helena has an escape velocity of 0.0348 km/s and an apoapsis of 441092000.0 km.
 | 101_Helena : escapeVelocity : 0.0348 (kilometrePerSeconds) | 101_Helena : apoapsis : 441092000.0 (kilometres)
101 Helena has an apoapsis of 441092000.0 km and an escape velocity of 0.0348 km per sec.
 | 101_Helena : mass : 3.0 (kilograms) | 101_Helena : apoapsis : 441092000.0 (kilometres)
101 Helena has a mass of 3.0 kg and an apoapsis of 441092000.0 km.
 | 101_Hel

The comic character, Balder, was created by Jack Kirby, an American.
 | Balder_(comicsCharacter) : creator : Jack_Kirby | Jack_Kirby : nationality : Americans
The American Jack Kirby created the comic character Balder.
 | Ballistic_(comicsCharacter) : alternativeName : "Kelvin Mao" | Ballistic_(comicsCharacter) : creator : "Michael Manley"
Michael Manley created the comic character, Ballistic, who has the alternative name, Kelvin Mao.
 | Ballistic_(comicsCharacter) : alternativeName : "Kelvin Mao" | Ballistic_(comicsCharacter) : creator : "Michael Manley"
The comic character, Ballistic, has the alternative name, Kelvin Mao, and was created by Michael Manley.
 | Ballistic_(comicsCharacter) : creator : Doug_Moench | Ballistic_(comicsCharacter) : creator : "Michael Manley"
The creators of the comic character Ballistic were Michael Manley and Doug Moench.
 | Ballistic_(comicsCharacter) : creator : Doug_Moench | Ballistic_(comicsCharacter) : creator : "Michael Manley"
Doug Moench and Michae

Tomatoes, of the order Solanales, are found in Arrabiata sauce.
 | Arrabbiata_sauce : region : Rome | Arrabbiata_sauce : ingredient : Olive_oil
Olive oil is an ingredient used in the preparation of Arrabbiata sauce, a dish from Rome.
 | Arròs_negre : country : Spain | Arròs_negre : ingredient : Cephalopod_ink
Cephalopod ink is an ingredient in the dish Arros negre which is from Spain.
 | Arròs_negre : country : Spain | Arròs_negre : ingredient : Cubanelle
Arròs negre is a traditional dish from Spain, Cubanelle is an ingredient.
 | Arròs_negre : country : Spain | Arròs_negre : ingredient : Cubanelle
Cubanelle is an ingredient in the Spanish dish of Arros negre.
 | Arròs_negre : country : Spain | Arròs_negre : ingredient : Cubanelle
Arros negre is a traditional Spanish dish made with Cubanelle.
 | Arròs_negre : country : Spain | Arròs_negre : ingredient : Cuttlefish
Arròs negre is from Spain and it uses cuttlefish as an ingredient.
 | Arròs_negre : country : Spain | Arròs_negre : ingredi

Baked Alaska is thought to have originated in the United States, France or China and has meringue as an ingredient.
 | Baked_Alaska : country : "France, United States or China" | Baked_Alaska : ingredient : Meringue
Baked Alaska with meringue is from France, the US and China.
 | Baked_Alaska : country : "France, United States or China" | Baked_Alaska : ingredient : Sponge_cake
Sponge cake is an ingredient in baked alaska which is said to come from either France, United States or China.
 | Baked_Alaska : country : "France, United States or China" | Baked_Alaska : ingredient : Sponge_cake
France, United States and China all claim to have invented Baked Alaska, which is made from sponge cake.
 | Baked_Alaska : country : China | Baked_Alaska : ingredient : Ice_cream
Ice cream is an ingredient of Baked Alaska and is a dish in China.
 | Baked_Alaska : country : China | Baked_Alaska : ingredient : Ice_cream
Ice cream is an ingredient in Baked Alaska, which comes from China.
 | Baked_Alaska : 

Binignit is a dessert and a cookie is also a dessert.
 | Binignit : course : Dessert | Dessert : dishVariation : Cookie
Binignit should be served as the dessert course, as should cookies.
 | Binignit : course : Dessert | Dessert : dishVariation : Sandesh_(confectionery)
Binignit and sandesh are both dessert.
 | Binignit : course : Dessert | Dessert : dishVariation : Sandesh_(confectionery)
Sandesh (confectionery) is a dish that can be served as a dessert, as is Binignit.
 | Binignit : course : Dessert | Dessert : dishVariation : Sandesh_(confectionery)
Two types of deserts that can be served are Binignit and Sandesh.
 | Binignit : ingredient : Sweet_potato | Binignit : country : Philippines
Sweet potato is an ingredient in Binignit which comes from the Philippines.
 | Binignit : ingredient : Sweet_potato | Binignit : country : Philippines
Sweet potatoes are in binignit recipes, it is a dish of the Phillippines.
 | Binignit : ingredient : Sweet_potato | Binignit : country : Philippines


AFC Blackpool have 1500 members and their ground is located in Blackpool.
 | A.F.C._Blackpool : ground : Blackpool | A.F.C._Blackpool : numberOfMembers : 1500
AFC Blackpool is located in Blackpool and has 1500 members.
 | A.F.C._Blackpool : manager : Stuart_Parker_(footballer) | Stuart_Parker_(footballer) : club : Blackpool_F.C.
Start Parker is manager at A.F.C. Blackpool and plays for Blackpool F.C..
 | A.F.C._Blackpool : manager : Stuart_Parker_(footballer) | Stuart_Parker_(footballer) : club : Blackpool_F.C.
Stuart Parker, who once managed AFC Blackpool, plays for Blackpool F.C.
 | A.F.C._Blackpool : manager : Stuart_Parker_(footballer) | Stuart_Parker_(footballer) : club : Blackpool_F.C.
The manager of A.F.C. Blackpool is Stuart Parker (footballer) who plays for Blackpool F.C.
 | A.F.C._Blackpool : manager : Stuart_Parker_(footballer) | Stuart_Parker_(footballer) : club : Bury_F.C.
The manager of A.F.C. Blackpool is Stuart Parker, who played football for Bury FC.
 | A.F.C._Blackpoo

Agremiacao Sportiva Arapiraquense have 17000 members and were in Campeonato Brasileiro Serie C in 2015.
 | Akron_Summit_Assault : chairman : Dave_Laughlin | Akron_Summit_Assault : numberOfMembers : 3000
Akron Summit Assault has 3000 members and their chairman is Dave Laughlin.
 | Akron_Summit_Assault : chairman : Dave_Laughlin | Akron_Summit_Assault : numberOfMembers : 3000
Akron Summit Assault has got 3000 members and the chairman is Dave Laughlin.
 | Akron_Summit_Assault : chairman : Dave_Laughlin | Akron_Summit_Assault : numberOfMembers : 3000
Akrons Summit Assault has 3000 members and the chairman is Dave Laughlin.
 | Akron_Summit_Assault : ground : St._Vincent–St._Mary_High_School | St._Vincent–St._Mary_High_School : city : Akron,_Ohio
The ground of Akron Summit Assault is in St Vincent St Mary High School of Akron, Ohio.
 | Akron_Summit_Assault : ground : St._Vincent–St._Mary_High_School | St._Vincent–St._Mary_High_School : city : Akron,_Ohio
Akron Summit Assault's ground is St. 

 | Afonso_Pena_International_Airport : elevationAboveTheSeaLevel : 911.0 | Afonso_Pena_International_Airport : operatingOrganisation : Infraero | Afonso_Pena_International_Airport : location : São_José_dos_Pinhais
Afonso Pena International airport is located in Sao Jose dos Pinhais and is operated by Infraero. It is 911 metres above sea level.
 | Afonso_Pena_International_Airport : elevationAboveTheSeaLevel : 911.0 | Afonso_Pena_International_Airport : operatingOrganisation : Infraero | Afonso_Pena_International_Airport : location : São_José_dos_Pinhais
Afonso Pena International Airport, located in Sao Jose dos Pinhais, is 911 meters above sea level and is operated by Infraero.
 | Afonso_Pena_International_Airport : elevationAboveTheSeaLevel : 911.0 | Afonso_Pena_International_Airport : operatingOrganisation : Infraero | Afonso_Pena_International_Airport : location : São_José_dos_Pinhais
Afonso Pena International airport is located 911 metres above sea level in Sao Jose dos Pinhais and

 | Angola_International_Airport : location : Ícolo_e_Bengo | Angola_International_Airport : runwayLength : 4000.0 | Angola_International_Airport : elevationAboveTheSeaLevelInMetres : 159
Angola International Airport is in Icolo e Bengo and is 159 metres above sea level with a runway that's 4,000 feet long.
 | Angola_International_Airport : location : Ícolo_e_Bengo | Ícolo_e_Bengo : country : Angola | Angola_International_Airport : elevationAboveTheSeaLevelInMetres : 159
At 159 meters above sea level, Angola International airport is located in Icolo e Bengo, in Angola.
 | Angola_International_Airport : location : Ícolo_e_Bengo | Ícolo_e_Bengo : country : Angola | Angola_International_Airport : elevationAboveTheSeaLevelInMetres : 159
Angola International Airport is in Icolo e Bengo, Angola and is 159 metres above sea level.
 | Angola_International_Airport : location : Ícolo_e_Bengo | Ícolo_e_Bengo : country : Angola | Angola_International_Airport : elevationAboveTheSeaLevelInMetres : 159

Alan Shepard (died on 1998-07-21 in California) graduated from NWC in 1957.
 | Alan_Shepard : almaMater : "NWC, M.A. 1957" | Alan_Shepard : deathPlace : California | Alan_Shepard : deathDate : "1998-07-21"
Alan Shepard, who attended NWC, M.A. in 1957, died in California on July 21, 1998.
 | Alan_Shepard : birthPlace : New_Hampshire | Alan_Shepard : selectedByNasa : 1959 | Alan_Shepard : birthDate : "1923-11-18"
Alan Shepard who was born on Nov 18, 1923 in New Hampshire was selected by NASA in 1959.
 | Alan_Shepard : birthPlace : New_Hampshire | Alan_Shepard : selectedByNasa : 1959 | Alan_Shepard : birthDate : "1923-11-18"
Alan Shepard was born on November 18th, 1923 in New Hampshire and was selected by NASA in 1959.
 | Alan_Shepard : birthPlace : New_Hampshire | Alan_Shepard : selectedByNasa : 1959 | Alan_Shepard : birthDate : "1923-11-18"
Alan Shepard was born in New Hampshire, chosen by NASA in 1959 and was born on the 18th of November 1923.
 | Alan_Shepard : birthPlace : New_Hampshi

108 St Georges Terrace has a floor area of 39599.0 square metres and is located in Perth, Australia.
 | 108_St_Georges_Terrace : location : Perth | Perth : country : Australia | 108_St_Georges_Terrace : floorCount : 50
108 St. Georges Terrace boasts 50 floors and is located in Perth, Australia.
 | 108_St_Georges_Terrace : location : Perth | Perth : country : Australia | 108_St_Georges_Terrace : floorCount : 50
108 St. Georges Terrace in Perth, Australia has 50 floors.
 | 108_St_Georges_Terrace : location : Perth | Perth : country : Australia | 108_St_Georges_Terrace : floorCount : 50
The 108 St. Georges Terrace in Perth, Australia, has a floor count of 50.
 | 11_Diagonal_Street : floorCount : 20 | 11_Diagonal_Street : architect : Helmut_Jahn | 11_Diagonal_Street : completionDate : 1983
11 Diagonal Street, with 20 floors, was designed by Helmut Jahn and completed in 1983.
 | 11_Diagonal_Street : floorCount : 20 | 11_Diagonal_Street : architect : Helmut_Jahn | 11_Diagonal_Street : comple

Adare Manor was completed in 1862, opened in 1700 and was designed by James Pain and George Richard Pain.
 | Adare_Manor : architect : "James Pain and George Richard Pain," | Adare_Manor : completionDate : 1862 | Adare_Manor : owner : J._P._McManus
James Pain and George Richard Pain are the architects of Adare Manor, which was completed in 1862, and is owned by JP McManus.
 | Adare_Manor : architect : "James Pain and George Richard Pain," | Adare_Manor : completionDate : 1862 | Adare_Manor : owner : J._P._McManus
The architects for Adare Manor completed in 1862 were James Pain and George Richard Pain. The Manor is owned by JP McManus.
 | Adare_Manor : architect : "James Pain and George Richard Pain," | Adare_Manor : completionDate : 1862 | Adare_Manor : owner : J._P._McManus
Adare Manor, owned by J. P. McManus, was designed by James and George Richard Pain and its construction was completed in 1862.
 | Adare_Manor : architect : Lewis_Nockalls_Cottingham | Adare_Manor : completionDate :

 | (19255)_1994_VK8 : density : 2.0 (gramPerCubicCentimetres) | (19255)_1994_VK8 : escapeVelocity : 0.0925 (kilometrePerSeconds) | (19255)_1994_VK8 : apoapsis : 6603633000.0 (kilometres)
The celestial body known as (19255) 1994 VK8 has a density of 2 grams per cubic centimetres,an escape velocity of 0.0925 km/s and its apoapsis is 6603633000.0 kilometres.
 | (19255)_1994_VK8 : density : 2.0 (gramPerCubicCentimetres) | (19255)_1994_VK8 : escapeVelocity : 0.0925 (kilometrePerSeconds) | (19255)_1994_VK8 : apoapsis : 6603633000.0 (kilometres)
(19255) 1994 VK8 has a density of 2.0 grams per cubic centimetre, an escape velocity of 0.0925 km/s, and an apoapsis of 6603633000.0 kilometres.
 | (19255)_1994_VK8 : density : 2.0 (gramPerCubicCentimetres) | (19255)_1994_VK8 : escapeVelocity : 0.0925 (kilometrePerSeconds) | (19255)_1994_VK8 : apoapsis : 6603633000.0 (kilometres)
(19255) 1994 VK8 has a density of 2.0 grams per cu cm, an escape velocity of 0.0925 km per sec and an apoapsis of 660363300

KeyboardInterrupt: 

In [None]:
class PromptTuning(nn.Module):
    """
    """
    def __init__(self, pretrained_config, prompt_len=20, hidden_dim=256):
        super().__init__()
        
        # Config of Pre-Trained LM
        self.pretrained_config=pretrained_config
        
        # torch.tensor([0, 1, 2, .. , prompt_len-1])
        self.pre_prompt=torch.arange(prompt_len)
        # Embedding
        self.embd=nn.Embedding(num_embeddings=prompt_len, embedding_dim=pretrained_config.d_model)
        # Reparameterization
        self.reparam=nn.Sequential(
            nn.Linear(pretrained_config.d_model, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, pretrained_config.d_model)
        )
        
    def forward(self, batch_size, device):
        # Shape: batch_size, prompt_len
        prompt=self.pre_prompt.unsqueeze(0).expand(batch_size, -1).to(device)
        # Shape: batch_size, prompt_len, d_model
        prompt=self.embd(prompt)
        # Shape: batch_size, prompt_len, d_model
        prompt=self.reparam(prompt)
        
        return prompt

In [29]:
# Model: Prompt Tuning
model=PromptTuning(pretrained_config=pretrained.config, prompt_len=prompt_len, hidden_dim=hidden_dim)

# Optim, Scheduler
optimizer=AdamW(model.parameters(), lr=lr)
# NO Warm-Up
scheduler=get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=int(epochs*len(dataset_train)/(accumulation_steps*batch_size))
)

# TensorBoard: Logging
writer=SummaryWriter()
step_global=0

for epoch in range(epochs):
    # Train Phase
    model.train()
    model.to(device)
    
    loss_train=0
    optimizer.zero_grad()
    
    for step, (data, attn_mask, label) in enumerate(dataloader_train):
        data=data.to(device)
        attn_mask=attn_mask.to(device)
        label=label.to(device)
        
        prompt=model(batch_size=data.shape[0] , device=device)
        outputs=pretrained(input_ids=data, attention_mask=attn_mask, labels=label, prompt=prompt)
        
        loss=outputs[0]/accumulation_steps
        loss.backward()
        
        loss_train+=loss.item()
        
        if (step+1)%accumulation_steps==0:
            step_global+=1
            
            # TensorBoard
            writer.add_scalar(
                f'loss_train/MT5-base_Prompt-Tuning_prompt-len{prompt_len}_hidden-dim{hidden_dim}_lr{lr}_batch{int(accumulation_steps*batch_size)}_epoch{epochs}',
                loss_train,
                step_global
            )
            # Console
            if step_global%1000==0:
                print(f'epoch {epoch+1} step {step_global} loss_train {loss_train:.4f}')
            # Set Loss to 0
            loss_train=0
            
            optimizer.step()
            scheduler.step()
            
            optimizer.zero_grad()
            
    # Save Model
    model.to(torch.device('cpu'))
    torch.save(model, f'../model/MT5-base_Prompt-Tuning_prompt-len{prompt_len}_hidden-dim{hidden_dim}_lr{lr}_batch{int(accumulation_steps*batch_size)}_epoch{epoch+1}of{epochs}.pt')

NameError: name 'PromptTuning' is not defined