<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/NEWSOLUTION_FINE_TUNING_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Libraries Installation

In [None]:
!pip install datasets networkx -q

!pip install torch_geometric -q


# Install Pytorch & other libraries
!pip install torch tensorboard --quiet

# Install Hugging Face libraries
!pip install  --upgrade transformers accelerate evaluate bitsandbytes --quiet

#FlashAttention only supports Ampere GPUs or newer. #NEED A100 , L4  IN GOOGLE COLAB
!pip install -U flash-attn --no-build-isolation --quiet


! pip install peft --quiet
! pip install trl ninja packaging --quiet

# Uncomment only if you're using A100 GPU
#!pip install flash-attn --no-build-isolation
!pip install diffusers safetensors  --quiet
!pip install colab-env --quiet

!pip install mistral_inference -q

!pip install trl==0.8.6 -q


!pip install torch-geometric -q
!pip install sqlparse networkx -q

!pip install bitsandbytes -q


#!pip uninstall -y torchvision -q
!pip install torchvision --no-cache-dir -q
#import evaluate

!pip install sentence-transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset

from peft import LoraConfig, get_peft_model, TaskType

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
)

#from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import spacy
import numpy as np

from torch_geometric.nn import GAT

from trl import setup_chat_format

import colab_env
import evaluate

Mounted at /content/gdrive


## Original Code

https://github.com/frank-morales2020/MLxDL/blob/main/FineTuning_LLM_Mistral_7B_Instruct_v0_1_for_text_to_SQL_EVALDATA.ipynb

* Import Main Components

* https://stackoverflow.com/questions/70950706/assertionerror-in-torch-geometric-nn-gatconv

Absolutely! Let's describe the dataflow in the provided code, breaking it down step-by-step:

1. **Data Input and Preparation:**
   - The process begins by loading the "sql-create-context" dataset, which presumably contains pairs of natural language questions and their corresponding SQL queries.
   - The dataset is divided into three distinct subsets: training, validation, and testing.
   - The Mistral-7B-Instruct-v0.3 language model and its tokenizer are loaded and prepared.

2. **Data Transformation with `TextToSQLDataset`:**
   - The `TextToSQLDataset` class is responsible for converting raw data into a format suitable for model training and evaluation.
   - For each data sample, the following transformations occur:
      - Tokenization: The input question and the SQL query (answer) are tokenized using the loaded tokenizer.
      - Dependency Parsing: The question is parsed using spaCy to extract the grammatical relationships between words, generating a dependency graph.
      - Dictionary Creation: A dictionary is created to store the tokenized input IDs, attention masks, labels (tokenized SQL query), and the dependency edges extracted from parsing.

3. **Batching and Shuffling with `DataLoader`:**
   - `DataLoader` takes the processed dataset from `TextToSQLDataset` and creates an iterable object for efficient batching.
   - Optionally, shuffling is applied to randomize the order of samples within each epoch during training.

4. **Forward Pass through `GraphModel`:**
   - **Mistral Encoder:** The tokenized input IDs and attention masks are fed into the Mistral model's encoder to obtain contextualized token embeddings.
   - **GATv2 Layer:** The GATv2 layer (Graph Attention Network) takes the token embeddings and the dependency edges as input. It applies graph attention mechanisms to incorporate the structural information from the dependency graph into the token representations.
   - **Pooling:** The node representations (output of GATv2) are aggregated using a pooling operation (e.g., mean pooling) to obtain a fixed-size representation of the entire input sequence.
   - **LM Head:** The pooled representation is passed through a linear layer, which produces logits – unnormalized probabilities for each token in the vocabulary.
   - **Loss Calculation:** During training, if labels (correct SQL queries) are available, the cross-entropy loss is calculated between the predicted logits and the true labels. This loss guides the optimization of the model's parameters.

5. **Model Optimization with PEFT (LoRA):**
   - The `GraphModel` is wrapped with PEFT's LoRA (Low-Rank Adaptation) configuration to enable parameter-efficient fine-tuning.
   - During training, only the parameters of the GATv2 layer and the LM head are updated, while the rest of the model parameters remain frozen.
   - The Hugging Face Trainer manages the training process, iterating over the dataset, computing gradients, and updating the model's parameters based on the calculated loss.

