In [1]:
import os
import random
import time
import yaml

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler

from transformers import AutoTokenizer, AutoModel


os.chdir('/home/mrsergazinov/TabLLM/initial_exp/')
from retrieval_models import *

In [2]:
# Define a custom dataset to keep track of indices
class IndexedTensorDataset(Dataset):
    def __init__(self, tensors_num, tensors_cat, targets):
        self.tensors_num = tensors_num
        self.tensors_cat = tensors_cat
        self.targets = targets
        self.indices = torch.arange(len(tensors_num), dtype=torch.long)

    def __getitem__(self, index):
        return (self.tensors_num[index], self.tensors_cat[index], self.targets[index], self.indices[index])

    def __len__(self):
        return len(self.tensors_num)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def load_config(config_file):
    with open(config_file, 'r') as file:
        config = yaml.safe_load(file)
    return config

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def label_encode(X, categorical_columns):
    le_categorical = {}
    for col in categorical_columns:
        le = LabelEncoder()
        X[col] = le.fit_transform(X[col].astype(str))
        le_categorical[col] = le  # Save each encoder if needed for inverse transformation later
    return X, le_categorical

def llm_encoder(X, categorical_columns, batch_size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Step 1: Prepare text embeddings
    embeddings = []
    for idx in range(X.shape[0]):
        string = ''
        for column in categorical_columns:
            string += f"{column}: {X[column][idx]}. "
        embeddings.append(string)

    # Load tokenizer and model, move model to device
    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-MiniLM-L6-v2')
    model = AutoModel.from_pretrained('sentence-transformers/paraphrase-MiniLM-L6-v2').to(device)

    # Step 2: Process embeddings in batches
    batch_embeddings = []
    with torch.no_grad():
        for i in range(0, len(embeddings), batch_size):
            batch_texts = embeddings[i:i+batch_size]
            encoded_input = tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt').to(device)
            model_output = model(**encoded_input)
            batch_embeddings.append(mean_pooling(model_output, encoded_input['attention_mask']).cpu())
            print(f'Processed {i+batch_size}/{len(embeddings)} embeddings')
    
    # Concatenate all batch embeddings
    embeddings_tensor = torch.cat(batch_embeddings, dim=0)
    embeddings_df = pd.DataFrame(embeddings_tensor.numpy(), columns=[f'embedding_{i}' for i in range(embeddings_tensor.shape[1])])

    # Clean up memory
    del encoded_input, model_output, batch_embeddings, embeddings_tensor, tokenizer, model
    torch.cuda.empty_cache()

    # Step 3: Concatenate embeddings with the original DataFrame and drop categorical columns
    X = pd.concat([X.reset_index(drop=True), embeddings_df], axis=1)
    X = X.drop(columns=categorical_columns)
    
    return X

def onehot_encode(X, categorical_columns):
    X = pd.get_dummies(X, columns=categorical_columns, drop_first = True)
    return X

def load_dataset(config):
    # Fetch the dataset
    data = fetch_openml(config['dataset']['name'], version=config['dataset']['version'], as_frame=True)
    X = data['data'].copy()
    y = data['target']

    # Identify categorical and numerical columns
    categorical_columns = X.select_dtypes(include=['category', 'object']).columns.tolist()
    numerical_columns = X.select_dtypes(include=['int64', 'float64']).columns.tolist()

    if config['dataset']['cat_encode'] == 'label':
        X, _ = label_encode(X, categorical_columns)
    elif config['dataset']['cat_encode'] == 'onehot':
        X = onehot_encode(X, categorical_columns)
    elif config['dataset']['cat_encode'] == 'llm':
        X = llm_encoder(X, categorical_columns)
    
    # Scale numerical columns
    numerical_transformer = StandardScaler()
    X[numerical_columns] = numerical_transformer.fit_transform(X[numerical_columns])

    # Encode the target variable
    le_target = LabelEncoder()
    y = le_target.fit_transform(y)

    # Convert X and y to tensor
    y = torch.tensor(y, dtype=torch.long)
    if config['dataset']['all_num']:
        X_num = torch.tensor(X.values, dtype=torch.float32)
        X_cat = X_num.clone()
        d_in_num = X_num.shape[1]
        d_in_cat = 0
    else:
        X_num = torch.tensor(X[numerical_columns].values, dtype=torch.float32)
        X_cat = torch.tensor(X.drop(columns=numerical_columns).values, dtype=torch.float32)
        d_in_num = X_num.shape[1]
        d_in_cat = X_cat.shape[1]

    return (X_num, X_cat, y, d_in_num, d_in_cat, le_target)


def evaluate(config, model, test_loader, criterion, device, X_num_train, X_cat_train, y_train):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0

    with torch.no_grad():
        for X_num_batch, X_cat_batch, y_batch, idx_batch in test_loader:
            X_num_batch = X_num_batch.to(device)
            X_cat_batch = X_cat_batch.to(device) if not config['dataset']['all_num'] else None
            y_batch = y_batch.to(device)

            # Use entire training data as candidates during evaluation
            candidate_x_num = X_num_train.to(device) 
            candidate_x_cat = X_cat_train.to(device) if not config['dataset']['all_num'] else None
            candidate_y = y_train.to(device)

            # Forward pass with separate categorical and numerical features
            logits = model(
                x_num=X_num_batch,
                x_cat=X_cat_batch,
                y=None,
                candidate_x_num=candidate_x_num,
                candidate_x_cat=candidate_x_cat,
                candidate_y=candidate_y
            )

            # Convert logits to predictions
            _, predicted = torch.max(logits, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()

            # Compute loss
            loss = criterion(logits, y_batch)
            test_loss += loss.item() * y_batch.size(0)

    accuracy = 100 * correct / total
    test_loss = test_loss / total

    print(f'Test Loss: {test_loss:.4f} | Test Accuracy: {accuracy:.2f}%')

def train(config):
    set_seed(config['dataset']['random_state'])

    X_num, X_cat, y, d_in_num, d_in_cat, le_target = load_dataset(config)
    output_classes = len(le_target.classes_)

    X_num_train, X_num_test, X_cat_train, X_cat_test, y_train, y_test = train_test_split(
        X_num, X_cat, y, test_size=config['dataset']['test_size'], random_state=config['dataset']['random_state']
    )

    # Use IndexedTensorDataset for separate numerical and categorical data
    train_dataset = IndexedTensorDataset(X_num_train, X_cat_train, y_train)
    test_dataset = IndexedTensorDataset(X_num_test, X_cat_test, y_test)

    train_loader = DataLoader(
        train_dataset, batch_size=config['dataset']['batch_size'], shuffle=True,
        pin_memory=True, num_workers=4
    )
    test_loader = DataLoader(
        test_dataset, batch_size=config['dataset']['batch_size'], shuffle=False,
        pin_memory=True, num_workers=4
    )

    # Initialize the ModernNCA model
    model = ModernNCA(
        d_in_num=d_in_num,
        d_in_cat=d_in_cat,
        d_out=output_classes,
        dim=config['model']['dim'],
        dropout=config['model']['dropout'],
        n_frequencies=config['model']['n_frequencies'],
        frequency_scale=config['model']['frequency_scale'],
        d_embedding=config['model']['d_embedding'],
        lite=config['model']['lite'],
        temperature=config['model']['temperature'],
        sample_rate=config['model']['sample_rate'],
        use_llama=config['model']['use_llama'],
        llama_model_name=config['model']['llama_model_name'],
        start_layer=config['model']['start_layer'],
        end_layer=config['model']['end_layer']
    )
    model = model.float()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config['training']['learning_rate'],
        weight_decay=config['training']['weight_decay']
    )

    # Training loop
    for epoch in range(config['training']['epochs']):
        model.train()
        start_time = time.time()
        epoch_loss = 0.0
        correct = 0
        total = 0

        for itr, (X_num_batch, X_cat_batch, y_batch, idx_batch) in enumerate(train_loader):
            X_num_batch = X_num_batch.to(device)
            X_cat_batch = X_cat_batch.to(device) if not config['dataset']['all_num'] else None
            y_batch = y_batch.to(device)

            # Exclude current batch for candidates
            mask = ~torch.isin(torch.arange(X_num_train.shape[0]), idx_batch)
            true_indices = torch.arange(X_num_train.shape[0])[mask]
            num_samples = int(len(true_indices) * config['model']['sample_rate'])
            sampled_indices = true_indices[torch.randperm(len(true_indices))[:num_samples]]
            sampled_mask = torch.zeros_like(mask, dtype=torch.bool)
            sampled_mask[sampled_indices] = True

            # Use the new sampled_mask to filter out elements
            candidate_x_num = X_num_train[sampled_mask].to(device)
            candidate_x_cat = X_cat_train[sampled_mask].to(device) if not config['dataset']['all_num'] else None
            candidate_y = y_train[sampled_mask].to(device)

            optimizer.zero_grad()
            # Forward pass with separate categorical and numerical features
            logits = model(
                x_num=X_num_batch,
                x_cat=X_cat_batch,
                y=y_batch,
                candidate_x_num=candidate_x_num,
                candidate_x_cat=candidate_x_cat,
                candidate_y=candidate_y
            )
            loss = criterion(logits, y_batch)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * y_batch.size(0)
            _, predicted = torch.max(logits, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()

            if itr % 50 == 0:
                print(f'Epoch [{epoch+1}/{config["training"]["epochs"]}]: Batch [{itr+1}/{len(train_loader)}] | Accuracy: {correct/total:.2f}')

        epoch_loss = epoch_loss / total
        epoch_acc = 100 * correct / total
        epoch_time = time.time() - start_time

        print(f'Epoch [{epoch+1}/{config["training"]["epochs"]}] | Loss: {epoch_loss:.4f} | '
                f'Accuracy: {epoch_acc:.2f}% | Time: {epoch_time:.2f}s')

    evaluate(config, model, test_loader, criterion, device, X_num_train, X_cat_train, y_train)

    # save the model
    path = config['training']['model_path']
    torch.save(model.state_dict(), path)
    model.load_state_dict(torch.load(path))

    evaluate(config, model, test_loader, criterion, device, X_num_train, X_cat_train, y_train)

In [3]:
config = load_config('retrieval_config.yaml')

In [4]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
train(config)

Epoch [1/10]: Batch [1/1222] | Accuracy: 0.81
Epoch [1/10]: Batch [51/1222] | Accuracy: 0.81
Epoch [1/10]: Batch [101/1222] | Accuracy: 0.83
Epoch [1/10]: Batch [151/1222] | Accuracy: 0.83
Epoch [1/10]: Batch [201/1222] | Accuracy: 0.83
Epoch [1/10]: Batch [251/1222] | Accuracy: 0.84
Epoch [1/10]: Batch [301/1222] | Accuracy: 0.84
Epoch [1/10]: Batch [351/1222] | Accuracy: 0.84
Epoch [1/10]: Batch [401/1222] | Accuracy: 0.84
Epoch [1/10]: Batch [451/1222] | Accuracy: 0.84
Epoch [1/10]: Batch [501/1222] | Accuracy: 0.84
Epoch [1/10]: Batch [551/1222] | Accuracy: 0.84
Epoch [1/10]: Batch [601/1222] | Accuracy: 0.85
Epoch [1/10]: Batch [651/1222] | Accuracy: 0.85
Epoch [1/10]: Batch [701/1222] | Accuracy: 0.85
Epoch [1/10]: Batch [751/1222] | Accuracy: 0.85
Epoch [1/10]: Batch [801/1222] | Accuracy: 0.85
Epoch [1/10]: Batch [851/1222] | Accuracy: 0.85
Epoch [1/10]: Batch [901/1222] | Accuracy: 0.85
Epoch [1/10]: Batch [951/1222] | Accuracy: 0.85
Epoch [1/10]: Batch [1001/1222] | Accuracy: