<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/FINAL_GNN_FINE_TUNINGT_T2SQL_DataAugmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Libraries Installation

In [None]:
!pip install datasets networkx -q

!pip install torch_geometric -q


# Install Pytorch & other libraries
!pip install torch tensorboard --quiet

# Install Hugging Face libraries
!pip install  --upgrade transformers accelerate evaluate bitsandbytes --quiet

#FlashAttention only supports Ampere GPUs or newer. #NEED A100 , L4  IN GOOGLE COLAB
!pip install -U flash-attn --no-build-isolation --quiet


!pip install peft --quiet
!pip install trl ninja packaging --quiet
!pip install diffusers safetensors  --quiet
!pip install colab-env --quiet

!pip install mistral_inference -q

!pip install trl==0.8.6 -q


!pip install sqlparse -q

!pip install bitsandbytes -q

#!pip uninstall -y torchvision -q
!pip install torchvision --no-cache-dir -q
#import evaluate

!pip install sentence-transformers -q

!pip install nlpaug -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset

from peft import LoraConfig, get_peft_model, TaskType

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
)

#from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import spacy
import numpy as np

from torch_geometric.nn import GAT

from trl import setup_chat_format

import colab_env
import evaluate

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Hugging Face model id
model_id = "mistralai/Mistral-7B-Instruct-v0.1"


# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)


tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load model and tokenizer
mistral_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

tokenizer.padding_side = 'right' # to prevent warnings

# We redefine the pad_token and pad_token_id with out of vocabulary token (unk_token)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id

## TRAININING

https://github.com/frank-morales2020/MLxDL/blob/main/FineTuning_LLM_Mistral_7B_Instruct_v0_1_for_text_to_SQL_EVALDATA.ipynb

* Import Main Components

In [None]:
!pip cache purge

Files removed: 3076


In [None]:
train_dataset = load_dataset("json", data_files="/content/gdrive/MyDrive/datasets/train_dataset.json", split="train")

In [None]:
train_dataset[0]['messages'][0]['content']

'You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\nSCHEMA:\nCREATE TABLE table_name_92 (total VARCHAR, finish VARCHAR)'

In [None]:
train_dataset[0]['messages'][1]['content']

'How many times was the finish t32?'

In [None]:
train_dataset[0]['messages'][2]['content']

'SELECT COUNT(total) FROM table_name_92 WHERE finish = "t32"'

In [None]:
len(train_dataset)

10000

* https://stackoverflow.com/questions/70950706/assertionerror-in-torch-geometric-nn-gatconv


1. **Data Input and Preparation:**
   - The process begins by loading the "sql-create-context" dataset, which presumably contains pairs of natural language questions and their corresponding SQL queries.
   - The dataset is divided into three distinct subsets: training, validation, and testing.
   - The Mistral-7B-Instruct-v0.3 language model and its tokenizer are loaded and prepared.

2. **Data Transformation with `TextToSQLDataset`:**
   - The `TextToSQLDataset` class is responsible for converting raw data into a format suitable for model training and evaluation.
   - For each data sample, the following transformations occur:
      - Tokenization: The input question and the SQL query (answer) are tokenized using the loaded tokenizer.
      - Dependency Parsing: The question is parsed using spaCy to extract the grammatical relationships between words, generating a dependency graph.
      - Dictionary Creation: A dictionary is created to store the tokenized input IDs, attention masks, labels (tokenized SQL query), and the dependency edges extracted from parsing.

3. **Batching and Shuffling with `DataLoader`:**
   - `DataLoader` takes the processed dataset from `TextToSQLDataset` and creates an iterable object for efficient batching.
   - Optionally, shuffling is applied to randomize the order of samples within each epoch during training.

4. **Forward Pass through `GraphModel`:**
   - **Mistral Encoder:** The tokenized input IDs and attention masks are fed into the Mistral model's encoder to obtain contextualized token embeddings.
   - **GATv2 Layer:** The GATv2 layer (Graph Attention Network) takes the token embeddings and the dependency edges as input. It applies graph attention mechanisms to incorporate the structural information from the dependency graph into the token representations.
   - **Pooling:** The node representations (output of GATv2) are aggregated using a pooling operation (e.g., mean pooling) to obtain a fixed-size representation of the entire input sequence.
   - **LM Head:** The pooled representation is passed through a linear layer, which produces logits – unnormalized probabilities for each token in the vocabulary.
   - **Loss Calculation:** During training, if labels (correct SQL queries) are available, the cross-entropy loss is calculated between the predicted logits and the true labels. This loss guides the optimization of the model's parameters.

5. **Model Optimization with PEFT (LoRA):**
   - The `GraphModel` is wrapped with PEFT's LoRA (Low-Rank Adaptation) configuration to enable parameter-efficient fine-tuning.
   - During training, only the parameters of the GATv2 layer and the LM head are updated, while the rest of the model parameters remain frozen.
   - The Hugging Face Trainer manages the training process, iterating over the dataset, computing gradients, and updating the model's parameters based on the calculated loss.

6. **Evaluation:**
   - After (or during) training, the model is evaluated on the validation and test sets.
   - The `compute_metrics` function decodes the predicted logits and labels back into text and assesses the model's performance using two metrics:
      - Semantic Similarity: This metric measures how semantically close the predicted SQL query is to the reference SQL query using SentenceTransformer embeddings.
      - Exact Match: This metric checks if the predicted SQL query matches the reference SQL query exactly.

7. **Output:**
    - The final output is the evaluation results, including semantic similarity and exact match scores, which provide insights into the model's ability to generate accurate SQL queries from natural language questions.


The corrected graphical representation of the dataflow, with the arrow pointing *into* the `Trainer`:

```
+-------------------+          +----------------------+
| Dataset Loading   |          | TextToSQLDataset     |
|  - sql-create...  | -------> | - Tokenization       |
|  - Split: train...|          | - Dependency Parsing |
+-------------------+          | - Dict Creation      |
                               +----------------------+
                                         |
                                         v
                         +---------------------------+
                         | DataLoader                |
                         | - Batches, Shuffling (opt)|
                         +---------------------------+
                                         |
                                         v
                         +---------------------------+
                         | GraphModel                |
                         | - Mistral Encoder         |
                         | - GATv2 Layer             |
                         | - Pooling                 |
                         | - LM Head                 |
                         | - (Loss Calculation)      |
                         +---------------------------+
                                         |
                                         v
            +--------------+          +---------+         +-----------------+
            | PEFT (LoRA)  | -------> | Trainer | <------ |Evaluation       |
            +--------------+          |         |         |(compute_metrics)|
                                      +---------+         +-----------------+
                                         |
                                         v
                                   +-----------------------+
                                   | - Semantic Similarity |
                                   | - Exact Match         |
                                   +-----------------------+
```

The revised dataflow now accurately shows that the evaluation metrics (semantic similarity and exact match) calculated by the `compute_metrics` function are used by the `Trainer` to assess the model's performance and make decisions during training (e.g., early stopping).



In [None]:
import huggingface_hub
import shutil

# Get the cache directory information
cache_info = huggingface_hub.utils.scan_cache_dir()

# Iterate through the repositories and delete them
for repo_info in cache_info.repos:
    repo_path = repo_info.repo_path
    shutil.rmtree(repo_path)  # Delete the repository folder

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    BitsAndBytesConfig, Trainer, TrainingArguments,
    set_seed,
    EarlyStoppingCallback,
    DataCollatorForSeq2Seq
)

from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    PeftModel,
    PeftConfig,
    PrefixTuningConfig,
    PromptEncoderConfig,
    TaskType
)

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from torch_geometric.nn import GAT
import spacy
import numpy as np
import torch.nn as nn
import evaluate


# Suppress warnings
import warnings
warnings.simplefilter('ignore')
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support.")

# 1. Load and Prepare Data
dataset = load_dataset("b-mc2/sql-create-context")["train"].shuffle(seed=42)

# Manually define splits
train_size = int(0.7 * len(dataset))
eval_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - eval_size

train_dataset = dataset.select(range(train_size))
eval_dataset = dataset.select(range(train_size, train_size + eval_size))
test_dataset = dataset.select(range(train_size + eval_size, len(dataset)))


# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 3a. PyTorch Datasets - Data Augmentation
import random
from nlpaug.augmenter.word import SynonymAug, RandomWordAug