6. **Evaluation:**
   - After (or during) training, the model is evaluated on the validation and test sets.
   - The `compute_metrics` function decodes the predicted logits and labels back into text and assesses the model's performance using two metrics:
      - Semantic Similarity: This metric measures how semantically close the predicted SQL query is to the reference SQL query using SentenceTransformer embeddings.
      - Exact Match: This metric checks if the predicted SQL query matches the reference SQL query exactly.

7. **Output:**
    - The final output is the evaluation results, including semantic similarity and exact match scores, which provide insights into the model's ability to generate accurate SQL queries from natural language questions.

Let me know if you have any further questions or would like clarification on any of the steps!


You're absolutely right! I apologize for the error in the previous diagram. The arrow direction was indeed incorrect.

Here's the corrected graphical representation of the dataflow, with the arrow pointing *into* the `Trainer`:

```
+-------------------+          +----------------------+
| Dataset Loading   |          | TextToSQLDataset     |
|  - sql-create...  | -------> | - Tokenization       |
|  - Split: train...|          | - Dependency Parsing |
+-------------------+          | - Dict Creation      |
                               +----------------------+
                                         |
                                         v
                         +---------------------------+
                         | DataLoader                |
                         | - Batches, Shuffling (opt)|
                         +---------------------------+
                                         |
                                         v
                         +---------------------------+
                         | GraphModel                |
                         | - Mistral Encoder         |
                         | - GATv2 Layer             |
                         | - Pooling                 |
                         | - LM Head                 |
                         | - (Loss Calculation)      |
                         +---------------------------+
                                         |
                                         v
            +--------------+          +---------+         +-----------------+
            | PEFT (LoRA)  | -------> | Trainer | <------ |Evaluation       |
            +--------------+          |         |         |(compute_metrics)|
                                      +---------+         +-----------------+
                                         |
                                         v
                                   +-----------------------+
                                   | - Semantic Similarity |
                                   | - Exact Match         |
                                   +-----------------------+
```

The revised dataflow now accurately shows that the evaluation metrics (semantic similarity and exact match) calculated by the `compute_metrics` function are used by the `Trainer` to assess the model's performance and make decisions during training (e.g., early stopping).




**PEFT (Parameter-Efficient Fine-Tuning) **

This information indicates the parameter count in a PEFT (Parameter-Efficient Fine-Tuning) model. Let's break it down:

* **trainable params: 589,824** - These are the parameters within the model that are adjusted during the fine-tuning process. Fine-tuning involves training a pre-trained model on a new task, and only a subset of the parameters are typically updated to adapt the model to the specific task.
* **all params: 7,718,522,880** - This represents the total number of parameters in the entire model, including both trainable and non-trainable parameters. Non-trainable parameters remain fixed during fine-tuning and typically come from the pre-trained base model.
* **trainable%: 0.0076** - This percentage shows the proportion of trainable parameters compared to the total number of parameters in the model. In this case, only a tiny fraction (0.0076%) of the parameters are being fine-tuned.

**Key Takeaway**

This PEFT model employs parameter-efficient fine-tuning, meaning only a small portion of the model's parameters are being updated. This approach offers several benefits:

* **Reduced computational resources** - Fine-tuning a smaller subset of parameters requires less memory and computational power.
* **Faster training** - Training time is generally shorter when fewer parameters are involved.
* **Mitigates overfitting** - Updating fewer parameters helps prevent the model from overfitting to the new task, which can improve generalization to unseen data.

Overall, this PEFT model exemplifies a strategy for efficiently adapting large pre-trained models to new tasks while minimizing resource requirements and training time.


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    BitsAndBytesConfig, Trainer, TrainingArguments,
    set_seed,
    EarlyStoppingCallback,
    DataCollatorForSeq2Seq
)

from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    PeftModel,
    PeftConfig,
    PrefixTuningConfig,
    PromptEncoderConfig,
    TaskType
)


from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from torch_geometric.nn import GAT
import spacy
import numpy as np
import torch.nn as nn
import evaluate


# Suppress warnings
import warnings
warnings.simplefilter('ignore')

warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support.")


# Load spaCy English model
try:
    nlp = spacy.load("en_core_web_sm")
except OSError:
    # Download if not already downloaded
    spacy.cli.download("en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")

# 1. Load and Prepare Data
dataset = load_dataset("b-mc2/sql-create-context")["train"].shuffle(seed=42)

# Manually define splits
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset = dataset.select(range(train_size))
val_dataset = dataset.select(range(train_size, train_size + val_size))
test_dataset = dataset.select(range(train_size + val_size, len(dataset)))

