In [1]:
import torch
import torch.nn as nn
from tokenizers import AddedToken
from transformers import T5TokenizerFast, T5ForConditionalGeneration
from transformers import BertModel, T5ForConditionalGeneration, T5Tokenizer

## Need to replace BERT as GNN module

In [2]:
# Define the BERT-T5 model
class BertT5(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.t5 = T5ForConditionalGeneration.from_pretrained('t5-small')

    def forward(self, input_ids, attention_mask,
                        decoder_input_ids,
                        decoder_attention_mask,
                        labels=None):
        
        bert_outputs = self.bert(input_ids=input_ids,
                                 attention_mask=attention_mask)
        
        print((bert_outputs.last_hidden_state).size())
        
        t5_outputs = self.t5(decoder_input_ids=decoder_input_ids,
                             decoder_attention_mask=decoder_attention_mask,
                             encoder_outputs=bert_outputs.last_hidden_state,
#                              encoder_hidden_states=bert_outputs.last_hidden_state
                            )
        
        if labels is not None:
            loss = t5_outputs.loss(labels=labels)
            return loss
        else:
            return t5_outputs.logits

# Instantiate the BERT-T5 model and move it to the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BertT5().to(device)

# Load the data and tokenizer
# tokenizer = T5Tokenizer.from_pretrained('t5-small')
text2sql_tokenizer = T5TokenizerFast.from_pretrained(
    't5-small',
    add_prefix_space = True
)
# TODO: Load and preprocess the data

# Define the optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
from load_dataset import Text2SQLDataset
from torch.utils.data import DataLoader

train_filepath = "/Users/aishwarya/Downloads/spring23/cs685-NLP/project/data/resdsql_pre/preprocessed_dataset_test.json" 
batch_size = 2 #'input batch size.')

train_dataset = Text2SQLDataset(
        dir_ = train_filepath,
        mode = "train")

train_dataloder = DataLoader(
        train_dataset, 
        batch_size = batch_size, 
        shuffle = True,
        collate_fn = lambda x: x,
        drop_last = True
    )

if isinstance(text2sql_tokenizer, T5TokenizerFast):
    text2sql_tokenizer.add_tokens([AddedToken(" <="), AddedToken(" <")])

In [6]:
num_epochs = 1

# Train the model
model.train()
for epoch in range(num_epochs):
    for idx, batch in enumerate(train_dataloder):
        
        batch_inputs = [data[0] for data in batch]
        batch_sqls = [data[1] for data in batch]

        if epoch == 0 and idx == 0:
            for batch_id in range(len(batch_inputs)):
                print(f"batch_inputs - {batch_inputs[batch_id]}")
                print(f"batch_sqls - {batch_sqls[batch_id]}")
#                 print("----------------------")

        tokenized_inputs = text2sql_tokenizer(
            batch_inputs, 
            padding = "max_length",
            return_tensors = "pt",
            max_length = 512, #512, max_encoder_len
            truncation = True
        )

        with text2sql_tokenizer.as_target_tokenizer():
            tokenized_outputs = text2sql_tokenizer(
                batch_sqls, 
                padding = "max_length", 
                return_tensors = 'pt',
                max_length = 256, #256, max_decoder_len
                truncation = True
            )
            
        encoder_input_ids = tokenized_inputs["input_ids"].to(device)
        encoder_input_attention_mask = tokenized_inputs["attention_mask"].to(device)
        
        decoder_input_ids = tokenized_outputs["input_ids"].to(device)
        decoder_attention_mask = tokenized_outputs["attention_mask"].to(device)
        labels = None #tokenized_outputs["attention_mask"].to(device)

        optimizer.zero_grad()
        loss = model(input_ids=encoder_input_ids,
                     attention_mask=encoder_input_attention_mask,
                     decoder_input_ids=decoder_input_ids,
                     decoder_attention_mask=decoder_attention_mask,
                     labels=labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        break

# Save the trained model
# torch.save(model.state_dict(), 'bert_t5_model.pt')

batch_inputs - Show all locations which don't have a train station with at least 15 platforms.
batch_sqls - select location from station except select location from station where number_of_platforms >= 15
batch_inputs - What is the name of the party form that is most common?
batch_sqls - select forms.form_name from forms join party_forms on forms.form_id = party_forms.form_id group by party_forms.form_id order by count ( * ) desc limit 1
torch.Size([2, 512, 768])
encoder_hidden_states.size() torch.Size([512, 768])


ValueError: not enough values to unpack (expected 3, got 2)