In [1]:
import pandas as pd
import numpy as np
import copy
import pickle
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModel
from torch.nn.utils.rnn import pad_sequence
import torch
from sklearn.cluster import KMeans
import pickle
import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pd.set_option('display.max_columns', 5000)

The purpose of this notebook is to cluster instances into similar groups so that contrastive loss can effectively learn meaningful differences between them.

## Load data

In [2]:
train_reviews_features = pd.read_parquet('train_reviews_features.parquet')
train_reviews_tokens = pd.read_parquet('train_reviews_tokens.parquet')
train_matches = pd.read_csv('train_matches.csv')

## Get embeddings for K-Means

In [4]:
model_name = 'sentence-transformers/all-MiniLM-L12-v2'
model = AutoModel.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def get_review_embeddings(train_reviews_tokens, batch_size=32, max_length=128):
    """
    Get embeddings for reviews using the MiniLM model
    Args:
        train_reviews_tokens: dataframe with tokenized reviews
        batch_size: batch size
        max_length: maximum length of the review
    Returns:
        embeddings: list of embeddings for reviews
    """
    embeddings = []
    model.eval()
    with torch.no_grad(): 
        for i in tqdm(range(0, len(train_reviews_tokens), batch_size), total=len(train_reviews_tokens)//batch_size):
            batch = train_reviews_tokens.iloc[i:i + batch_size]
            input_ids = [torch.tensor(ids) for ids in batch['input_ids'].tolist()]
            attention_mask = [torch.tensor(mask) for mask in batch['attention_mask'].tolist()]

            input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
            attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)

            input_ids = input_ids[:, :max_length]
            attention_mask = attention_mask[:, :max_length]

            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
            embeddings.extend(batch_embeddings)
    
    return embeddings

def create_embedding_df(train_reviews_tokens):
    """
    Create dataframe with embeddings for reviews
    Args:
        train_reviews_tokens: dataframe with tokenized reviews
    Returns:
        embedding_df: dataframe with embeddings for reviews
    """
    embeddings = get_review_embeddings(train_reviews_tokens)
    embedding_df = pd.DataFrame(embeddings, columns=[f'emb_{i}' for i in range(embeddings[0].shape[0])])
    return embedding_df

embedding_df = create_embedding_df(train_reviews_tokens)

train_reviews_features = train_reviews_features.reset_index(drop=True)  
embedding_df = embedding_df.reset_index(drop=True)

feature_columns = train_reviews_features.columns[2:]
feature_and_embedding_df = pd.concat([train_reviews_features[feature_columns], embedding_df], axis=1)

## K-Means

In [None]:
n_clusters = 100
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans_labels = kmeans.fit_predict(feature_and_embedding_df)

train_reviews_features['kmeans_label'] = kmeans_labels

## Create Groups from K-Menas and save for dataset creation

In [None]:
review_id_to_idx = {row['review_id']: idx for idx, row in train_matches.iterrows()}

kmeans_groups = {}

grouped = train_reviews_features.groupby('kmeans_label')

for label, group in grouped:
    review_ids_in_group = group['review_id'].tolist()
    
    for review_id in review_ids_in_group:
        if review_id in review_id_to_idx:
            index = review_id_to_idx[review_id]
            
            if label not in kmeans_groups:
                kmeans_groups[label] = []
            kmeans_groups[label].append(index)

with open('kmeans_groups.pkl', 'wb') as f:
    pickle.dump(kmeans_groups, f)