# Optionally, load augmented data (if you have it)
# train_dataset = load_dataset("json", data_files="your_augmented_dataset.json", split="train")

# 2. Mistral Model and Tokenizer
model_id = "mistralai/Mistral-7B-Instruct-v0.3"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Model Configuration
config = AutoModelForCausalLM.from_pretrained(model_id).config
config.output_hidden_states = True
config.use_cache = False
#config.torch_dtype = torch.float32
config.torch_dtype = torch.bfloat16

# Load Model with Quantization
mistral_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    attn_implementation="flash_attention_2",  # Optimization
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    config=config
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 3. PyTorch Datasets

import torch
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import spacy

# Load the spaCy English model
try:
    nlp = spacy.load("en_core_web_sm")
except OSError:
    spacy.cli.download("en_core_web_sm")
    nlp = spacy.load("en_core_web_sm")

class TextToSQLDataset(Dataset):
    def __init__(self, dataset, tokenizer):
        self.dataset = dataset
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]


        text = item['question']
        target_text = item['answer']


        # 1. Tokenization (with padding and truncation)
        tokenized_input = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=1024,  # Increase as needed
            return_tensors="pt"
        )
        tokenized_target = self.tokenizer(
            target_text,
            truncation=True,
            padding="max_length",
            max_length=1024,  # Increase as needed
            return_tensors="pt"
        )

        # Flatten lists (if needed)
        tokenized_input = {k: v.squeeze(0) for k, v in tokenized_input.items()}
        tokenized_target = {k: v.squeeze(0) for k, v in tokenized_target.items()}


        # Print input text and target text for debugging
        #print('\n')
        #print(f"Sample ID: {idx}")
        #print(f"  - Text: {text}")
        #print(f"  - Target Text: {target_text}")
        #print('\n')

        sample_ids = torch.tensor([idx])  # Create a tensor with the sample ID

        # 2. Dependency Parsing for Edge Extraction
        doc = nlp(text)
        edges = []
        for token in doc:
            head_i = token.head.i
            if 0 <= head_i < len(doc) and token.dep_ != "ROOT" and token.i != head_i:
                edges.append([token.i, head_i])

        # Edge Index Extraction and Validation
        edges = item.get("edges", edges)  # If "edges" is already present in data, use that

        if not edges:  # Handle empty graphs
            num_nodes = len(tokenized_input["input_ids"])
            edges = [[i, i] for i in range(num_nodes)]  # Self-loops for isolated nodes
        else:
            max_index = len(tokenized_input["input_ids"]) - 1
            edges = [(src, tgt) for src, tgt in edges if 0 <= src <= max_index and 0 <= tgt <= max_index]

        # Create edge index tensor
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

        # Convert everything to tensors BEFORE padding/truncation
        input_ids = tokenized_input["input_ids"].clone().detach()
        attention_mask = tokenized_input["attention_mask"].clone().detach()
        labels = tokenized_target["input_ids"].clone().detach()

        # Handle potentially empty target sequences
        if len(labels) == 0:
            labels = torch.tensor([self.tokenizer.pad_token_id], dtype=torch.long)  # Create a single-element tensor with pad token

        # Padding and Truncation for consistent input shapes
        max_length = 1024  # Adjust if needed

        # Ensure that ALL tensors are truncated/padded to the SAME max_length
        input_ids = input_ids[:max_length]
        attention_mask = attention_mask[:max_length]
        labels = labels[:max_length]

        # Add padding if necessary
        if len(input_ids) < max_length:
            pad_length = max_length - len(input_ids)
            pad_tensor = torch.full((pad_length,), self.tokenizer.pad_token_id)
            input_ids = torch.cat((input_ids, pad_tensor))
            attention_mask = torch.cat((attention_mask, torch.zeros(pad_length, dtype=torch.long)))

        if len(labels) < max_length:
            pad_length = max_length - len(labels)
            labels = torch.cat((labels, torch.full((pad_length,), -100)))  # Pad labels with -100

        # (Optional) Print statements for debugging
        #print("\n")
        #print("Original text:", text)
        #print("Target text:", target_text)
        #print("Tokenized input IDs:", tokenized_input["input_ids"])
        #print("Tokenized target IDs:", tokenized_target["input_ids"])
        #print("Labels:", labels)
        #print("Edge index:", edge_index)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "edges": edge_index,
            "sample_ids": torch.tensor([idx]),  # Add sample_id here
        }

