In [1]:
from typing import List, Optional, Any
import torch
import torch.utils.data
import numpy as np
import random

import datetime

import transformers
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

import json

In [None]:
torch.manual_seed(21)
torch.cuda.manual_seed_all(21)
np.random.seed(21)
random.seed(21)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
device = torch.device("cuda")

In [None]:
class DatasetNatural2CQL(torch.utils.data.Dataset):
    def __init__(self, path: Optional[str] = None) -> None:
        self.sentence_freq = []
        self.cql2nl = []
        self.nl2cql = []
        self.natural_language_rulebased = []
        self.cql = []
        self.natural_language = []

        if path is not None:
            self.load_tsv(path)

    def add_translation(self, freq: int, cql: str, natural_language_rulebased: str, natural_language: List[str]) -> None:
        cql_index = len(self.sentence_freq)
        self.sentence_freq.append(freq)
        self.cql.append(cql)
        self.natural_language_rulebased.append(natural_language_rulebased)
        self.cql2nl.append([])

        for sentence in natural_language:
            self.nl2cql.append(cql_index)
            self.cql2nl[-1].append(len(self.natural_language))
            self.natural_language.append(sentence)

    def load_tsv(self, path: str) -> None:
        with open(path, "r") as file_data:
            for line in file_data:
                line = line.strip()
                line = line.split("\t")
                texts_json = json.loads(line[4])
                texts_extracted = texts_json["data"][0]["content"][0]["text"]["value"].split("\n")
                self.add_translation(int(line[0]), line[2], line[3], texts_extracted)

    def __len__(self):
        return len(self.nl2cql)

    def __getitem__(self, idx):
        if idx < len(self.nl2cql):
            return self.natural_language[idx], self.cql[self.nl2cql[idx]]
        return None
        

In [None]:
class DatasetNatural2CQLTokenized(DatasetNatural2CQL):
    def __init__(self, tokenizer: Any, path: Optional[str] = None) -> None:
        super().__init__(path)
        self.tokenizer = tokenizer
        self.natural_language_tokenized = []
        self.natural_language_mask = []
        self.cql_tokenized = []
        if len(self) > 0:
            self.tokenize()

    def tokenize(self) -> None:
        for sentence in self.natural_language:
            sentence_tokenized = self.tokenizer.batch_encode_plus(
                ["translate: " + sentence.replace("/", "//")],
                return_tensors="pt",
            )
            self.natural_language_tokenized.append(sentence_tokenized.input_ids.squeeze().to(dtype=torch.long))
            self.natural_language_mask.append(sentence_tokenized.attention_mask.squeeze().to(dtype=torch.long))

        for c in self.cql:
            c_tokenized = self.tokenizer.batch_encode_plus(
                [c],
                return_tensors="pt",
            )
            self.cql_tokenized.append(c_tokenized.input_ids.squeeze().to(dtype=torch.long))

    def apply_padding(self, mx0, mx2) -> None:
        for i in range(len(self.natural_language_tokenized)):
            tmp = self.natural_language_mask[i]
            if len(tmp.shape) == 0:
                tmp = tmp.unsqueeze(0)
            self.natural_language_mask[i] = torch.zeros(mx0)
            self.natural_language_mask[i][:tmp.shape[0]] = tmp
            self.natural_language_mask[i] = self.natural_language_mask[i].to(dtype=torch.long)

            tmp = self.natural_language_tokenized[i]
            if len(tmp.shape) == 0:
                tmp = tmp.unsqueeze(0)
            self.natural_language_tokenized[i] = torch.zeros(mx0)
            self.natural_language_tokenized[i][:tmp.shape[0]] = tmp
            self.natural_language_tokenized[i] = self.natural_language_tokenized[i].to(dtype=torch.long)

        for i in range(len(self.cql_tokenized)):
            tmp = self.cql_tokenized[i]
            if len(tmp.shape) == 0:
                tmp = tmp.unsqueeze(0)
            self.cql_tokenized[i] = torch.zeros(mx2)
            self.cql_tokenized[i][:tmp.shape[0]] = tmp
            self.cql_tokenized[i] = self.cql_tokenized[i].to(dtype=torch.long)

    def cut_data(self, mx0, mx2):
        for i in range(len(self.natural_language_tokenized)):
            self.natural_language_mask[i] = self.natural_language_tokenized[i][:mx0]
            self.natural_language_mask[i] = self.natural_language_mask[i][:mx0]

        for i in range(len(self.cql_tokenized)):
            self.cql_tokenized[i] = self.cql_tokenized[i][:mx2]

    def __getitem__(self, idx):
        if idx < len(self.nl2cql):
            return self.natural_language_tokenized[idx], self.natural_language_mask[idx], self.cql_tokenized[self.nl2cql[idx]]
        return None
        

In [None]:
model_name = "google/flan-t5-large"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")

In [None]:
dataset = DatasetNatural2CQL("expand_natural_texts_0004.res.tsv")

In [None]:
dataset_tokenized = DatasetNatural2CQLTokenized(tokenizer, "expand_natural_texts_0004.res.tsv")

In [None]:
dataset_tokenized[106102]

In [None]:
mx0 = 0
mx2 = 0
for i, rec in enumerate(dataset_tokenized):
    if i >= len(dataset_tokenized):
        break
    if len(rec[0].shape) > 0 and len(rec[0]) > mx0:
        mx0 = rec[0].shape[0]
    if len(rec[2].shape) > 0 and len(rec[2]) > mx2:
        mx2 = rec[2].shape[0]
print("mx0 =", mx0, "mx2 =", mx2)

In [None]:
dataset_tokenized.apply_padding(mx0, mx2)

In [None]:
dataset_tokenized[106102]

In [None]:
dataset_tokenized.cut_data(120, 90)

In [None]:
dataset_tokenized[106102][0].shape

In [None]:
# inspired from: https://github.com/Shivanandroy/T5-Finetuning-PyTorch
def train(epoch, tokenizer, model, device, loader, optimizer):
    epoch_start = datetime.datetime.utcnow()
    model.train()
    for _, data in enumerate(loader, 0):
        y = data[2].to(device, dtype=torch.long)
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone().detach()
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
        ids = data[0].to(device, dtype=torch.long)
        mask = data[1].to(device, dtype=torch.long)
        
        outputs = model(
            input_ids=ids,
            attention_mask=mask,
            decoder_input_ids=y_ids,
            labels=lm_labels,
        )
        loss = outputs[0]

        time_delta = datetime.datetime.utcnow() - epoch_start
        if _ % 100 == 0:
            print("time: ", datetime.datetime.utcnow().isoformat(), time_delta.seconds , "sec | epoch: ", str(epoch), "| batch: ", str(_), "/", len(loader), "|", str(loss.item()))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [None]:
optimizer = torch.optim.Adam(
    params=model.parameters(), lr=1e-4
)

In [None]:
train(0, tokenizer, model, device, torch.utils.data.DataLoader(dataset_tokenized, batch_size=4, shuffle=True, num_workers=0), optimizer)

In [None]:
model.eval()
sentence_tokenized = tokenizer(
    "translate: Aby word police.",
    return_tensors="pt",
)
print(sentence_tokenized)
generated_ids = model.generate(
      sentence_tokenized.input_ids.to("cuda")
)
print(generated_ids)

In [None]:
tokenizer.convert_ids_to_tokens([ 0,  891,   63, 1448, 2095,    1])