def random_insertion(sentence, aug_p=0.1, synonym_aug=None, max_attempts=3):  # Add max_attempts
    words = sentence.split()
    new_words = []
    for word in words:
        new_words.append(word)
        if random.random() < aug_p:
            attempts = 0
            while attempts < max_attempts:  # Try multiple times to find a synonym
                if new_words:
                    candidate_words = [new_words[-1]]
                if len(new_words) < len(words):
                    candidate_words.append(words[len(new_words)])
                if candidate_words and synonym_aug:
                    synonym = synonym_aug.augment(random.choice(candidate_words))
                    if synonym:
                        new_words.append(synonym[0])
                        break  # Exit the loop if a synonym is found
                attempts += 1
    return ' '.join(new_words)

def augment_data(dataset):
    augmented_data = []

    synonym_aug = SynonymAug(aug_src='wordnet')
    delete_aug = RandomWordAug(action="delete", aug_p=0.05)

    for item in dataset:
        question = item['question']
        answer = item['answer']
        context = item['context']

        augmented_questions = set()
        augmented_questions.add(question)  # Always include the original question

        # Attempt to generate one more unique augmented question
        while len(augmented_questions) < 2:  # Keep trying until we have 2 unique questions
            augmentation_method = random.choice(['synonym', 'insertion', 'deletion'])

            if augmentation_method == 'synonym':
                synonyms = synonym_aug.augment(question)
                if synonyms and synonyms[0] != question:
                    augmented_questions.add(synonyms[0])

            elif augmentation_method == 'insertion':
                augmented_question = random_insertion(question, synonym_aug=synonym_aug)
                if augmented_question != question:
                    augmented_questions.add(augmented_question)

            else:  # deletion
                deleted = delete_aug.augment(question)
                if deleted and deleted[0] != question:
                    augmented_questions.add(deleted[0])

        # Add the augmented examples to the dataset
        for aug_question in list(augmented_questions):
            augmented_data.append({'question': aug_question, 'answer': answer, 'context': context})

    return augmented_data


# 3. PyTorch Datasets
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
import spacy

# Load the spaCy English model with SRL capabilities
try:
    nlp = spacy.load("en_core_web_trf")
except OSError:
    spacy.cli.download("en_core_web_trf")
    nlp = spacy.load("en_core_web_trf")

class TextToSQLDataset(Dataset):
    def __init__(self, dataset, tokenizer):
        self.dataset = dataset
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        text = item['question']
        target_text = item['answer']

        # 1. Tokenization
        tokenized_input = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=1024,
            return_tensors="pt"
        )
        tokenized_target = self.tokenizer(
            target_text,
            truncation=True,
            padding="max_length",
            max_length=1024,
            return_tensors="pt"
        )

        # Flatten lists
        tokenized_input = {k: v.squeeze(0) for k, v in tokenized_input.items()}
        tokenized_target = {k: v.squeeze(0) for k, v in tokenized_target.items()}

        # 2. Dependency Parsing & SRL for Edge Extraction
        doc = nlp(text)
        edges = []
        edge_attrs = []
        for token in doc:
            # Dependency Parsing
            if token.dep_ != "ROOT" and token.i != token.head.i:
                edges.append([token.i, token.head.i])
                edge_attrs.append(self.tokenizer.vocab.get(token.dep_, self.tokenizer.unk_token_id))

            # SRL
            if token.dep_ in {"nsubj", "dobj", "nsubjpass"}:
                for child in token.children:
                    if child.dep_ == "prep":
                        for grandchild in child.children:
                            if grandchild.dep_ in {"pobj", "pcomp"}:
                                edges.append([token.i, grandchild.i])
                                edge_attrs.append(self.tokenizer.vocab.get("prep_" + grandchild.dep_, self.tokenizer.unk_token_id))

        # Edge Index Extraction and Validation
        if not edges:
            num_nodes = len(tokenized_input["input_ids"])
            edges = [[i, i] for i in range(num_nodes)]
        else:
            max_index = len(tokenized_input["input_ids"]) - 1
            edges = [(src, tgt) for src, tgt in edges
                     if 0 <= src <= max_index and 0 <= tgt <= max_index]

        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        edge_attrs = torch.tensor(edge_attrs, dtype=torch.long)

        # 3. Node Features with POS Tags
        pos_tags = [token.pos_ for token in doc]
        pos_tag_ids = [self.tokenizer.vocab.get(tag, self.tokenizer.unk_token_id) for tag in pos_tags]
        pos_tag_ids = torch.tensor(pos_tag_ids, dtype=torch.long)

        # Convert everything to tensors BEFORE padding/truncation
        input_ids = tokenized_input["input_ids"].clone().detach()
        attention_mask = tokenized_input["attention_mask"].clone().detach()
        labels = tokenized_target["input_ids"].clone().detach()

        # Handle potentially empty target sequences
        if len(labels) == 0:
            labels = torch.tensor([self.tokenizer.pad_token_id], dtype=torch.long)

        # Padding and Truncation
        max_length = 1024

        input_ids = input_ids[:max_length]
        attention_mask = attention_mask[:max_length]
        labels = labels[:max_length]
        pos_tag_ids = pos_tag_ids[:max_length]

        if len(input_ids) < max_length:
            pad_length = max_length - len(input_ids)
            pad_tensor = torch.full((pad_length,), self.tokenizer.pad_token_id)
            input_ids = torch.cat((input_ids, pad_tensor))
            attention_mask = torch.cat((attention_mask, torch.zeros(pad_length, dtype=torch.long)))
            pos_tag_ids = torch.cat((pos_tag_ids, torch.zeros(pad_length, dtype=torch.long)))

        if len(labels) < max_length:
            pad_length = max_length - len(labels)
            labels = torch.cat((labels, torch.full((pad_length,), -100)))

        if len(edge_attrs) < edge_index.size(1):
            pad_length = edge_index.size(1) - len(edge_attrs)
            edge_attrs = torch.cat((edge_attrs, torch.zeros(pad_length, dtype=torch.long)))

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "edges": edge_index,
            "edge_attrs": edge_attrs,
            "pos_tag_ids": pos_tag_ids,
            "sample_ids": torch.tensor([idx])
        }

#Minimum: Start with at least 2000-3000 samples.
#This should be enough to provide a good initial assessment of your model's performance and potential.

#Medium: If your computational resources allow, try 5000-7000 samples.
#This could provide a more robust evaluation and potentially lead to better performance.

#Maximum: If you have ample resources, consider using the entire dataset (around 10,000 samples).
#This would give you the most comprehensive training data possible and potentially lead
#to the best model performance.

#Reduce train_dataset size for POC
#POC_sample=26000

POC_sample=16000
import numpy as np
train_dataset = train_dataset.select(np.random.choice(len(train_dataset), POC_sample, replace=False))

### data augmentation #######
train_dataset_augmented = augment_data(train_dataset)
train_dataset = TextToSQLDataset(train_dataset_augmented, tokenizer)
#############################


POC_valsample=1
#############################
eval_dataset = eval_dataset.select(np.random.choice(len(eval_dataset), POC_valsample, replace=False))
test_dataset = test_dataset.select(np.random.choice(len(test_dataset), POC_valsample, replace=False))

eval_dataset = TextToSQLDataset(eval_dataset, tokenizer)
test_dataset = TextToSQLDataset(test_dataset, tokenizer)
#############################


# 4. GAT Layer and GraphModel
import torch
import torch.nn as nn
from torch_geometric.nn import GAT

class GATLayer(torch.nn.Module):
    def __init__(self, in_features, out_features, num_heads=8, num_layers=3):
        super(GATLayer, self).__init__()
        self.gat = GAT(in_channels=in_features, hidden_channels=out_features, heads=num_heads,
                        concat=False, num_layers=num_layers)

    def forward(self, x, edge_index, edge_attr=None):
        return self.gat(x, edge_index, edge_attr=edge_attr)

    def get_lora_target_modules(self):
        return [module for module in self.gat.modules() if isinstance(module, torch.nn.Linear)]


# 4b.  GraphModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv


from dataclasses import dataclass
from typing import Optional, Tuple

@dataclass
class MyCausalLMOutputWithPast:
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

class GraphModel(nn.Module):
    def __init__(self, encoder, tokenizer):
        super(GraphModel, self).__init__()
        self.encoder = encoder
        self.config = encoder.model.config

        # Adjust in_channels to match the actual input dimensionality
        self.gatv2 = GATv2Conv(
            in_channels=self.config.hidden_size,  # Set to 4096
            out_channels=self.config.hidden_size,
            heads=8,
            concat=False,
        )

        # Max Pooling
        self.pool = lambda x, batch: torch.max(x, dim=0, keepdim=True)[0]

        # Additional Feedforward Layer
        self.ffn = nn.Sequential(
            nn.Linear(self.config.hidden_size, self.config.hidden_size * 2),
            nn.ReLU(),
            nn.Linear(self.config.hidden_size * 2, self.config.hidden_size),
        )

        self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size)
        self.tokenizer = tokenizer

        # Add generation config
        self.generation_config = encoder.generation_config

    def forward(self, input_ids, attention_mask, edges, labels=None, inputs_embeds=None,
                pos_tag_ids=None, edge_attrs=None, sample_ids=None, output_attentions=False,
                output_hidden_states=False, return_dict=False):


        # 1. Token Embeddings (Encoder)
        if input_ids is not None:
            encoder_outputs = self.encoder(
                input_ids.to(self.encoder.model.device),
                attention_mask=attention_mask.to(self.encoder.model.device),
                output_hidden_states=True
            )
            embeddings = encoder_outputs.hidden_states[-1]
        elif inputs_embeds is not None:
            embeddings = inputs_embeds
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        # Ensure correct shape for GATv2Conv
        if embeddings.dim() > 2:
            embeddings = embeddings.view(-1, embeddings.shape[-1])

         # 2. Obtain POS tag embeddings and concatenate
        if pos_tag_ids is not None:
            pos_tag_embeddings = self.encoder.model.embeddings(pos_tag_ids.to(self.encoder.model.device))
            embeddings = torch.cat([embeddings, pos_tag_embeddings], dim=-1)


        # 3. Edge Index Creation (with potential optimization for memory)
        edge_index = []
        node_offset = 0
        for i, graph_edges in enumerate(edges):
            if graph_edges is None or graph_edges.numel() == 0:
                num_nodes = input_ids.size(1)
                graph_edges = torch.arange(node_offset, node_offset + num_nodes, device=embeddings.device)
                graph_edges = graph_edges.repeat(2, 1)
            else:
                if not isinstance(graph_edges, torch.Tensor):
                    graph_edges = torch.tensor(graph_edges, dtype=torch.long, device=embeddings.device)
                graph_edges += node_offset
            edge_index.append(graph_edges)
            node_offset += input_ids.size(1)
        edge_index = torch.cat(edge_index, dim=1)

        # 4. GATv2 Layer (pass edge_attrs)
        graph_out = self.gatv2(embeddings, edge_index, edge_attr=edge_attrs)


        # 5. Pooling
        batch = torch.arange(len(edges), device=graph_out.device).repeat_interleave(input_ids.size(1))
        pooled = self.pool(graph_out, batch).unsqueeze(1)

        # 5.1 - Additional Feedforward Layer
        pooled = self.ffn(pooled)

        # 6. LM Head
        logits = self.lm_head(pooled)

        # 7. Loss Calculation
        loss = None
        if labels is not None:
            from torch.nn import CrossEntropyLoss

            mask = (labels != -100).float()

            # Apply softmax to logits to get probabilities
            log_probs = F.log_softmax(logits, dim=-1)

            loss_fct = CrossEntropyLoss(ignore_index=-100)

            # Reshape log_probs to match target shape
            log_probs = log_probs.squeeze(1)

            labels = labels[:, 0]

            loss = loss_fct(log_probs, labels)

            loss_per_sample = (loss * mask).sum(dim=1) / mask.sum(dim=1)

        # 8. Return
        return {
            "loss": loss,
            "logits": logits,
            "past_key_values": encoder_outputs.past_key_values,
            "hidden_states": encoder_outputs.hidden_states,
            "attentions": encoder_outputs.attentions,
        }

    def prepare_inputs_for_generation(self, input_ids, edges, attention_mask=None,
                                      pos_tag_ids=None, edge_attrs=None, **kwargs):
        if isinstance(self, PeftModel):
            return self.base_model.prepare_inputs_for_generation(input_ids, edges, attention_mask, **kwargs)

        batch_size = input_ids.size(0)
        if batch_size > 1:
            batched_edges = []
            node_offset = 0
            for i in range(batch_size):
                graph_edges = edges[i]
                batched_edges.extend([(src + node_offset, dst + node_offset) for src, dst in graph_edges])
                node_offset += input_ids.size(1)
            edge_index = torch.tensor(batched_edges, dtype=torch.long).t().contiguous().to(input_ids.device)
        else:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous().to(input_ids.device)

        model_inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "edges": edge_index,
            "pos_tag_ids": pos_tag_ids,
            "edge_attrs": edge_attrs,
            "past_key_values": kwargs.get("past_key_values", None),
        }
        return model_inputs

# END GraphModel

from peft import prepare_model_for_kbit_training

# Quantize Mistral (before creating GraphModel)
mistral_model = prepare_model_for_kbit_training(mistral_model, use_gradient_checkpointing=True)


# 5. Model Setup (Define model first)
model = GraphModel(mistral_model, tokenizer)  # Pass both mistral_model and tokenizer


# 6. PEFT Configuration (Use automatic module discovery)
peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    # Instead of targeting specific layers, let PEFT automatically discover the linear layers within the model
    task_type="CAUSAL_LM",
)


# 7. Apply PEFT
model = get_peft_model(model, peft_config)
print('\n\n')
print('PEFT-Model')
model.print_trainable_parameters() # To see the trainable parameters
print('\n')

# Access the config of the encoder (Mistral model) within your GraphModel
model.encoder.config.use_cache = False

# Ensure that LoRA layers are properly initialized and their dimensions are correctly set
for name, module in model.named_modules():
    if "lora" in name:
        module = module.to(device) # Move LoRA parameters to the correct device

model.encoder.gradient_checkpointing_enable()  # Enable gradient checkpointing for memory optimization on the Mistral model
#model.encoder.model.embed_tokens.requires_grad_(True)
torch.cuda.empty_cache()


# 8. Evaluation Metric (Semantic Similarity)
metric = evaluate.load("exact_match")
sentence_transformer_model = SentenceTransformer('all-mpnet-base-v2')

def compute_metrics(eval_pred):
    all_preds, all_labels = eval_pred

    # Convert all elements to tensors, handling different data types
    predictions = [torch.tensor(pred) if not isinstance(pred, torch.Tensor) else pred for pred in all_preds]
    labels = [torch.tensor(label) if not isinstance(label, torch.Tensor) else label for label in all_labels]

    # Filter out any None values before stacking
    predictions = [pred for pred in predictions if pred is not None]
    labels = [label for label in labels if label is not None]

    # Convert to tensors and stack (only if there are predictions/labels)
    if predictions:
        predictions = torch.stack(predictions).squeeze()
    else:
        predictions = torch.tensor([])  # Empty tensor if no predictions

    if labels:
        labels = torch.stack(labels).squeeze()
    else:
        labels = torch.tensor([])  # Empty tensor if no labels

    # Handle cases where only one prediction/label is present (avoid squeezing to a scalar)
    if predictions.dim() == 0:
        predictions = predictions.unsqueeze(0)
    if labels.dim() == 0:
        labels = labels.unsqueeze(0)

    # Print shapes for debugging
    #print('\n')
    #print(f"Shape of logits in compute_metrics: {predictions.shape}")
    #print(f"Shape of labels in compute_metrics: {labels.shape}")
    #('\n')


    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    em = metric.compute(predictions=decoded_preds, references=decoded_labels)["exact_match"]

    return {"exact_match": em}