#train_dataset = load_dataset("json", data_files="/content/gdrive/MyDrive/datasets/train_dataset.json", split="train")
#val_dataset   = load_dataset("json", data_files="/content/gdrive/MyDrive/datasets/test_dataset.json", split="train")



import numpy as np
#train_dataset = train_dataset.select(range(100))

#Reduce train_dataset size for POC
POC_sample=500
train_dataset = train_dataset.select(np.random.choice(len(train_dataset), POC_sample, replace=False))


#Minimum: Start with at least 2000-3000 samples.
#This should be enough to provide a good initial assessment of your model's performance and potential.

#Medium: If your computational resources allow, try 5000-7000 samples.
#This could provide a more robust evaluation and potentially lead to better performance.

#Maximum: If you have ample resources, consider using the entire dataset (around 10,000 samples).
#This would give you the most comprehensive training data possible and potentially lead
#to the best model performance.



POC_valsample=1
val_dataset = val_dataset.select(np.random.choice(len(val_dataset), POC_valsample, replace=False))
test_dataset = test_dataset.select(np.random.choice(len(test_dataset), POC_valsample, replace=False))


train_dataset = TextToSQLDataset(train_dataset, tokenizer)
val_dataset = TextToSQLDataset(val_dataset, tokenizer)
test_dataset = TextToSQLDataset(test_dataset, tokenizer)

# Create DataLoader (no collate_fn needed)
#train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 4. GAT Layer and GraphModel
class GATLayer(torch.nn.Module):
    def __init__(self, in_features, out_features, num_heads=8, num_layers=1):
        super(GATLayer, self).__init__()
        self.gat = GAT(in_channels=in_features, hidden_channels=out_features, heads=num_heads,
                        concat=False, num_layers=num_layers)

    def forward(self, x, edge_index):
        return self.gat(x, edge_index)

    # Make the internal linear layers accessible
    def get_lora_target_modules(self):
        # Access the linear layers within the GAT convolutions
        return [module for module in self.gat.modules() if isinstance(module, torch.nn.Linear)]


#  GraphModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv

import torch
import torch.nn as nn
from torch_geometric.nn import GATv2Conv


#from transformers import CausalLMOutputWithPast

from dataclasses import dataclass
from typing import Optional, Tuple

@dataclass
class MyCausalLMOutputWithPast:
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


