<a href="https://colab.research.google.com/github/felipemt267/sql-language-model/blob/main/sql_lstm_generator_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from collections import Counter

In [None]:
corpus = [
    "SELECT * FROM customers WHERE active = 1;",
    "INSERT INTO products (name, price) VALUES ('laptop', 1200);",
    "UPDATE employees SET salary = 5000 WHERE department = 'HR';",
    "DELETE FROM orders WHERE order_date < '2023-01-01';",
    "CREATE TABLE users (id INT PRIMARY KEY, username VARCHAR(50));",
    "ALTER TABLE invoices ADD COLUMN total DECIMAL(10, 2);",
    "DROP TABLE IF EXISTS temp_data;",
    "SELECT name, COUNT(*) AS total FROM sales GROUP BY name;",
    "JOIN addresses ON customers.id = addresses.customer_id;",
    "ORDER BY created_at DESC LIMIT 10;",
    "WHERE status = 'pending' AND priority > 5;",
    "HAVING SUM(amount) > 1000;",
    "GRANT SELECT, INSERT ON database.* TO 'user'@'localhost';",
    "REVOKE ALL PRIVILEGES ON employees FROM 'intern'@'%';",
    "BEGIN TRANSACTION; COMMIT;",
    "ROLLBACK TO SAVEPOINT before_update;",
    "EXPLAIN SELECT * FROM logs WHERE type = 'error';",
    "INDEX idx_name ON employees (last_name);",
    "PRIMARY KEY (order_id, product_id);",
    "FOREIGN KEY (customer_id) REFERENCES customers(id);"
]

In [None]:
tokens = [word.lower()for setnece in corpus for word in setnece.split()]

In [None]:
tokens

['select',
 '*',
 'from',
 'customers',
 'where',
 'active',
 '=',
 '1;',
 'insert',
 'into',
 'products',
 '(name,',
 'price)',
 'values',
 "('laptop',",
 '1200);',
 'update',
 'employees',
 'set',
 'salary',
 '=',
 '5000',
 'where',
 'department',
 '=',
 "'hr';",
 'delete',
 'from',
 'orders',
 'where',
 'order_date',
 '<',
 "'2023-01-01';",
 'create',
 'table',
 'users',
 '(id',
 'int',
 'primary',
 'key,',
 'username',
 'varchar(50));',
 'alter',
 'table',
 'invoices',
 'add',
 'column',
 'total',
 'decimal(10,',
 '2);',
 'drop',
 'table',
 'if',
 'exists',
 'temp_data;',
 'select',
 'name,',
 'count(*)',
 'as',
 'total',
 'from',
 'sales',
 'group',
 'by',
 'name;',
 'join',
 'addresses',
 'on',
 'customers.id',
 '=',
 'addresses.customer_id;',
 'order',
 'by',
 'created_at',
 'desc',
 'limit',
 '10;',
 'where',
 'status',
 '=',
 "'pending'",
 'and',
 'priority',
 '>',
 '5;',
 'having',
 'sum(amount)',
 '>',
 '1000;',
 'grant',
 'select,',
 'insert',
 'on',
 'database.*',
 'to',
 

In [None]:
vocab = Counter(tokens)

In [None]:
vocab

Counter({'select': 3,
         '*': 2,
         'from': 5,
         'customers': 1,
         'where': 5,
         'active': 1,
         '=': 6,
         '1;': 1,
         'insert': 2,
         'into': 1,
         'products': 1,
         '(name,': 1,
         'price)': 1,
         'values': 1,
         "('laptop',": 1,
         '1200);': 1,
         'update': 1,
         'employees': 3,
         'set': 1,
         'salary': 1,
         '5000': 1,
         'department': 1,
         "'hr';": 1,
         'delete': 1,
         'orders': 1,
         'order_date': 1,
         '<': 1,
         "'2023-01-01';": 1,
         'create': 1,
         'table': 3,
         'users': 1,
         '(id': 1,
         'int': 1,
         'primary': 2,
         'key,': 1,
         'username': 1,
         'varchar(50));': 1,
         'alter': 1,
         'invoices': 1,
         'add': 1,
         'column': 1,
         'total': 2,
         'decimal(10,': 1,
         '2);': 1,
         'drop': 1,
         'if': 1

In [None]:
vocab = sorted(vocab,key=vocab.get,reverse=True)
vocab

['=',
 'from',
 'where',
 'on',
 'select',
 'employees',
 'table',
 '*',
 'insert',
 'primary',
 'total',
 'by',
 '>',
 'to',
 'key',
 'customers',
 'active',
 '1;',
 'into',
 'products',
 '(name,',
 'price)',
 'values',
 "('laptop',",
 '1200);',
 'update',
 'set',
 'salary',
 '5000',
 'department',
 "'hr';",
 'delete',
 'orders',
 'order_date',
 '<',
 "'2023-01-01';",
 'create',
 'users',
 '(id',
 'int',
 'key,',
 'username',
 'varchar(50));',
 'alter',
 'invoices',
 'add',
 'column',
 'decimal(10,',
 '2);',
 'drop',
 'if',
 'exists',
 'temp_data;',
 'name,',
 'count(*)',
 'as',
 'sales',
 'group',
 'name;',
 'join',
 'addresses',
 'customers.id',
 'addresses.customer_id;',
 'order',
 'created_at',
 'desc',
 'limit',
 '10;',
 'status',
 "'pending'",
 'and',
 'priority',
 '5;',
 'having',
 'sum(amount)',
 '1000;',
 'grant',
 'select,',
 'database.*',
 "'user'@'localhost';",
 'revoke',
 'all',
 'privileges',
 "'intern'@'%';",
 'begin',
 'transaction;',
 'commit;',
 'rollback',
 'savepoi

In [None]:
vocab_size = len(vocab)
vocab_size

103

In [None]:
word_to_idx = {word: i for i, word in enumerate(vocab)}
word_to_idx

{'=': 0,
 'from': 1,
 'where': 2,
 'on': 3,
 'select': 4,
 'employees': 5,
 'table': 6,
 '*': 7,
 'insert': 8,
 'primary': 9,
 'total': 10,
 'by': 11,
 '>': 12,
 'to': 13,
 'key': 14,
 'customers': 15,
 'active': 16,
 '1;': 17,
 'into': 18,
 'products': 19,
 '(name,': 20,
 'price)': 21,
 'values': 22,
 "('laptop',": 23,
 '1200);': 24,
 'update': 25,
 'set': 26,
 'salary': 27,
 '5000': 28,
 'department': 29,
 "'hr';": 30,
 'delete': 31,
 'orders': 32,
 'order_date': 33,
 '<': 34,
 "'2023-01-01';": 35,
 'create': 36,
 'users': 37,
 '(id': 38,
 'int': 39,
 'key,': 40,
 'username': 41,
 'varchar(50));': 42,
 'alter': 43,
 'invoices': 44,
 'add': 45,
 'column': 46,
 'decimal(10,': 47,
 '2);': 48,
 'drop': 49,
 'if': 50,
 'exists': 51,
 'temp_data;': 52,
 'name,': 53,
 'count(*)': 54,
 'as': 55,
 'sales': 56,
 'group': 57,
 'name;': 58,
 'join': 59,
 'addresses': 60,
 'customers.id': 61,
 'addresses.customer_id;': 62,
 'order': 63,
 'created_at': 64,
 'desc': 65,
 'limit': 66,
 '10;': 67,
 '

In [None]:
index_to_word = {i: word for i,word in enumerate(vocab)}
index_to_word

{0: '=',
 1: 'from',
 2: 'where',
 3: 'on',
 4: 'select',
 5: 'employees',
 6: 'table',
 7: '*',
 8: 'insert',
 9: 'primary',
 10: 'total',
 11: 'by',
 12: '>',
 13: 'to',
 14: 'key',
 15: 'customers',
 16: 'active',
 17: '1;',
 18: 'into',
 19: 'products',
 20: '(name,',
 21: 'price)',
 22: 'values',
 23: "('laptop',",
 24: '1200);',
 25: 'update',
 26: 'set',
 27: 'salary',
 28: '5000',
 29: 'department',
 30: "'hr';",
 31: 'delete',
 32: 'orders',
 33: 'order_date',
 34: '<',
 35: "'2023-01-01';",
 36: 'create',
 37: 'users',
 38: '(id',
 39: 'int',
 40: 'key,',
 41: 'username',
 42: 'varchar(50));',
 43: 'alter',
 44: 'invoices',
 45: 'add',
 46: 'column',
 47: 'decimal(10,',
 48: '2);',
 49: 'drop',
 50: 'if',
 51: 'exists',
 52: 'temp_data;',
 53: 'name,',
 54: 'count(*)',
 55: 'as',
 56: 'sales',
 57: 'group',
 58: 'name;',
 59: 'join',
 60: 'addresses',
 61: 'customers.id',
 62: 'addresses.customer_id;',
 63: 'order',
 64: 'created_at',
 65: 'desc',
 66: 'limit',
 67: '10;',
 6

In [None]:
class SQLLM(nn.Module):
    def __init__(self, vocab_size, embedding_dim=64, hidden_dim=128):
        super().__init__()
        self.embeddings  = nn.Embedding(vocab_size,embedding_dim)
        self.lstm        = nn.LSTM(embedding_dim,hidden_dim, batch_first = True)
        self.fc          = nn.Linear(hidden_dim,vocab_size)

    def forward(self, x, hidden = None):
        x           = self.embeddings(x)
        out, hidden = self.lstm(x, hidden)
        out         = self.fc(out)
        return out, hidden

In [None]:
class SQLDataset(Dataset):
    def __init__(self, corpus, seq_length=3):
        self.seq_length = seq_length
        self.data = []

        for sentence in corpus:
           tokens = sentence.lower().split()
           indices = [word_to_idx[word] for word in tokens]
           for i in range(len(indices)-self.seq_length):
              self.data.append((
                   torch.tensor(indices[i:i+self.seq_length]),
                   torch.tensor(indices[i+1:i+1+self.seq_length])
                   ))
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        return self.data[idx]

dataset = SQLDataset(corpus, seq_length=3)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SQLLM(vocab_size).to(device)

In [None]:
device


device(type='cuda')

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr =0.001 )

In [None]:
epochs = 200

for epoch in range(epochs):
  for inputs, targets in dataloader:
    inputs, targets = inputs.to(device), targets.to(device)

    optimizer.zero_grad()
    output,_= model(inputs)
    loss = criterion(output.view(-1,vocab_size),targets.view(-1))
    loss.backward()
    optimizer.step()
  print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')

Epoch 1, Loss: 4.3983
Epoch 2, Loss: 4.2268
Epoch 3, Loss: 3.9085
Epoch 4, Loss: 1.6098
Epoch 5, Loss: 1.4766
Epoch 6, Loss: 2.0641
Epoch 7, Loss: 0.7522
Epoch 8, Loss: 2.2325
Epoch 9, Loss: 0.9636
Epoch 10, Loss: 0.3604
Epoch 11, Loss: 1.1304
Epoch 12, Loss: 0.3431
Epoch 13, Loss: 0.6175
Epoch 14, Loss: 0.6001
Epoch 15, Loss: 0.6997
Epoch 16, Loss: 0.2785
Epoch 17, Loss: 0.6299
Epoch 18, Loss: 0.0931
Epoch 19, Loss: 0.0699
Epoch 20, Loss: 0.0583
Epoch 21, Loss: 0.0772
Epoch 22, Loss: 0.5907
Epoch 23, Loss: 0.3349
Epoch 24, Loss: 0.6878
Epoch 25, Loss: 0.2328
Epoch 26, Loss: 0.0312
Epoch 27, Loss: 0.0536
Epoch 28, Loss: 0.4221
Epoch 29, Loss: 0.0332
Epoch 30, Loss: 0.6924
Epoch 31, Loss: 0.3716
Epoch 32, Loss: 0.0232
Epoch 33, Loss: 0.0387
Epoch 34, Loss: 0.5617
Epoch 35, Loss: 0.0163
Epoch 36, Loss: 0.0147
Epoch 37, Loss: 0.0135
Epoch 38, Loss: 0.6538
Epoch 39, Loss: 0.0079
Epoch 40, Loss: 0.0125
Epoch 41, Loss: 0.0208
Epoch 42, Loss: 0.0145
Epoch 43, Loss: 0.0203
Epoch 44, Loss: 0.00

In [None]:
def complete_text(seed_text, num_words=5, temperatura=0.7):
  model.eval()
  words = seed_text.lower().split()

  with torch.no_grad():
    for _ in range(num_words):
      inputs = torch.tensor(
          [word_to_idx[word]for word in words[-3:]]
      ).unsqueeze(0).to(device)


      output,_ = model(inputs)
      probabilities = torch.softmax(output[0,-1]/temperatura, dim=0)
      next_idx = torch.multinomial(probabilities,1).item()



      words.append(index_to_word[next_idx])
  return ' '.join(words)


In [None]:
print(complete_text("SELECT * FROM", num_words=5))


select * from logs where type = 'error';