#9. Training Arguments and Trainer
training_args = TrainingArguments(
    output_dir ="/content/gdrive/MyDrive/model/GNNT2SQL",
    logging_dir="/content/gdrive/MyDrive/model/GNN-T2SQL/logs",

    per_device_train_batch_size=1,  # Slightly increased, but be cautious
    gradient_accumulation_steps=4,  # Adjusted for effective batch size of 4

    # Number of epochs and early stopping
    num_train_epochs=1,  # Start with a few epochs and monitor validation loss
    #early_stopping_patience=3,  # Enable early stopping to prevent overfitting

    # Learning rate and scheduler
    learning_rate=5e-5,  # A reasonable starting point, adjust if needed
    lr_scheduler_type="linear",  # Or try other schedulers like "cosine"
    warmup_steps=500,  # Warmup is crucial, especially with a larger learning rate

    # Evaluation and saving
    eval_strategy="steps",
    eval_steps=50,  # Evaluate less frequently to save time #500
    save_strategy="steps",
    save_steps=50, #500
    logging_strategy="steps",
    logging_steps=100,

    # Other settings
    push_to_hub=False,
    load_best_model_at_end=True,
    use_legacy_prediction_loop=False,
    metric_for_best_model="eval_exact_match",
    report_to="tensorboard",
    #generation_max_length=2048,  # Adjust if needed based on your data
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)

from torch.nn.utils.rnn import pad_sequence
from torch_geometric.data import Batch as GraphBatch  # Note the import
import torch
import torch_geometric.data
from torch.nn.utils.rnn import pad_sequence


class GraphDataCollatorForSeq2Seq:
    def __init__(self, tokenizer, model=None, label_pad_token_id=-100, pad_to_multiple_of=None):
        self.tokenizer = tokenizer
        self.model = model
        self.label_pad_token_id = label_pad_token_id
        self.pad_to_multiple_of = pad_to_multiple_of

    def __call__(self, features):
        # Separate standard features from graph edges
        # Extract labels before padding and handle potentially empty sequences
        labels = [feature["labels"] if feature["labels"].numel() > 0
                  else torch.tensor([self.label_pad_token_id], dtype=torch.long)
                  for feature in features]

        # Extract sample_ids
        sample_ids = [feature["sample_ids"] for feature in features]

        standard_features = [{k: v for k, v in feature.items() if k != "edges" and k != "labels"} for feature in features]
        edges = [feature["edges"] for feature in features]

        # Collate standard features (input_ids, attention_mask) using default collator
        collated_standard_features =  DataCollatorForSeq2Seq(
            tokenizer=self.tokenizer,
            model=self.model,
            label_pad_token_id=self.label_pad_token_id,
            pad_to_multiple_of=self.pad_to_multiple_of
        )(standard_features)

        # Pad input_ids and attention_mask
        input_ids = pad_sequence([f['input_ids'] for f in standard_features], batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence([f['attention_mask'] for f in standard_features], batch_first=True, padding_value=0)

         # Pad labels separately
        labels = pad_sequence(labels, batch_first=True, padding_value=self.label_pad_token_id)

        # Create batch for graph data
        graph_data_list = []
        for i in range(len(edges)):
            # Convert to PyTorch Geometric Data
            graph_data_list.append(torch_geometric.data.Data(
                x=collated_standard_features['input_ids'][i].unsqueeze(1),  # Node features (input_ids)
                edge_index=edges[i],                   # Edge index
                # Use num_edges for batch index to ensure correct batching in PyG
                batch=torch.tensor([i] * edges[i].size(1))
            ))
        batched_graph = GraphBatch.from_data_list(graph_data_list)  # Batch graphs

        #print(f"Sample GraphDataCollatorForSeq2Seq ID: {sample_ids}")

         # Include sample_ids in the collated features
        collated_features = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "edges": edges,
            "sample_ids": sample_ids  # Add sample_ids here
        }

        return collated_features


# 10a.  Data Collator
data_collator = GraphDataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=8
)


# 10AA. Training Arguments and Trainer

from transformers import Trainer
from transformers.trainer_utils import EvalLoopOutput, PredictionOutput,  has_length

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

from transformers.trainer_utils import EvalPrediction
from transformers import Trainer, TrainerCallback

from typing import Any, Dict, List, Optional, Tuple, Union
#from transformers.trainer_pt_utils import nested_truncate  # Updated import#

from transformers.trainer_pt_utils import nested_truncate, nested_concat, nested_numpify, nested_detach  # Updated imports


######
from torch.nn import CrossEntropyLoss

from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from transformers import Trainer
from transformers.trainer_utils import EvalPrediction
import torch
import torch.nn.functional as F

def estimated_num_samples(dataloader: DataLoader):
    """
    This function will attempt to determine the number of samples in a DataLoader.

    Args:
        dataloader (DataLoader): The DataLoader to estimate the number of samples from.

    Returns:
        int: The estimated number of samples, or 0 if estimation is not possible.
    """
    if hasattr(dataloader, "dataset") and hasattr(dataloader.dataset, "__len__"):
        return len(dataloader.dataset)  # Use dataset length if available
    elif hasattr(dataloader, "batch_sampler") and hasattr(dataloader.batch_sampler, "sampler") and hasattr(dataloader.batch_sampler.sampler, "__len__"):
        return len(dataloader.batch_sampler.sampler)  # Use sampler length if available
    else:
        # If neither is available, return 0
        warnings.warn("Could not estimate the number of samples in the dataloader. Returning 0.")
        return 0

class CustomTrainer(Trainer):
    _id=0

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        self.tokenizer = tokenizer
        inputs = self._prepare_inputs(inputs)
        labels = inputs.pop("labels", None)

        print('\n\n')
        print("**** Prediction Step ****")


        with torch.no_grad():
            outputs = model(**inputs)

            # Access loss and logits based on the type of outputs
            if isinstance(outputs, MyCausalLMOutputWithPast):
                loss = outputs.loss
                logits = outputs.logits
            elif isinstance(outputs, tuple):
                # Check if the tuple has at least two elements before accessing
                if len(outputs) >= 2:
                    loss = outputs[0]  # Assuming loss is the first element
                    logits = outputs[1]  # Assuming logits is the second element
                else:
                    # Handle the case where the tuple has fewer than two elements
                    raise ValueError("Output tuple from model has fewer than two elements")
            else:
                #raise ValueError("Unexpected output type from model")
                # Access loss and logits as attributes for other types
                loss = outputs["loss"] # Access loss from the dictionary
                logits = outputs["logits"] # Access logits from the dictionary


            # Debugging Logging (ensure labels exist before accessing)
            if not prediction_loss_only:
                print('\n\n')
                print("Shape of logits in prediction_step:", logits.shape)
                if labels is not None:  # Only print if labels exist
                    print("Shape of labels in prediction_step:", labels.shape)

            if prediction_loss_only:
                if isinstance(loss, torch.Tensor):
                    loss = loss.mean().detach()
                return (loss, None, None)

            max_new_tokens = 1024 - inputs['input_ids'].shape[1]

            # Modify this part to handle the generated IDs
            if max_new_tokens <= 0:
                print('\n\n')
                print("No new tokens to generate. An increase in the sample size for training is required. ")
                #print("#TOKENS: ",max_new_tokens)
                # Input is already at max length, no need to generate
                generated_ids = inputs['input_ids']
            else:
                generated_ids = model.encoder.generate(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    max_new_tokens=max_new_tokens,
                    num_beams=5
                )

            # Check the shape of generated_ids
            #print("Shape of generated_ids:", generated_ids.shape)



        # Flatten the generated_ids to a 1D list before decoding
        flattened_generated_ids = generated_ids.view(-1).tolist()
        predictions=generated_ids

        # Now decode using the flattened list
        predictions_decoder = self.tokenizer.decode(flattened_generated_ids, skip_special_tokens=True)
        Q = self.tokenizer.decode(inputs['input_ids'].view(-1).tolist(), skip_special_tokens=True)
        A = self.tokenizer.decode(labels.view(-1).tolist(), skip_special_tokens=True)
        #print("Sample IDs in prediction_step:", inputs['sample_ids'][0])


        # Debugging Logging
        #print("Shape of predictions in prediction_step:", predictions.shape)
        if labels is not None:
            print('\n\n')
            # Extract sample_ids
            sample_ids = int(self._id)+1  # Use self._id here
            #print("Sample IDs in prediction_step:", sample_ids)
            print("Sample IDs in prediction_step:", inputs['sample_ids'])

            print("Question:", Q)
            print("Decoded Original Answer BEFORE Predictions:", A)
            print('\n\n')

            print("Decoded Predictions:", predictions_decoder)

            return (loss, predictions, labels)

    def _prediction_loop(self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval") -> Union[Tuple[torch.Tensor, torch.Tensor], EvalPrediction]:


         # In case you have a callback that needs length eventually
        if has_length(dataloader):
            num_samples = len(dataloader.dataset)
        # The dataset does not support __len__, estimate the number of samples.
        else:
            num_samples = estimated_num_samples(dataloader)

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses:
            all_losses = all_losses[:num_samples]
        if all_preds:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels:
            all_labels = nested_truncate(all_labels, num_samples)

        # 8.  Compute Metrics and Average Loss
        metrics = self.compute_metrics((all_preds, all_labels))
        average_loss = torch.mean(torch.stack(all_losses))
        metrics[f"{metric_key_prefix}_loss"] = average_loss.item()

        # 9.  Log the Metrics
        self.log(metrics)

        # 10. Return Based on Whether It's a Prediction or Evaluation
        if prediction_loss_only:
            return (metrics, None, None)

        return EvalPrediction(predictions=all_preds, label_ids=all_labels, metrics=metrics)


    def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
        test_dataloader = self.get_test_dataloader(test_dataset)
        return self.prediction_loop(test_dataloader, description="Prediction")

    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:

        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        output = self.prediction_loop(
            eval_dataloader,
            description="Evaluation",
        )

        return output.metrics


######

# 10A. Trainer (Modified)
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, None)
)


from transformers import TrainerCallback
class LossLoggingCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 50000 == 0:  # Log every 100 steps (adjust as needed)
            print(f"Step {state.global_step} - Loss: {state.loss}")




# 11. Train the model
# Add the Callback to the Trainer
trainer.add_callback(LossLoggingCallback())

# Add the Early Stopping to the Trainer
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3))

trainer.train()


# 12. Evaluate on the test set
#test_results = trainer.evaluate(eval_dataset)
#print('\n\n')
#print(f'Test Semantic Similarity: {test_results["eval_semantic_similarity"]:.4f}')
#print(f'Test Exact Match. (Evaluate on the test set): {test_results["eval_exact_match"]:.4f}')
#print('\n\n')

1. Units per Hour: You're using 4.82 compute units every hour.

2. Total Time Available:

Divide your total compute units by your hourly usage rate: 309.66 units / 4.82 units/hour = 64.24 hours

## Compute Metrics

What is Logical Correctness?

In the realm of T2SQL (Text-to-SQL), logical correctness goes beyond mere syntactic accuracy and execution success. It evaluates whether the generated SQL query genuinely captures the intent and nuances of the natural language input, ensuring that it retrieves the desired information from the database.

Consider these scenarios where a generated SQL query might be syntactically valid and even execute without errors, yet still be logically incorrect:

Incorrect column or table selection: The query might fetch data from the wrong columns or tables, resulting in irrelevant or inaccurate results.
Incorrect filtering or aggregation: The WHERE clause or aggregation functions (e.g., SUM, COUNT) might be misapplied, leading to filtered or aggregated data that doesn't align with the user's intent.
Incorrect joins: If the natural language input implies relationships between multiple tables, the generated query might have incorrect or missing joins, producing misleading results.
Subtle semantic mismatches: Even if the query produces results, they might not fully capture the nuances and implied meaning of the original question or instruction.
Why is it Important?

While execution accuracy is crucial, logical correctness ensures that the generated SQL truly "understands" the user's request and provides the right answer, not just an answer. It's a key indicator of the T2SQL model's ability to reason about the data schema and the user's intent.

How to Measure Logical Correctness

Measuring logical correctness can be challenging, as it often requires a deeper understanding of the underlying data schema and the subtle nuances of natural language. Here are some common approaches:

Manual Inspection:

A human evaluator examines a sample of generated queries and their corresponding ground truth queries to assess if they capture the same intent.
Pros: Provides valuable qualitative insights and can catch subtle semantic mismatches.
Cons: Time-consuming and not scalable for large datasets.
SQL Parsing and Comparison:

Use a SQL parser to extract structural components (e.g., SELECT, FROM, WHERE, GROUP BY) from both generated and ground truth queries.
Compare these components to identify mismatches or inconsistencies.
Pros: Can be automated and is relatively scalable.
Cons: Might miss subtle semantic differences or struggle with complex SQL constructs.
Result Comparison:

Execute both the generated and ground truth queries against the database.
Compare the results to see if they are equivalent or sufficiently similar.
Pros: Directly measures the impact of logical errors on the output.
Cons: Requires access to the database, can be computationally expensive, and might not capture all types of logical errors.
Hybrid Approaches:

Combine manual inspection with automated techniques to leverage their strengths.
For example, use SQL parsing to identify potential logical errors and then have a human evaluator review those cases.
Choosing the Right Approach

The ideal approach for measuring logical correctness depends on factors like:

Data schema complexity: More complex schemas might require more sophisticated techniques.
Evaluation scale: Manual inspection might be feasible for smaller datasets, while automated methods are necessary for larger ones.
Available resources: Access to a database and computational power will influence the feasibility of certain approaches.
Desired level of rigor: The trade-off between efficiency and thoroughness will guide your choice.
Remember, evaluating logical correctness is an ongoing challenge in T2SQL research. Combining multiple approaches and iteratively refining your evaluation metrics will lead to more robust and reliable T2SQL models.


In [None]:
import torch
from sentence_transformers import SentenceTransformer, util
from datasets import load_metric
import sqlparse
import psycopg2

# Assuming you have 'tokenizer' and 'sentence_transformer_model' defined elsewhere

# Load the metric for exact match calculation
metric = load_metric("exact_match")

def compute_metrics(eval_pred, db_config=None):
    all_preds, all_labels = eval_pred

    # Convert predictions and labels to tensors
    predictions = [
        torch.tensor(pred) if not isinstance(pred, torch.Tensor) else pred
        for pred in all_preds
    ]
    labels = [
        torch.tensor(label) if not isinstance(label, torch.Tensor) else label
        for label in all_labels
    ]

    # Filter out None values
    predictions = [pred for pred in predictions if pred is not None]
    labels = [label for label in labels if label is not None]

    # Stack predictions and labels if they exist
    if predictions:
        predictions = torch.stack(predictions).squeeze()
    else:
        predictions = torch.tensor([])
    if labels:
        labels = torch.stack(labels).squeeze()
    else:
        labels = torch.tensor([])

    # Handle single prediction/label cases
    if predictions.dim() == 0:
        predictions = predictions.unsqueeze(0)
    if labels.dim() == 0:
        labels = labels.unsqueeze(0)

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute exact match accuracy
    em = metric.compute(predictions=decoded_preds, references=decoded_labels)["exact_match"]

    # Compute semantic similarity
    embeddings_pred = sentence_transformer_model.encode(decoded_preds)
    embeddings_labels = sentence_transformer_model.encode(decoded_labels)
    semantic_similarity = util.cos_sim(embeddings_pred, embeddings_labels).mean()

    # Execution Accuracy (if db_config is provided)
    if db_config:
        execution_accuracy_scores = []
        with psycopg2.connect(**db_config) as conn:
            with conn.cursor() as cur:
                for pred_sql, label_sql in zip(decoded_preds, decoded_labels):
                    try:
                        # Execute predicted query
                        cur.execute(pred_sql)
                        pred_results = cur.fetchall()

                        # Execute label (ground truth) query
                        cur.execute(label_sql)
                        label_results = cur.fetchall()

                        # Compare results
                        if pred_results == label_results:
                            execution_accuracy_scores.append(1)
                        else:
                            execution_accuracy_scores.append(0)

                            # Optional: Log mismatches for debugging
                            logging.debug(f"Mismatch: Predicted: {pred_results}, Label: {label_results}")

                    except psycopg2.Error as e:
                        execution_accuracy_scores.append(0)
                        logging.error(f"Error executing SQL: {e}")

        execution_accuracy = sum(execution_accuracy_scores) / len(execution_accuracy_scores)
    else:
        execution_accuracy = 0.0

    # Logical Correctness (SQL Parsing and Comparison)
    logical_correctness_scores = []
    for pred_sql, label_sql in zip(decoded_preds, decoded_labels):
        parsed_pred = sqlparse.parse(pred_sql)[0]
        parsed_label = sqlparse.parse(label_sql)[0]

        pred_components = {
            "select": [token.value for token in parsed_pred.tokens if isinstance(token, sqlparse.sql.IdentifierList)],
            "from": [token.value for token in parsed_pred.tokens if isinstance(token, sqlparse.sql.Identifier)],
            # ... extract other components like WHERE, GROUP BY, etc.
        }
        label_components = {
            "select": [token.value for token in parsed_label.tokens if isinstance(token, sqlparse.sql.IdentifierList)],
            "from": [token.value for token in parsed_label.tokens if isinstance(token, sqlparse.sql.Identifier)],
            # ... extract other components like WHERE, GROUP BY, etc., using the same logic as for pred_components
        }

        score = 0
        for component_type in ["select", "from", ...]:
            if pred_components[component_type] == label_components[component_type]:
                score += 1
            elif set(pred_components[component_type]) == set(label_components[component_type]):
                score += 0.5
            # ... add more sophisticated comparison logic if needed

        logical_correctness_scores.append(score)

    logical_correctness = sum(logical_correctness_scores) / len(logical_correctness_scores)


    return {
        "exact_match": em,
        "semantic_similarity": semantic_similarity.item(),
        "execution_accuracy": execution_accuracy,
        "logical_correctness": logical_correctness,
    }


## MISTRAL-MODEL

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
#from trl import setup_chat_format

# Hugging Face model id
model_id = "mistralai/Mistral-7B-Instruct-v0.1"

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)


tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load model and tokenizer
mistral_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

tokenizer.padding_side = 'right' # to prevent warnings

# We redefine the pad_token and pad_token_id with out of vocabulary token (unk_token)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id

## Postgresql Setup

In [None]:
#ADDED By FM 01/06/2024
!apt-get update -y
!apt-get install postgresql-14 -y

!service postgresql restart
!sudo apt install postgresql-server-dev-all

In [None]:
!sudo -u postgres psql -c "CREATE USER postgres WITH SUPERUSER"
!sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'postgres'"

#ERROR:  role "postgres" already exists
#ALTER ROLE

QUERY_create='CREATE TABLE table_name_24 (score VARCHAR, date VARCHAR)'


QUERY_select='SELECT 2009 FROM table_name_50 WHERE 2011 = "a"'

In [None]:
def table_creator(query):
    import os
    import psycopg2 as ps
    import pandas as pd

    DB_NAME = "postgres"
    DB_USER = "postgres"
    DB_PASS = "postgres"
    DB_HOST = "localhost"
    DB_PORT = "5432"

    conn = ps.connect(database=DB_NAME,
                  user=DB_USER,
                  password=DB_PASS,
                  host=DB_HOST,
                  port=DB_PORT)

    cur = conn.cursor() # creating a cursor

    # Wrap the execute command in a try-except block to handle potential errors
    try:
        cur.execute("""
                            %s
                            """%query)
        conn.commit()
        print("Table Created successfully")
    except Exception as e:
        conn.rollback() # Rollback the transaction in case of an error
        print("Error creating table:", e)

    conn.close()

In [None]:
import os
import psycopg2 as ps
import pandas as pd

DB_NAME = "postgres"
DB_USER = "postgres"
DB_PASS = "postgres"
DB_HOST = "localhost"
DB_PORT = "5432"

In [None]:
import os
import psycopg2 as ps
import pandas as pd

def table_select(query):
    conn = ps.connect(database=DB_NAME,
                      user=DB_USER,
                      password=DB_PASS,
                      host=DB_HOST,
                      port=DB_PORT)
    print("Database connected successfully")

    #query = query.replace('"', "'") # Replace double quotes with single quotes for potential date values

    try:

        #df = pd.read_sql_query("%s"%query, con=conn)
        #print('rec: %'%df) # Print the resulting DataFrame

        cur = conn.cursor()
        cur.execute(query)
        rows = cur.fetchall()
        conn.commit()
        conn.close()
        print('\n')
        print('Record(s): %s \n'%len(rows))
        for row in rows:
            print(row)


        eqc=1

    except Exception as e:
        eqc=0
        #conn.rollback() # Rollback the transaction in case of an error
        print("Error executing query:", e)
        #print('TABLE IS EMPTY')
        conn.commit()
        conn.close()

    return eqc

In [None]:
table_creator(QUERY_create)

Table Created successfully


## evaluate_model

In [None]:
import logging
import psycopg2
from transformers import Trainer

# Your PostgreSQL configuration
DB_NAME = "postgres"
DB_USER = "postgres"
DB_PASS = "postgres"
DB_HOST = "localhost"
DB_PORT = "5432"

def evaluate_model(model, tokenizer, sentence_transformer_model, eval_dataset):
    """
    Evaluates a model on a given dataset and computes exact match and semantic similarity metrics,
    after creating necessary tables in the PostgreSQL database.

    Args:
        model: The fine-tuned model to be evaluated.
        tokenizer: The tokenizer used for encoding and decoding text.
        sentence_transformer_model: The sentence transformer model used for semantic similarity.
        eval_dataset: The dataset on which the model will be evaluated.

    Returns:
        A dictionary containing the evaluation results: 'exact_match' and 'semantic_similarity'.
    """

    # 1. Table Creation
    db_config = {
        'database': DB_NAME,
        'user': DB_USER,
        'password': DB_PASS,
        'host': DB_HOST,
        'port': DB_PORT
    }

    with psycopg2.connect(**db_config) as conn:
        with conn.cursor() as cur:
            for example in eval_dataset.dataset:
                context = example['context']
                create_table_statements = context.split(';')
                for create_table_statement in create_table_statements:
                    if create_table_statement.strip():
                        try:
                            cur.execute(create_table_statement)
                            conn.commit()
                        except psycopg2.Error as e:
                            logging.error(f"Error creating table: {e}")

    # Create a Trainer instance with the compute_metrics function
    trainer = Trainer(
        model=model,
        compute_metrics=compute_metrics,  # Remove the lambda function and db_config
    )

    # Evaluate the model
    results = trainer.evaluate(eval_dataset)

    return results


## EVALUATOR


In [None]:
!pip install peft --quiet
!pip install colab-env --quiet
# Install Hugging Face libraries
!pip install  --upgrade transformers accelerate evaluate bitsandbytes --quiet


#FlashAttention only supports Ampere GPUs or newer. #NEED A100 , L4  IN GOOGLE COLAB
!pip install -U flash-attn --no-build-isolation --quiet

import colab_env
import evaluate

Mistral-7B

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
#from trl import setup_chat_format

# Hugging Face model id
model_id = "mistralai/Mistral-7B-Instruct-v0.1" #24 JUNE 2024

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(model_id,use_fast=True)
tokenizer.padding_side = 'right' # to prevent warnings

# We redefine the pad_token and pad_token_id with out of vocabulary token (unk_token)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id

# # set chat template to OAI chatML, remove if you start from a fine-tuned model
#model, tokenizer = setup_chat_format(model, tokenizer)

PEFT PURE

In [None]:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline

peft_model_id = "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval"


# Path to your PEFT adapter weights - LOCALDISK
#peft_model_id = '/content/gdrive/MyDrive/model/GNNT2SQL/checkpoint-1950/'


# Load Model with PEFT adapter
# Removed torch.inference_mode() context manager
model = AutoPeftModelForCausalLM.from_pretrained(
    peft_model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.float16
)

# Force model to initialize weights by running a dummy forward pass
dummy_input = torch.zeros((1, 1), dtype=torch.long, device=model.device)  # Create a dummy input tensor
_ = model(dummy_input)  # Run a forward pass to initialize weights

tokenizer = AutoTokenizer.from_pretrained(peft_model_id)

# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

PEFT MIX

In [None]:
import torch
from peft import AutoPeftModelForCausalLM, PeftModel
from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig

# Path to your PEFT adapter weights
peft_model_id = '/content/gdrive/MyDrive/model/GNNT2SQL/checkpoint-1950/'

#peft_model_id = "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval"

# Load the PEFT adapter
model = PeftModel.from_pretrained(model, peft_model_id)

#tokenizer = AutoTokenizer.from_pretrained(base_model_id) # Use the base model ID here

# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

In [None]:
from datasets import load_dataset

