Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing the pretrained model directly in new inputs #17

Closed
mellahysf opened this issue Aug 12, 2020 · 14 comments
Closed

Testing the pretrained model directly in new inputs #17

mellahysf opened this issue Aug 12, 2020 · 14 comments

Comments

@mellahysf
Copy link

Hi,

Can you give us the necessary files and pretrained model to run it directly (for testing/evaluation) ? How to use them directly to test the pretrained model in new inputs?

Thank you

@gouldju1
Copy link

I'd really like to see this, too.

@alexpolozov
Copy link
Contributor

alexpolozov commented Aug 15, 2020

I'll push an end-to-end sample sometime next week. Here's an approximate untested snippet (adapting the code from infer.py) if you want to run inference against one of the preprocessed Spider schemas:

import json
import os
import _jsonnet
from ratsql.commands.infer import Inferer
from ratsql.datasets.spider import SpiderItem
from ratsql.utils import registry

exp_config = json.loads(_jsonnet.evaluate_file(exp_config_path))
model_config_path = os.path.join(root_dir, exp_config["model_config"])
model_config_args = exp_config.get("model_config_args")
infer_config = json.loads(_jsonnet.evaluate_file(model_config_path, tla_codes={'args': json.dumps(model_config_args)})

inferer = Inferer(infer_config)
inferer.device = torch.device("cpu")
model = inferer.load_model(model_dir, checkpoint_step)
dataset = registry.construct('dataset', inferer.config['data']['val'])

for _, schema in dataset.schemas.items():
    model.preproc.enc_preproc._preprocess_schema(schema)

def question(q, db_id):
    spider_schema = dataset.schemas[db_id]
    data_item = SpiderItem(
        text=None,  # intentionally None -- should be ignored when the tokenizer is set correctly
        code=None,
        schema=spider_schema,
        orig_schema=spider_schema.orig,
        orig={"question": q}
    )
    model.preproc.clear_items()
    enc_input = model.preproc.enc_preproc.preprocess_item(data_item, None)
    preproc_data = enc_input, None
    with torch.no_grad():
        return inferer._infer_one(model, data_item, preproc_data, beam_size=1, use_heuristic=True)

@gouldju1
Copy link

Awesome, I came up with something similar over the weekend. Thanks! So I used what I wrote on the climbing DB:
Input question: How tall is the mountain from Kenya
Output query: SELECT mountain.Height FROM mountain WHERE mountain.Country = 'terminal'

This happens when using other DBs, too. How can I get the where clause to insert the correct where value? Thank you!

@gouldju1
Copy link

gouldju1 commented Aug 17, 2020

Here's an example running python3 run.py eval experiments/spider-glove-run.jsonnet:

 "inferred_code": "SELECT singer.Name FROM singer JOIN singer_in_concert ON singer.Singer_ID = singer_in_concert.Singer_ID WHERE singer.Age = 'terminal'"

Here's the full output in step-xxxxx.infer:

{
    "index": 38,
    "beams": [
        {
            "orig_question": "What are the names of the singers who performed in a concert in 2014?",
            "model_output": {
                "_type": "sql",
                "select": {
                    "_type": "select",
                    "is_distinct": false,
                    "aggs": [
                        {
                            "_type": "agg",
                            "agg_id": {
                                "_type": "NoneAggOp"
                            },
                            "val_unit": {
                                "_type": "Column",
                                "col_unit1": {
                                    "_type": "col_unit",
                                    "agg_id": {
                                        "_type": "NoneAggOp"
                                    },
                                    "col_id": 9,
                                    "is_distinct": false
                                }
                            }
                        }
                    ]
                },
                "sql_where": {
                    "_type": "sql_where",
                    "where": {
                        "_type": "Eq",
                        "val_unit": {
                            "_type": "Column",
                            "col_unit1": {
                                "_type": "col_unit",
                                "agg_id": {
                                    "_type": "NoneAggOp"
                                },
                                "col_id": 13,
                                "is_distinct": false
                            }
                        },
                        "val1": {
                            "_type": "Terminal"
                        }
                    }
                },
                "sql_groupby": {
                    "_type": "sql_groupby"
                },
                "sql_orderby": {
                    "_type": "sql_orderby",
                    "limit": false
                },
                "sql_ieu": {
                    "_type": "sql_ieu"
                },
                "from": {
                    "_type": "from",
                    "table_units": [
                        {
                            "_type": "Table",
                            "table_id": 1
                        },
                        {
                            "_type": "Table",
                            "table_id": 3
                        }
                    ],
                    "conds": {
                        "_type": "Eq",
                        "val_unit": {
                            "_type": "Column",
                            "col_unit1": {
                                "_type": "col_unit",
                                "agg_id": {
                                    "_type": "NoneAggOp"
                                },
                                "col_id": 8,
                                "is_distinct": false
                            }
                        },
                        "val1": {
                            "_type": "ColUnit",
                            "c": {
                                "_type": "col_unit",
                                "agg_id": {
                                    "_type": "NoneAggOp"
                                },
                                "col_id": 21,
                                "is_distinct": false
                            }
                        }
                    }
                }
            },
            "inferred_code": "SELECT singer.Name FROM singer JOIN singer_in_concert ON singer.Singer_ID = singer_in_concert.Singer_ID WHERE singer.Age = 'terminal'",
            "score": -2.1817409903378575
        }
    ]
}

@DevanshChoubey
Copy link

@gouldju1

Hi I don't think any of the major models on the spider board supports where value prediction. after i get this running maybe i will publish my end-to-end model..

@mellahysf
Copy link
Author

How tall is the mountain from Kenya

@gouldju1 can you share with us your code that take a question and return the SQL query ?

Thank you.

@Akshaysharma29
Copy link

Akshaysharma29 commented Sep 16, 2020

Hi,@gouldju1 @DevanshChoubey @mellahysf are you able to directly generate SQL query for text?

@DevanshChoubey
Copy link

Yes @Akshaysharma29

@mellahysf
Copy link
Author

Thank you @alexpolozov.
I used your code and it works for me.

@mellahysf
Copy link
Author

mellahysf commented Oct 6, 2020

Hi,

How can I run inference (adapting this code #17 (comment)) against a new database Schema? (The "question" method should takes, in this case, .sql file as the second argument)

Thanks

@mellahysf mellahysf reopened this Oct 6, 2020
@kalleknast
Copy link

@mellahysf

I adapted the code above to "work" with a new (not Spider) schema. It runs without errors, but the predicted sql is, at this point, not very impressive.

I had to separate loading the model from loading the dataset since inferer.load() infers model dims from the dataset resulting in errors when using a new single-schema dataset.

import json
import os
import _jsonnet
import torch
from ratsql.commands.infer import Inferer
from ratsql.datasets.spider import SpiderItem
from ratsql.utils import registry

db_id = '<id of new database>'
root_dir = '<path to rat-sql>'
exp_config_path = '<path to spider-bert-run.jsonnet>'
model_dir = '<path to model>'
checkpoint_step = 16100  # whatever checkpoint you want to use
data_conf = {'db_path': '<path to new database>',
             'name': '<name of new dataset>',
             'paths': ['<path to dev.json for new dataset>'],  # dev.json has to be generated using spider's preprocessing methods
             'tables_paths': ['<path to tables.json for new dataset>']}  # tables.json has to be generated using spider's preprocessing methods

exp_config = json.loads(_jsonnet.evaluate_file(exp_config_path))  # data_path: '<path to spider/>',
model_config_path = os.path.join(root_dir, exp_config["model_config"])
model_config_args = exp_config.get("model_config_args")
infer_config = json.loads(_jsonnet.evaluate_file(model_config_path, tla_codes={'args': json.dumps(model_config_args)}))

inferer = Inferer(infer_config)
inferer.device = torch.device("cpu")

model = inferer.load_model(model_dir, checkpoint_step)  # load the model according to the spider dataset
dataset = registry.construct('dataset', data_conf)  # load the new dataset (not spider)

for _, schema in dataset.schemas.items():
    model.preproc.enc_preproc._preprocess_schema(schema)

def question(q, db_id):
    schema = dataset.schemas[db_id]
    data_item = SpiderItem(
        text=None,  # intentionally None -- should be ignored when the tokenizer is set correctly
        code=None,
        schema=schema,
        orig_schema=schema.orig,
        orig={"question": q}
    )
    model.preproc.clear_items()
    enc_input = model.preproc.enc_preproc.preprocess_item(data_item, None)
    preproc_data = enc_input, None
    with torch.no_grad():
        return inferer._infer_one(model, data_item, preproc_data, beam_size=1, use_heuristic=True)

@mellahysf
Copy link
Author

thank you @kalleknast it works for me

@arrtvv852
Copy link

Awesome, I came up with something similar over the weekend. Thanks! So I used what I wrote on the climbing DB:
Input question: How tall is the mountain from Kenya
Output query: SELECT mountain.Height FROM mountain WHERE mountain.Country = 'terminal'

This happens when using other DBs, too. How can I get the where clause to insert the correct where value? Thank you!

@gouldju1 I found there is a parameter call include_literals in ./configs/spider/nl2code-base.libsonnet. By default is set to be false, which will cause where clause value val1 become Terminal instead of the real value in the input prompt. It should be set to true if you want where clause value included in your training data.

@FruVirus
Copy link

FruVirus commented Jan 27, 2023

Awesome, I came up with something similar over the weekend. Thanks! So I used what I wrote on the climbing DB:
Input question: How tall is the mountain from Kenya
Output query: SELECT mountain.Height FROM mountain WHERE mountain.Country = 'terminal'
This happens when using other DBs, too. How can I get the where clause to insert the correct where value? Thank you!

@gouldju1 I found there is a parameter call include_literals in ./configs/spider/nl2code-base.libsonnet. By default is set to be false, which will cause where clause value val1 become Terminal instead of the real value in the input prompt. It should be set to true if you want where clause value included in your training data.

this doesn't seem to work....

It seems that the original rat-sql does not fill in the terminal values at all. the closest one can get is to use NatSQL to fill in the values. However, the authors of NatSQL only generated the files for the Spider dataset and it's unclear how one would go about generating corresponding NatSQL files for their own dataset.

Unfortunately, this makes RAT-SQL not very useful in practice, which is a shame since the model actually performs pretty decently otherwise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants