## Column Embedding Extraction

This notebook extract embeddings for the 11 datasets of KE-TALENT. We use the following models:
- MPNet `sentence-transformers/all-mpnet-base-v2`
- MMTEB SOTA `Alibaba-NLP/gte-Qwen2-7B-instruct`
- MMTEB STS SOTA `Lajavaness/bilingual-embedding-large`
- (constant embeddings)

Please run `scripts/talent_data_preproc.ipynb` beforehand.

In [1]:
# Run this once at first to work on the project root
%cd ..

  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/usr2/juyongk/graph-concept-prior


In [2]:
import re
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
import json

from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Pre-processing functions
# def split_camel_case_numbers(name):
#     # Step 1: Insert space between lowercase->uppercase and uppercase->lowercase transitions
#     name = re.sub(r'([a-z])([A-Z])|([A-Z])([A-Z][a-z])', r'\1\3 \2\4', name)
#     # Step 2: Insert space before and after numbers
#     name = re.sub(r'(\d+)', r' \1 ', name)  # Adds spaces around numbers
#     return ' '.join(name.split())  # Removes any extra spaces

# def format_name(name):
#     name = name.replace('_', ' ')
#     name = split_camel_case_numbers(name)
#     name = name.lower()
#     return name

def get_full_desc_sentences(col_name_descs):
    ret = []
    for name, desc in col_name_descs:
        assert name or desc
        if not desc or name == desc:
            # ret.append(format_name(name))
            ret.append(name)
        elif not name or 'attribute' in name.lower():
            ret.append(desc)
        else:
            ret.append(name + " : " + desc)
            # ret.append(format_name(name) + " : " + desc)
    return ret

In [4]:
talent_dataset_names = [
    "Abalone_reg",
    "Diamonds",
    "Parkinsons_Telemonitoring",
    "archive_r56_Portuguese",
    "communities_and_crime",
    "Bank_Customer_Churn_Dataset",
    "statlog",
    "taiwanese_bankruptcy_prediction",
    "ASP-POTASSCO-classification",
    "internet_usage",
    "predict_students_dropout_and_academic_success",
]

talent_task_descs = [
    "Predict the age of abalone from physical measurements.",
    "Predict diamond price based on features such as carat/cut/colour/clarity/size.",
    "Predict the motor UPDRS score from biomedical voice measurements of people with early-stage Parkinson's disease.",
    "Predict student performance in secondary education (high school).",
    "Predict per capita violent crimes.",
    "Predict if the client has left the bank.",
    "Predict if a customer has good or bad credit risk.",
    "Predict the best ASP solver (algorithm) for a given problem instance.",
    "Predict whether a company will go bankrupt.",
    "Predict a user's occupation based on demographic and internet usage.",
    "Predict the student status (dropout, enrolled, and graduate) at the end of the normal duration of the course.",
]

- Load the models

In [5]:
# If you have a separate directory to keep Huggingface models. Specify here:
cache_dir='/usr1/data/LLM/huggingface/'
# cache_dir=None

In [6]:
MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
mpnet_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=cache_dir)
mpnet_model = AutoModel.from_pretrained(MODEL_NAME, cache_dir=cache_dir).cuda().eval()

def mean_pooling(token_embeddings, attention_mask):
    """
    Perform mean pooling on the token embeddings using the attention mask.
    """
    # Expand the attention mask so it matches the embeddings shape
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    # Sum the token embeddings along the sequence dimension
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
    # Count the non-padded tokens per sequence
    sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
    # Compute the mean
    return sum_embeddings / sum_mask

def get_mpnet_embeddings(sentences):
    """
    Given a list of sentences, return a tensor of normalized sentence embeddings.
    """
    # Tokenize
    inputs = mpnet_tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
    inputs = {
        k: v.cuda() for k, v in inputs.items()
    }
    
    with torch.no_grad():
        # Forward pass through model
        outputs = mpnet_model(**inputs)
    
    # Get the last hidden state
    token_embeddings = outputs.last_hidden_state
    # Perform pooling
    sentence_embeddings = mean_pooling(token_embeddings, inputs["attention_mask"])
    # (Optional) Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1).cpu()
    
    return sentence_embeddings

In [7]:
qwen_tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-Qwen2-7B-instruct', cache_dir=cache_dir, trust_remote_code=True)
qwen_model = AutoModel.from_pretrained('Alibaba-NLP/gte-Qwen2-7B-instruct', cache_dir=cache_dir, trust_remote_code=True)
qwen_model.eval()
qwen_model = qwen_model.cuda()
qwen_instruction = "Given the task description of a tabular dataset and a column description, retrive semantically relevant columns."
qwen_max_length = 8192

def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {qwen_instruction}\nTask: {task_description}\nQuery: {query}'

def get_qwen_embeddings(sentences, query=False, task_description=None):
    if query:
        sentences = [get_detailed_instruct(task_description, s) for s in sentences]
    ret = []
    for s in sentences:
        batch_dict = qwen_tokenizer([s], max_length=qwen_max_length, padding=True, truncation=True, return_tensors='pt')
        batch_dict = {k: v.cuda() for k, v in batch_dict.items()}
        with torch.no_grad():
            outputs = qwen_model(**batch_dict)
        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        embeddings = embeddings.cpu()
        ret.append(embeddings)
    ret = torch.concat(ret, dim=0)
    ret = F.normalize(ret, p=2, dim=1)
    return ret

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  7.55it/s]


In [8]:
sts_model = SentenceTransformer('Lajavaness/bilingual-embedding-large', cache_folder=cache_dir, trust_remote_code=True)
sts_model = sts_model.eval().cuda()

def get_sts_embeddings(input_sentences):
    return torch.tensor(sts_model.encode(input_sentences))

In [9]:
for dataset_subdir in talent_dataset_names:
    df_desc = pd.read_csv('data/talent/' + dataset_subdir + '/col_desc.csv', keep_default_na=False)
    col_name_descs = list(zip(df_desc.name, df_desc.desc))
    col_descs = get_full_desc_sentences(col_name_descs)
    embeds = get_mpnet_embeddings(col_descs)
    torch.save(embeds, 'data/talent/' + dataset_subdir + '/col_embed_mpnet.pt')
    
    if os.path.exists('data/talent/' + dataset_subdir + '/onehot/col_desc.csv'):
        df_desc = pd.read_csv('data/talent/' + dataset_subdir + '/onehot/col_desc.csv', keep_default_na=False)
        col_name_descs = list(zip(df_desc.name, df_desc.desc))
        col_descs = get_full_desc_sentences(col_name_descs)
        embeds = get_mpnet_embeddings(col_descs)
        torch.save(embeds, 'data/talent/' + dataset_subdir + '/onehot/col_embed_mpnet.pt')

In [10]:
for dataset_subdir in talent_dataset_names:
    df_desc = pd.read_csv('data/talent/' + dataset_subdir + '/col_desc.csv', keep_default_na=False)
    col_name_descs = list(zip(df_desc.name, df_desc.desc))
    col_descs = get_full_desc_sentences(col_name_descs)
    embeds = get_sts_embeddings(col_descs)
    torch.save(embeds, 'data/talent/' + dataset_subdir + '/col_embed_sts.pt')
    
    if os.path.exists('data/talent/' + dataset_subdir + '/onehot/col_desc.csv'):
        df_desc = pd.read_csv('data/talent/' + dataset_subdir + '/onehot/col_desc.csv', keep_default_na=False)
        col_name_descs = list(zip(df_desc.name, df_desc.desc))
        col_descs = get_full_desc_sentences(col_name_descs)
        embeds = get_sts_embeddings(col_descs)
        torch.save(embeds, 'data/talent/' + dataset_subdir + '/onehot/col_embed_sts.pt')