class GraphModel(nn.Module):
    def __init__(self, encoder, tokenizer):
        super(GraphModel, self).__init__()
        self.encoder = encoder
        self.config = encoder.model.config
        self.gatv2 = GATv2Conv(
            in_channels=self.config.hidden_size,
            out_channels=self.config.hidden_size,
            heads=8,
            concat=False,
        )

        # Max Pooling
        self.pool = lambda x, batch: torch.max(x, dim=0, keepdim=True)[0]

        # Additional Feedforward Layer
        self.ffn = nn.Sequential(
            nn.Linear(self.config.hidden_size, self.config.hidden_size * 2),
            nn.ReLU(),
            nn.Linear(self.config.hidden_size * 2, self.config.hidden_size),
        )

        self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size)
        self.tokenizer = tokenizer

        # Add generation config (you might need to adjust this based on your specific needs)
        self.generation_config = encoder.generation_config

    # Forward Pass
    def forward(self, input_ids, attention_mask, edges, labels=None, inputs_embeds=None, sample_ids=None, output_attentions=False, output_hidden_states=False, return_dict=False):



        # Print vocabulary sizes
        #print('\n')
        #print("Mistral Model Vocab Size:", self.encoder.config.vocab_size)
        #print("Tokenizer Vocab Size:", self.tokenizer.vocab_size)
        #print('\n')



        # 1. Token Embeddings (Encoder)
        if input_ids is not None:
            encoder_outputs = self.encoder(
                input_ids.to(self.encoder.model.device),
                attention_mask=attention_mask.to(self.encoder.model.device),
                output_hidden_states=True
            )
            embeddings = encoder_outputs.hidden_states[-1]
        elif inputs_embeds is not None:
            embeddings = inputs_embeds
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        # Ensure correct shape for GATv2Conv
        if embeddings.dim() > 2:
            embeddings = embeddings.view(-1, embeddings.shape[-1])

        # 2. Edge Index Creation (Batched, with Enhanced Error Handling)
        edge_index = []
        node_offset = 0
        for i, graph_edges in enumerate(edges):
            if graph_edges is None or graph_edges.numel() == 0:  # Check if graph_edges is None or empty
                num_nodes = input_ids.size(1)
                # Create self-loops for isolated nodes if no edges are provided
                graph_edges = torch.arange(node_offset, node_offset + num_nodes, device=embeddings.device)
                graph_edges = graph_edges.repeat(2, 1) # Repeat the tensor, not the arange object
            else: # Add this else block to handle the case when edges are present
                # Ensure graph_edges is a tensor before adding offset
                if not isinstance(graph_edges, torch.Tensor):
                    graph_edges = torch.tensor(graph_edges, dtype=torch.long, device=embeddings.device)
                graph_edges += node_offset  # Now safe to add offset
            edge_index.append(graph_edges)
            node_offset += input_ids.size(1)
        edge_index = torch.cat(edge_index, dim=1)

        # 3. GATv2 Layer
        graph_out = self.gatv2(embeddings, edge_index)

        # 4. Pooling (using max pooling)
        batch = torch.arange(len(edges), device=graph_out.device).repeat_interleave(input_ids.size(1))
        pooled = self.pool(graph_out, batch).unsqueeze(1)

        # Additional Feedforward Layer
        pooled = self.ffn(pooled)

        # 5. LM Head
        logits = self.lm_head(pooled)

        from torch.nn import CrossEntropyLoss
        # 6. Loss Calculation (if labels provided)
        loss = None
        if labels is not None:

             # Print Sample IDs first
            print(f"\nIteration/Step: {trainer.state.global_step}")
            mask = (labels != -100).float()

            #for i, sample_id in enumerate(sample_ids):
            #    print(f"Sample ID: {sample_id.item()}")


            # Now decode and print input, target, and loss for each sample
            with torch.no_grad():
                input_text = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
                target_text = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
                for i, sample_id in enumerate(sample_ids):
                    print(f"Sample ID: {sample_id.item()}")
                    print("Decoded Input:", input_text[i])
                    print("Decoded Target (Labels):", target_text[i])



            # Apply softmax to logits to get probabilities
            log_probs = F.log_softmax(logits, dim=-1)

            # Calculate loss
            #loss_fct = nn.NLLLoss(ignore_index=-100)  # Use NLLLoss for log probabilities

            loss_fct = CrossEntropyLoss(ignore_index=-1)

            # Reshape log_probs to match target shape (remove the extra dimension)
            log_probs = log_probs.squeeze(1)

            #labels = labels[:, 0] ### WORK UNTIL BEFORE THE LAST INTERACTION
            #labels = labels.squeeze(1)   # Remove extra dimension from labels ## DON'T work
            #labels = labels.view(-1)  # Flatten labels to a 1D tensor DON'T work

            labels = labels[:, 0]


            loss = loss_fct(log_probs, labels)

            #print(f"Loss mean: {loss.item()}\n")

            loss_per_sample = (loss * mask).sum(dim=1) / mask.sum(dim=1)
            print(f"Loss per sample: {loss_per_sample.item()}")



        # 7. Return (Modified to return a tuple)
            if labels is not None:
                return (loss, logits, None)
            else:
                return (None, logits, None)




        return MyCausalLMOutputWithPast(
                loss=loss,
                logits=logits,
                past_key_values=encoder_outputs.past_key_values,
                hidden_states=encoder_outputs.hidden_states,
                attentions=encoder_outputs.attentions,
        )


    def prepare_inputs_for_generation(self, input_ids, edges, attention_mask=None, **kwargs):
        if isinstance(self, PeftModel):
            return self.base_model.prepare_inputs_for_generation(input_ids, edges, attention_mask, **kwargs)

        batch_size = input_ids.size(0)
        if batch_size > 1:
            batched_edges = []
            node_offset = 0
            for i in range(batch_size):
                graph_edges = edges[i]
                batched_edges.extend([(src + node_offset, dst + node_offset) for src, dst in graph_edges])
                node_offset += input_ids.size(1)
            edge_index = torch.tensor(batched_edges, dtype=torch.long).t().contiguous().to(input_ids.device)
        else:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous().to(input_ids.device)

        model_inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "edges": edge_index,
            "past_key_values": kwargs.get("past_key_values", None),
        }
        return model_inputs

# END GraphModel

