# SNIPS + Embedding-Augmented FLAN-T5
Fine-tuning FLAN-T5 for real-world intent and slot detection using the SNIPS dataset with SBERT embeddings.

In [None]:
!pip install -q transformers datasets accelerate sentence-transformers seqeval

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import pandas as pd, torch, json

In [None]:
# 1. Load SNIPS dataset
ds = load_dataset('snips_built_in_intents')

In [None]:
# 2. Format to T5-style examples
def prepare(example):
    intent = example['intent']
    slots = example['slots']
    example['target'] = json.dumps({'intent': intent, **slots})
    return example

ds = ds.map(prepare)

In [None]:
# 3. Add embeddings
embedder = SentenceTransformer('all-MiniLM-L6-v2')
def add_embed(ex):
    ex['emb'] = embedder.encode(ex['text']).tolist()
    return ex

ds = ds.map(add_embed)

In [None]:
# 4. Tokenizer and FLAN-T5
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')
model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base')

def preprocess(ex):
    prefix = ' '.join(map(lambda x: str(round(x, 4)), ex['emb'][:16]))
    inp = prefix + ' | ' + ex['text']
    tokens = tokenizer(inp, max_length=128, truncation=True)
    tgt = tokenizer(ex['target'], max_length=64, truncation=True)
    tokens['labels'] = tgt['input_ids']
    return tokens

tokenized = ds.map(preprocess)

In [None]:
# 5. Fine-tuning
args = Seq2SeqTrainingArguments(
    output_dir='snips_augmented',
    per_device_train_batch_size=8,
    num_train_epochs=5,
    logging_steps=10,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tokenized['train'],
    eval_dataset=tokenized['validation'],
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model)
)
trainer.train()

In [None]:
# 6. Inference
def predict(text):
    emb = embedder.encode(text)
    prefix = ' '.join(map(lambda x: str(round(x,4)), emb[:16]))
    inp = prefix + ' | ' + text
    tokens = tokenizer(inp, return_tensors='pt').to(model.device)
    out = model.generate(**tokens, max_length=64)
    return tokenizer.decode(out[0], skip_special_tokens=True)

print(predict('Play the last song from Coldplay'))
print(predict('Find me a restaurant in New York tomorrow night'))