In [11]:
for dataset_subdir, task_desc in zip(talent_dataset_names, talent_task_descs):
    df_desc = pd.read_csv('data/talent/' + dataset_subdir + '/col_desc.csv', keep_default_na=False)
    col_name_descs = list(zip(df_desc.name, df_desc.desc))
    col_descs = get_full_desc_sentences(col_name_descs)
    query_embeds = get_qwen_embeddings(col_descs, query=True, task_description=task_desc)
    doc_embeds = get_qwen_embeddings(col_descs, query=False)
    torch.save({'query': query_embeds, 'doc': doc_embeds}, 'data/talent/' + dataset_subdir + '/col_embed_qwen.pt')
    
    if os.path.exists('data/talent/' + dataset_subdir + '/onehot/col_desc.csv'):
        df_desc = pd.read_csv('data/talent/' + dataset_subdir + '/onehot/col_desc.csv', keep_default_na=False)
        col_name_descs = list(zip(df_desc.name, df_desc.desc))
        col_descs = get_full_desc_sentences(col_name_descs)
        query_embeds = get_qwen_embeddings(col_descs, query=True, task_description=task_desc)
        doc_embeds = get_qwen_embeddings(col_descs, query=False)
        torch.save({'query': query_embeds, 'doc': doc_embeds}, 'data/talent/' + dataset_subdir + '/onehot/col_embed_qwen.pt')

In [12]:
embed_dim = 100
for dataset_subdir in talent_dataset_names:
    df_desc = pd.read_csv('data/talent/' + dataset_subdir + '/col_desc.csv', keep_default_na=False)
    embeds = torch.ones((len(df_desc), embed_dim), dtype=torch.float32) + torch.randn((len(df_desc), embed_dim)) * 0.01
    embeds = F.normalize(embeds, p=2, dim=1)
    torch.save(embeds, 'data/talent/' + dataset_subdir + '/col_embed_unif.pt')
    
    if os.path.exists('data/talent/' + dataset_subdir + '/onehot/col_desc.csv'):
        df_desc = pd.read_csv('data/talent/' + dataset_subdir + '/onehot/col_desc.csv', keep_default_na=False)
        embeds = torch.ones((len(df_desc), embed_dim), dtype=torch.float32) + torch.randn((len(df_desc), embed_dim)) * 0.01
        embeds = F.normalize(embeds, p=2, dim=1)
        torch.save(embeds, 'data/talent/' + dataset_subdir + '/onehot/col_embed_unif.pt')

### Visualize the concept kernel

In [None]:
dataset_subdir = talent_dataset_names[0]
use_orig_format = False

dataset_subdir

In [None]:
if os.path.exists('data/talent/' + dataset_subdir + '/onehot/col_desc.csv') and (not use_orig_format):
    orig_embed = torch.load(f'data/talent/{dataset_subdir}/onehot/col_embed_mpnet.pt', weights_only=False)
    sts_embed = torch.load(f'data/talent/{dataset_subdir}/onehot/col_embed_sts.pt', weights_only=False)
    qwen_embed = torch.load(f'data/talent/{dataset_subdir}/onehot/col_embed_qwen.pt', weights_only=False)
    info = json.load(open(f'data/talent/{dataset_subdir}/onehot/info.json'))
    df_desc = pd.read_csv(f'data/talent/{dataset_subdir}/onehot/col_desc.csv', keep_default_na=False)
else:
    orig_embed = torch.load(f'data/talent/{dataset_subdir}/col_embed_mpnet.pt', weights_only=False)
    sts_embed = torch.load(f'data/talent/{dataset_subdir}/col_embed_sts.pt', weights_only=False)
    qwen_embed = torch.load(f'data/talent/{dataset_subdir}/col_embed_qwen.pt', weights_only=False)
    info = json.load(open(f'data/talent/{dataset_subdir}/info.json'))
    df_desc = pd.read_csv(f'data/talent/{dataset_subdir}/col_desc.csv', keep_default_na=False)
col_name_descs = list(zip(df_desc.name, df_desc.desc))
col_sents = get_full_desc_sentences(col_name_descs)

In [None]:
qwen_embed

In [None]:
W = qwen_embed['doc'] @ qwen_embed['query'].T
D = W.sum(dim=0, keepdims=True) ** -1
T = W * D

In [None]:
eig_vals, eig_vecs = torch.linalg.eig(T)
eig_indices = torch.argsort(eig_vals.abs(), descending=True)
eig_vals, eig_vecs = eig_vals[eig_indices], eig_vecs[:, eig_indices]

In [None]:
assert torch.allclose(eig_vals[0].real, torch.tensor(1.0)), "First eigen value != 1.0"
assert torch.allclose(eig_vals[0].imag, torch.tensor(0.0)), "First eigen value != 1.0"
pi = eig_vecs[:, 0].real
pi = pi.abs() / pi.norm(p=1)

In [None]:
orig_score = orig_embed @ orig_embed.T
sts_score = sts_embed @ sts_embed.T
qwen_score = qwen_embed['query'] @ qwen_embed['doc'].T
qwen_score_norm = (qwen_score + qwen_score.T)/2
d = qwen_score_norm[range(len(qwen_score)),range(len(qwen_score))] ** -0.5
qwen_score_norm = d.unsqueeze(1) * d.unsqueeze(0) * qwen_score_norm
qwen_embed_avg = F.normalize(qwen_embed['query']+qwen_embed['doc'], p=2, dim=1)
qwen_score2 = qwen_embed_avg @ qwen_embed_avg.T

In [None]:
print('Orig min:', orig_score.min(), 'max:', orig_score.max())
print('STS min:', sts_score.min(), 'max:', sts_score.max())
print('Qwen min:', qwen_score.min(), 'max:', qwen_score.max())
print('Qwen norm min:', qwen_score_norm.min(), 'max:', qwen_score_norm.max())
print('Qwen2 min:', qwen_score2.min(), 'max:', qwen_score2.max())

In [None]:
col_sents

In [None]:
plt.figure(figsize=(12, 10))
plt.imshow(orig_score, vmin=-1.0, vmax=1.0, cmap='bwr')
plt.yticks(range(len(orig_score)), df_desc.name, fontsize=6)
plt.colorbar();

In [None]:
plt.figure(figsize=(12, 10))
plt.imshow(sts_score, vmin=-1.0, vmax=1.0, cmap='bwr')
plt.yticks(range(len(orig_score)), df_desc.name, fontsize=6)
plt.colorbar();

In [None]:
plt.figure(figsize=(12, 10))
plt.imshow(qwen_score, vmin=-1.0, vmax=1.0, cmap='bwr')
plt.yticks(range(len(orig_score)), df_desc.name, fontsize=6)
plt.colorbar();

In [None]:
plt.figure(figsize=(12, 10))
plt.imshow(qwen_score_norm, vmin=-1.0, vmax=1.0, cmap='bwr')
plt.yticks(range(len(orig_score)), df_desc.name, fontsize=6)
plt.colorbar();

In [None]:
plt.figure(figsize=(12, 10))
plt.imshow(qwen_score2, vmin=-1.0, vmax=1.0, cmap='bwr')
plt.yticks(range(len(orig_score)), df_desc.name, fontsize=6)
plt.colorbar();

## Extract preprocessed input
- Save preprocessed input (cat and num columns concatenated) into a single file

In [47]:
from gcp.datasets import load_dataset
from gcp.models.base import BaseDeepModel
from omegaconf import OmegaConf

In [48]:
for dataset_subdir in talent_dataset_names:
    conf = OmegaConf.load('configs/dataset/talent/' + dataset_subdir + '.yaml')
    conf.data_common.use_onehot = True
    conf.params.train = conf.data_common
    conf.params.val = conf.data_common
    conf.params.test = conf.data_common
    
    train_dataset = load_dataset(conf, 'train')
    val_dataset = load_dataset(conf, 'val')
    test_dataset = load_dataset(conf, 'test')
    
    # embed = train_dataset.kg.x.shape
    metadata = train_dataset.metadata
    
    model = BaseDeepModel(task_type=metadata['task_type'], num_classes=metadata['task_type'], n_num_features=metadata['n_num_features'], n_cat_features=metadata['n_cat_features'], kg=None, metadata=metadata)
    model.data_preproc(train_dataset)
    model.data_preproc(val_dataset)
    model.data_preproc(test_dataset)

    for split, dataset in [('train', train_dataset), ('val', val_dataset), ('test', test_dataset)]:
        X = [dataset.X_num, dataset.X_cat]
        X = np.concat([t for t in X if t is not None], axis=1)
        np.save('data/talent/' + dataset_subdir + f'/X_{split}.npy', X)

One-hot converted dataset does not exists. Fall back to the original
One-hot converted dataset does not exists. Fall back to the original
One-hot converted dataset does not exists. Fall back to the original
One-hot converted dataset does not exists. Fall back to the original
One-hot converted dataset does not exists. Fall back to the original
One-hot converted dataset does not exists. Fall back to the original