# Load our test dataset
eval_dataset = load_dataset("json", data_files="/content/gdrive/MyDrive/datasets/test_dataset.json", split="train")

#eval_dataset

eval_dataset[0]

In [None]:
from difflib import SequenceMatcher

def similar(a, b):
    return SequenceMatcher(None, a, b).ratio()

similar("Apple","Appel")

0.8

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

from tqdm import tqdm
from random import randint
from datasets import load_dataset
import psycopg2
from psycopg2 import sql

def evaluate(sample):
    prompt = pipe.tokenizer.apply_chat_template(sample["messages"][:2], tokenize=False, add_generation_prompt=True)
    outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
    predicted_answer = outputs[0]['generated_text'][len(prompt):].strip()

    print()
    print()
    schema=sample["messages"][0]['content']
    schema_query=schema[153:len(schema)]
    question = sample["messages"][1]["content"]
    original_answer = sample["messages"][2]["content"]

    ps=similar(predicted_answer,original_answer)

    if ps >= 0.95:
        print('\n')
        print(f'Generated Answer-SIMILARY: {predicted_answer}')
        print(f' Original Answer-SIMILARY: {original_answer}')
        print(f'        SIMILARY: {ps}')
        print('\n\n')
        predicted_answer=original_answer

    if predicted_answer ==  original_answer:
        print()
        print()
        print('SUCCESS!')
        print()
        print(f'QUESTION: {question}')
        print()
        print(f'SCHEMA QUERY: {schema_query}')
        #table_creator(schema_query)
        print()
        print(f'Generated Answer: {predicted_answer}')
        #table_select(predicted_answer)
        print()
        print(f'Original Answer: {original_answer}')
        print()
        return 1
    else:
        print()
        print()
        print('NO - SUCCESS!')
        print()
        print(f'QUESTION: {question}')

        #ps=similar(predicted_answer,original_answer)
        print(f'Generated Answer: {predicted_answer}')
        print(f' Original Answer: {original_answer}')
        print(f'        SIMILARY: {ps}')
        print()

        return 0

success_rate = []
number_of_eval_samples = 10

# iterate over eval dataset and predict
for n in tqdm(range(number_of_eval_samples)):
    s=eval_dataset[n]
    success_rate.append(evaluate(s))

# compute accuracy
accuracy = sum(success_rate)/len(success_rate)
#print(f'Accuracy: {accuracy}')

  0%|          | 0/10 [00:00<?, ?it/s]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
 10%|█         | 1/10 [00:04<00:41,  4.63s/it]





Generated Answer-SIMILARY: SELECT COUNT(*) FROM Has_allergy AS T1 JOIN Allergy_type AS T2 ON T1.allergy = T2.allergy WHERE T2.allergytype = "food"
 Original Answer-SIMILARY: SELECT COUNT(*) FROM Has_allergy AS T1 JOIN Allergy_type AS T2 ON T1.allergy = T2.allergy WHERE T2.allergytype = "food"
        SIMILARY: 1.0





SUCCESS!

QUESTION: How many students have a food allergy?

SCHEMA QUERY: CREATE TABLE Has_allergy (allergy VARCHAR); CREATE TABLE Allergy_type (allergy VARCHAR, allergytype VARCHAR)

Generated Answer: SELECT COUNT(*) FROM Has_allergy AS T1 JOIN Allergy_type AS T2 ON T1.allergy = T2.allergy WHERE T2.allergytype = "food"

Original Answer: SELECT COUNT(*) FROM Has_allergy AS T1 JOIN Allergy_type AS T2 ON T1.allergy = T2.allergy WHERE T2.allergytype = "food"



 20%|██        | 2/10 [00:10<00:41,  5.16s/it]





NO - SUCCESS!

QUESTION: Return the name and gender of the staff who was assigned in 2016.
Generated Answer: SELECT staff_name, staff_gender FROM staff AS t1 JOIN staff_department_assignments AS t2 ON t1.staff_id = t2.staff_id WHERE t2.date_assigned_from = 2016
 Original Answer: SELECT T1.staff_name, T1.staff_gender FROM staff AS T1 JOIN staff_department_assignments AS T2 ON T1.staff_id = T2.staff_id WHERE T2.date_assigned_from LIKE "2016%"
        SIMILARY: 0.9240506329113924



 30%|███       | 3/10 [00:13<00:30,  4.38s/it]





Generated Answer-SIMILARY: SELECT MAX(week) FROM table_name_47 WHERE opponent = "cleveland browns" AND attendance = 54 OFFSET 205
 Original Answer-SIMILARY: SELECT MAX(week) FROM table_name_47 WHERE opponent = "cleveland browns" AND attendance > 54 OFFSET 205
        SIMILARY: 0.9901960784313726





SUCCESS!

QUESTION: What is the highest week for Cleveland Browns with 54,205 in attendance?

SCHEMA QUERY: CREATE TABLE table_name_47 (week INTEGER, opponent VARCHAR, attendance VARCHAR)

Generated Answer: SELECT MAX(week) FROM table_name_47 WHERE opponent = "cleveland browns" AND attendance > 54 OFFSET 205

Original Answer: SELECT MAX(week) FROM table_name_47 WHERE opponent = "cleveland browns" AND attendance > 54 OFFSET 205



 40%|████      | 4/10 [00:16<00:22,  3.71s/it]





Generated Answer-SIMILARY: SELECT written_by FROM table_29087004_2 WHERE production_code = 116
 Original Answer-SIMILARY: SELECT written_by FROM table_29087004_2 WHERE production_code = 116
        SIMILARY: 1.0





SUCCESS!

QUESTION: Who wrote the episode with a production code of 116?

SCHEMA QUERY: CREATE TABLE table_29087004_2 (written_by VARCHAR, production_code VARCHAR)

Generated Answer: SELECT written_by FROM table_29087004_2 WHERE production_code = 116

Original Answer: SELECT written_by FROM table_29087004_2 WHERE production_code = 116



 50%|█████     | 5/10 [00:17<00:14,  2.98s/it]





Generated Answer-SIMILARY: SELECT fleet FROM table_name_61 WHERE number = "l4"
 Original Answer-SIMILARY: SELECT fleet FROM table_name_61 WHERE number = "l4"
        SIMILARY: 1.0





SUCCESS!

QUESTION: what fleet is associated with the number L4?

SCHEMA QUERY: CREATE TABLE table_name_61 (fleet VARCHAR, number VARCHAR)

Generated Answer: SELECT fleet FROM table_name_61 WHERE number = "l4"

Original Answer: SELECT fleet FROM table_name_61 WHERE number = "l4"



 60%|██████    | 6/10 [00:20<00:11,  2.93s/it]





Generated Answer-SIMILARY: SELECT console FROM table_12887260_1 WHERE franchise_or_game = "Shenmue"
 Original Answer-SIMILARY: SELECT console FROM table_12887260_1 WHERE franchise_or_game = "Shenmue"
        SIMILARY: 1.0





SUCCESS!

QUESTION: What consoles was Shenmue released on?

SCHEMA QUERY: CREATE TABLE table_12887260_1 (console VARCHAR, franchise_or_game VARCHAR)

Generated Answer: SELECT console FROM table_12887260_1 WHERE franchise_or_game = "Shenmue"

Original Answer: SELECT console FROM table_12887260_1 WHERE franchise_or_game = "Shenmue"



 70%|███████   | 7/10 [00:22<00:07,  2.46s/it]





Generated Answer-SIMILARY: SELECT COUNT(*) FROM customers WHERE city = "Prague"
 Original Answer-SIMILARY: SELECT COUNT(*) FROM customers WHERE city = "Prague"
        SIMILARY: 1.0





SUCCESS!

QUESTION: How many customers live in Prague city?

SCHEMA QUERY: CREATE TABLE customers (city VARCHAR)

Generated Answer: SELECT COUNT(*) FROM customers WHERE city = "Prague"

Original Answer: SELECT COUNT(*) FROM customers WHERE city = "Prague"



 80%|████████  | 8/10 [00:24<00:04,  2.48s/it]





Generated Answer-SIMILARY: SELECT driver FROM table_18893428_1 WHERE constructor = "Mercedes-Benz"
 Original Answer-SIMILARY: SELECT driver FROM table_18893428_1 WHERE constructor = "Mercedes-Benz"
        SIMILARY: 1.0