from peft import prepare_model_for_kbit_training

# Quantize Mistral (before creating GraphModel)
mistral_model = prepare_model_for_kbit_training(mistral_model, use_gradient_checkpointing=True)


# 5. Model Setup (Define model first)
model = GraphModel(mistral_model, tokenizer)  # Pass both mistral_model and tokenizer
model.to(device)


# 6. PEFT Configuration (Use automatic module discovery)
peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    #target_modules="all-linear",
    task_type="CAUSAL_LM",
        target_modules=[
        "gat.gat.convs.0.lin_l",
        "gat.gat.convs.0.lin_r",
        "gat.gat.convs.1.lin_l",
        "gat.gat.convs.1.lin_r",
        "lm_head"
    ]
    #task_type=TaskType.CAUSAL_LM
)


# 7. Apply PEFT
model = get_peft_model(model, peft_config)
print('\n\n')
print('PEFT-Model')
model.print_trainable_parameters() # To see the trainable parameters
print('\n')

# Access the config of the encoder (Mistral model) within your GraphModel
model.encoder.config.use_cache = False
model.encoder.gradient_checkpointing_enable()  # Enable gradient checkpointing for memory optimization on the Mistral model
model.encoder.model.embed_tokens.requires_grad_(True)
torch.cuda.empty_cache()

# 8. Evaluation Metric (Semantic Similarity)
metric = evaluate.load("exact_match")
sentence_transformer_model = SentenceTransformer('all-mpnet-base-v2')


def compute_metrics(eval_pred):
    all_preds, all_labels = eval_pred

    # Convert all elements to tensors, handling different data types
    predictions = [torch.tensor(pred) if not isinstance(pred, torch.Tensor) else pred for pred in all_preds]
    labels = [torch.tensor(label) if not isinstance(label, torch.Tensor) else label for label in all_labels]

    # Filter out any None values before stacking
    predictions = [pred for pred in predictions if pred is not None]
    labels = [label for label in labels if label is not None]

    # Convert to tensors and stack (only if there are predictions/labels)
    if predictions:
        predictions = torch.stack(predictions).squeeze()
    else:
        predictions = torch.tensor([])  # Empty tensor if no predictions

    if labels:
        labels = torch.stack(labels).squeeze()
    else:
        labels = torch.tensor([])  # Empty tensor if no labels

    # Handle cases where only one prediction/label is present (avoid squeezing to a scalar)
    if predictions.dim() == 0:
        predictions = predictions.unsqueeze(0)
    if labels.dim() == 0:
        labels = labels.unsqueeze(0)

    # Print shapes for debugging
    #print('\n')
    #print(f"Shape of logits in compute_metrics: {predictions.shape}")
    #print(f"Shape of labels in compute_metrics: {labels.shape}")
    #('\n')


    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    em = metric.compute(predictions=decoded_preds, references=decoded_labels)["exact_match"]

    return {"exact_match": em}


#/content/gdrive/MyDrive/model

# 9. Training Arguments and Trainer
training_args = TrainingArguments(
    "/content/gdrive/MyDrive/model/GNN-T2SQL",
    logging_dir="/content/gdrive/MyDrive/model/GNN-T2SQL",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    push_to_hub=False,
    dataloader_pin_memory=False,
    load_best_model_at_end=True,
    #gradient_checkpointing=False,
    use_legacy_prediction_loop=False,
    metric_for_best_model="eval_exact_match",
    report_to="tensorboard",
    #generation_max_length=2048, # Add this line to increase the generation max length

)


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)


from torch.nn.utils.rnn import pad_sequence
from torch_geometric.data import Batch as GraphBatch  # Note the import
import torch
import torch_geometric.data
from torch.nn.utils.rnn import pad_sequence

class GraphDataCollatorForSeq2Seq:
    def __init__(self, tokenizer, model=None, label_pad_token_id=-100, pad_to_multiple_of=None):
        self.tokenizer = tokenizer
        self.model = model
        self.label_pad_token_id = label_pad_token_id
        self.pad_to_multiple_of = pad_to_multiple_of

    def __call__(self, features):
        # Separate standard features from graph edges
        # Extract labels before padding and handle potentially empty sequences
        labels = [feature["labels"] if feature["labels"].numel() > 0
                  else torch.tensor([self.label_pad_token_id], dtype=torch.long)
                  for feature in features]

        # Extract sample_ids
        sample_ids = [feature["sample_ids"] for feature in features]

        standard_features = [{k: v for k, v in feature.items() if k != "edges" and k != "labels"} for feature in features]
        edges = [feature["edges"] for feature in features]

        # Collate standard features (input_ids, attention_mask) using default collator
        collated_standard_features =  DataCollatorForSeq2Seq(
            tokenizer=self.tokenizer,
            model=self.model,
            label_pad_token_id=self.label_pad_token_id,
            pad_to_multiple_of=self.pad_to_multiple_of
        )(standard_features)

        # Pad input_ids and attention_mask
        input_ids = pad_sequence([f['input_ids'] for f in standard_features], batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence([f['attention_mask'] for f in standard_features], batch_first=True, padding_value=0)

         # Pad labels separately
        labels = pad_sequence(labels, batch_first=True, padding_value=self.label_pad_token_id)

        # Create batch for graph data
        graph_data_list = []
        for i in range(len(edges)):
            # Convert to PyTorch Geometric Data
            graph_data_list.append(torch_geometric.data.Data(
                x=collated_standard_features['input_ids'][i].unsqueeze(1),  # Node features (input_ids)
                edge_index=edges[i],                   # Edge index
                # Use num_edges for batch index to ensure correct batching in PyG
                batch=torch.tensor([i] * edges[i].size(1))
            ))
        batched_graph = GraphBatch.from_data_list(graph_data_list)  # Batch graphs

        #print(f"Sample GraphDataCollatorForSeq2Seq ID: {sample_ids}")

         # Include sample_ids in the collated features
        collated_features = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "edges": edges,
            "sample_ids": sample_ids  # Add sample_ids here
        }

        return collated_features


# 10a.  Data Collator
data_collator = GraphDataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=8
)


# 10AA. Training Arguments and Trainer

from transformers import Trainer
from transformers.trainer_utils import EvalLoopOutput, PredictionOutput,  has_length

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

from transformers.trainer_utils import EvalPrediction
from transformers import Trainer, TrainerCallback

from typing import Any, Dict, List, Optional, Tuple, Union
#from transformers.trainer_pt_utils import nested_truncate  # Updated import#

from transformers.trainer_pt_utils import nested_truncate, nested_concat, nested_numpify, nested_detach  # Updated imports


######
from torch.nn import CrossEntropyLoss

from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from transformers import Trainer
from transformers.trainer_utils import EvalPrediction
import torch
import torch.nn.functional as F

def estimated_num_samples(dataloader: DataLoader):
    """
    This function will attempt to determine the number of samples in a DataLoader.

    Args:
        dataloader (DataLoader): The DataLoader to estimate the number of samples from.

    Returns:
        int: The estimated number of samples, or 0 if estimation is not possible.
    """
    if hasattr(dataloader, "dataset") and hasattr(dataloader.dataset, "__len__"):
        return len(dataloader.dataset)  # Use dataset length if available
    elif hasattr(dataloader, "batch_sampler") and hasattr(dataloader.batch_sampler, "sampler") and hasattr(dataloader.batch_sampler.sampler, "__len__"):
        return len(dataloader.batch_sampler.sampler)  # Use sampler length if available
    else:
        # If neither is available, return 0
        warnings.warn("Could not estimate the number of samples in the dataloader. Returning 0.")
        return 0

class CustomTrainer(Trainer):

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        inputs = self._prepare_inputs(inputs)
        labels = inputs.pop("labels", None)

        with torch.no_grad():
            outputs = model(**inputs)

            loss = outputs.loss
            logits = outputs.logits

            # Debugging Logging (ensure labels exist before accessing)
            if not prediction_loss_only:
                print("Shape of logits in prediction_step:", logits.shape)
                if labels is not None:  # Only print if labels exist
                    print("Shape of labels in prediction_step:", labels.shape)

            if prediction_loss_only:
                if isinstance(loss, torch.Tensor):
                    loss = loss.mean().detach()
                return (loss, None, None)

            max_new_tokens = 1024 - inputs['input_ids'].shape[1]
            if max_new_tokens <= 0:
                # Input is already at max length, no need to generate
                generated_ids = inputs['input_ids']
            else:
                generated_ids = model.encoder.generate(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    max_new_tokens=max_new_tokens,
                    num_beams=5
                )

            # Truncate predictions if they exceed 1024 (unlikely with max_new_tokens)
            predictions = generated_ids[:, :1024]


            # Ensure outputs are tensors and on the correct device
            predictions = torch.tensor(predictions, device=logits.device)
            if labels is not None:
                labels = torch.tensor(labels, device=logits.device)

            #print("Loss:", loss)

            return (loss, predictions, labels)

    def _prediction_loop(self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval") -> Union[Tuple[torch.Tensor, torch.Tensor], EvalPrediction]:

        # ... other parts remain same ...

         # In case you have a callback that needs length eventually
        if has_length(dataloader):
            num_samples = len(dataloader.dataset)
        # The dataset does not support __len__, estimate the number of samples.
        else:
            num_samples = estimated_num_samples(dataloader)

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses:
            all_losses = all_losses[:num_samples]
        if all_preds:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels:
            all_labels = nested_truncate(all_labels, num_samples)

        # 8.  Compute Metrics and Average Loss
        metrics = self.compute_metrics((all_preds, all_labels))
        average_loss = torch.mean(torch.stack(all_losses))
        metrics[f"{metric_key_prefix}_loss"] = average_loss.item()

        # 9.  Log the Metrics
        self.log(metrics)

        # 10. Return Based on Whether It's a Prediction or Evaluation
        if prediction_loss_only:
            return (metrics, None, None)

        return EvalPrediction(predictions=all_preds, label_ids=all_labels, metrics=metrics)


    def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
        test_dataloader = self.get_test_dataloader(test_dataset)
        return self.prediction_loop(test_dataloader, description="Prediction")

    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
        # ... (existing code in Trainer.evaluate)

        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        output = self.prediction_loop(
            eval_dataloader,
            description="Evaluation",
            # ...
        )

        # ... (rest of the existing code)

        return output.metrics


######

# 10A. Trainer (Modified)
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, None)
)


from transformers import TrainerCallback
class LossLoggingCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 100 == 0:  # Log every 100 steps (adjust as needed)
            print(f"Step {state.global_step} - Loss: {state.loss}")




# 11. Train the model

# Add the Callback to the Trainer
#trainer.add_callback(LossLoggingCallback())

# Add the Early Stopping to the Trainer
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=2))

trainer.train()


# 12. Evaluate on the test set
test_results = trainer.evaluate(test_dataset)
print('\n\n')
#print(f'Test Semantic Similarity: {test_results["eval_semantic_similarity"]:.4f}')
print(f'Test Exact Match. (Evaluate on the test set): {test_results["eval_exact_match"]:.4f}')
print('\n\n')

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]




PEFT-Model
trainable params: 589,824 || all params: 7,718,522,880 || trainable%: 0.0076




The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.



Iteration/Step: 0
Sample ID: 42
Decoded Input: On what date was there a friendly game against Wales?
Decoded Target (Labels): SELECT date FROM table_name_61 WHERE type_of_game = "friendly" AND opponent = "wales"
Loss per sample: 10.475870132446289

Iteration/Step: 0
Sample ID: 32
Decoded Input: Who had the fastest lap in the race where Patrick Tambay was on the pole?
Decoded Target (Labels): SELECT fastest_lap FROM table_1140073_2 WHERE pole_position = "Patrick Tambay"
Loss per sample: 10.483648300170898


Epoch,Training Loss,Validation Loss



Iteration/Step: 1
Sample ID: 120
Decoded Input: What is the lowest position for a driver with 2 points?
Decoded Target (Labels): SELECT MAX(pos) FROM table_23385853_19 WHERE points = 2
Loss per sample: 10.600207328796387

Iteration/Step: 1
Sample ID: 29
Decoded Input: Name the most serial number for feb 1994
Decoded Target (Labels): SELECT MAX(Serial) AS number FROM table_29002641_1 WHERE scrapped = "Feb 1994"
Loss per sample: 10.622447967529297

Iteration/Step: 2
Sample ID: 478
Decoded Input: What was Collingwood's score when they played against North Melbourne at home?
Decoded Target (Labels): SELECT home_team AS score FROM table_name_12 WHERE away_team = "north melbourne"
Loss per sample: 10.514888763427734

Iteration/Step: 2
Sample ID: 70
Decoded Input: What was the date when the attendance was 77,918?
Decoded Target (Labels): SELECT date FROM table_name_90 WHERE attendance = "77,918"
Loss per sample: 10.51118278503418

Iteration/Step: 3
Sample ID: 496
Decoded Input: How many year