SUCCESS!

QUESTION: Who is the driver of the entry constructed by Mercedes-Benz?

SCHEMA QUERY: CREATE TABLE table_18893428_1 (driver VARCHAR, constructor VARCHAR)

Generated Answer: SELECT driver FROM table_18893428_1 WHERE constructor = "Mercedes-Benz"

Original Answer: SELECT driver FROM table_18893428_1 WHERE constructor = "Mercedes-Benz"



 90%|█████████ | 9/10 [00:26<00:02,  2.26s/it]





NO - SUCCESS!

QUESTION: How many different courses offered by Physics department?
Generated Answer: SELECT COUNT(*) FROM course WHERE dept_name = "Physics"
 Original Answer: SELECT COUNT(DISTINCT course_id) FROM course WHERE dept_name = 'Physics'
        SIMILARY: 0.8188976377952756



100%|██████████| 10/10 [00:29<00:00,  2.97s/it]





Generated Answer-SIMILARY: SELECT COUNT(home_ground) FROM table_11365528_2 WHERE president = "Peter Williamson"
 Original Answer-SIMILARY: SELECT COUNT(home_ground) FROM table_11365528_2 WHERE president = "Peter Williamson"
        SIMILARY: 1.0





SUCCESS!

QUESTION: The president, peter williamson, had how many home grounds?

SCHEMA QUERY: CREATE TABLE table_11365528_2 (home_ground VARCHAR, president VARCHAR)

Generated Answer: SELECT COUNT(home_ground) FROM table_11365528_2 WHERE president = "Peter Williamson"

Original Answer: SELECT COUNT(home_ground) FROM table_11365528_2 WHERE president = "Peter Williamson"






In [None]:
print(f'Accuracy: {accuracy}')

Accuracy: 0.8


In [None]:
!python --version
!nvcc --version
!nvidia-smi

Python 3.10.12
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0
Tue Oct 15 14:15:23 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   69C    P0              32W /  72W |  18503MiB / 23034MiB |      0%      Default |
|                       

## EVALUATION-TRAINER

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset

from peft import LoraConfig, get_peft_model, TaskType

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
)

#from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import spacy
import numpy as np

from torch_geometric.nn import GAT

from trl import setup_chat_format

import colab_env
import evaluate

In [None]:
# Assuming 'edges' data is present in the eval_dataset, modify the TextToSQLDataset class to include 'edges' in the returned dictionary.
class TextToSQLDataset(Dataset):
    def __init__(self, dataset, tokenizer):
        self.dataset = dataset
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset[idx]
        # Assuming 'edges' is a key in your dataset
        edges = row.get('edges')
        encoding = self.tokenizer(
            row["question"],
            row["context"],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt",
        )
        # Include 'edges' in the returned dictionary
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": self.tokenizer(row["answer"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")["input_ids"].squeeze(),
            "edges": edges # Make sure edges are properly formatted for your GraphModel
        }

In [None]:
def compute_metrics(eval_pred):
    all_preds, all_labels = eval_pred

    # Convert all elements to tensors, handling different data types
    predictions = [torch.tensor(pred) if not isinstance(pred, torch.Tensor) else pred for pred in all_preds]
    labels = [torch.tensor(label) if not isinstance(label, torch.Tensor) else label for label in all_labels]

    # Filter out any None values before stacking
    predictions = [pred for pred in predictions if pred is not None]
    labels = [label for label in labels if label is not None]

    # Convert to tensors and stack (only if there are predictions/labels)
    if predictions:
        predictions = torch.stack(predictions).squeeze()
    else:
        predictions = torch.tensor([])  # Empty tensor if no predictions

    if labels:
        labels = torch.stack(labels).squeeze()
    else:
        labels = torch.tensor([])  # Empty tensor if no labels

    # Handle cases where only one prediction/label is present (avoid squeezing to a scalar)
    if predictions.dim() == 0:
        predictions = predictions.unsqueeze(0)
    if labels.dim() == 0:
        labels = labels.unsqueeze(0)

    # Print shapes for debugging
    print('\n')
    print(f"Shape of logits in compute_metrics: {predictions.shape}")
    print(f"Shape of labels in compute_metrics: {labels.shape}")
    print('\n')

    # Ensure predictions is a list of lists
    predictions = predictions.tolist()
    # Remove the extra nesting if present
    if isinstance(predictions[0], list):
        if isinstance(predictions[0][0], list):
            predictions = [p[0] for p in predictions]

    # Convert logits to predicted token ids
    predictions = torch.tensor(predictions).argmax(dim=-1)

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Load the metric
    metric = evaluate.load("exact_match")

    em = metric.compute(predictions=decoded_preds, references=decoded_labels)["exact_match"]

    return {"exact_match": em}

In [None]:
from transformers import AutoTokenizer
from peft import PeftModel # PeftModel is now correctly imported from peft
from datasets import load_dataset
import evaluate
from sentence_transformers import SentenceTransformer
import numpy as np

# Load your fine-tuned model
output_dir = "/content/gdrive/MyDrive/model/GNNT2SQL"

#model_id ='/content/gdrive/MyDrive/model/GNNT2SQL/checkpoint-100/'

model_id = '/content/gdrive/MyDrive/model/GNNT2SQL/checkpoint-1950/'

# Load the base model first
#model = AutoModelForCausalLM.from_pretrained("google/flan-t5-xl")


# Use PeftModel to load the model, pass the model object and model_id as arguments
model = PeftModel.from_pretrained(mistral_model, model_id)


# Hugging Face model id
model_id = "mistralai/Mistral-7B-Instruct-v0.1"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Set padding side to 'left'
tokenizer.padding_side = 'left'

# Add the padding token
tokenizer.pad_token = tokenizer.eos_token

# Load the sentence transformer model
sentence_transformer_model = SentenceTransformer('all-mpnet-base-v2')

# Prepare your evaluation dataset
eval_dataset = load_dataset("b-mc2/sql-create-context")["train"].shuffle(seed=42)

POC_valsample = 100
eval_dataset = eval_dataset.select(np.random.choice(len(eval_dataset), POC_valsample, replace=False))

# Assuming TextToSQLDataset is a custom class, ensure it handles 'edges' data
eval_dataset = TextToSQLDataset(eval_dataset, tokenizer)

# Evaluate the model
results = evaluate_model(model, tokenizer, sentence_transformer_model, eval_dataset)

# Print the results
print(f'Exact Match: {results["eval_exact_match"]:.4f}')
#print(f'Semantic Similarity: {results["eval_semantic_similarity"]:.4f}')

# Ensure that the 'edges' data is correctly included and passed to the model within the TextToSQLDataset class and evaluate_model function.

In [None]:
def compute_metrics(eval_pred):
    all_preds, all_labels = eval_pred

    # Convert all elements to tensors, handling different data types
    predictions = [torch.tensor(pred) if not isinstance(pred, torch.Tensor) else pred for pred in all_preds]
    labels = [torch.tensor(label) if not isinstance(label, torch.Tensor) else label for label in all_labels]

    # Filter out any None values before stacking
    predictions = [pred for pred in predictions if pred is not None]
    labels = [label for label in labels if label is not None]

    # Convert to tensors and stack (only if there are predictions/labels)
    if predictions:
        predictions = torch.stack(predictions).squeeze()
    else:
        predictions = torch.tensor([])  # Empty tensor if no predictions

    if labels:
        labels = torch.stack(labels).squeeze()
    else:
        labels = torch.tensor([])  # Empty tensor if no labels

    # Handle cases where only one prediction/label is present (avoid squeezing to a scalar)
    if predictions.dim() == 0:
        predictions = predictions.unsqueeze(0)
    if labels.dim() == 0:
        labels = labels.unsqueeze(0)

    # Print shapes for debugging
    print('\n')
    print(f"Shape of logits in compute_metrics: {predictions.shape}")
    print(f"Shape of labels in compute_metrics: {labels.shape}")
    print('\n')

    # Ensure predictions is a list of lists
    predictions = predictions.tolist()
    # Remove the extra nesting if present
    if isinstance(predictions[0], list):
        if isinstance(predictions[0][0], list):
            predictions = [p[0] for p in predictions]

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    em = metric.compute(predictions=decoded_preds, references=decoded_labels)["exact_match"]

    return {"exact_match": em}