In [None]:
from datasets import load_dataset
from gensim.models import Word2Vec
import pandas as pd
from scipy.stats import spearmanr, pearsonr
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sklearn.metrics import mean_squared_error
import re
import numpy as np
import matplotlib.pyplot as plt


In [None]:
# Stream dataset
dataset = load_dataset("c4", "en", split="train", streaming=True)
nltk.download('stopwords')
nltk.download('wordnet')

stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()

In [None]:
def preprocess(text):
    # Lowercasing and removing special characters
    text = re.sub(r'[^a-zA-Z\s]', '', text,re.I|re.A).lower()
    words = text.split()
    processed_words = []
    for word in words:
        if word not in stop_words:
            lemma = lemmatizer.lemmatize(word)
            processed_words.append(lemma)
    return processed_words

# Initialize model
word2vec_model = None

In [None]:
# Process and train
max_docs = 1000000
initial_vocab_size = 5000  # Number of documents to build initial vocabulary
update_vocab_size = 500    # Update vocabulary every 'update_vocab_size' documents

processed_texts = []
for i, doc in enumerate(dataset.take(max_docs)):
    processed_text = preprocess(doc['text'])
    processed_texts.append(processed_text)
    if word2vec_model is None and len(processed_texts) >= initial_vocab_size:
        word2vec_model = Word2Vec(processed_texts, vector_size=150, window=7, min_count=3, workers=4, epochs=20)
        processed_texts = []
    elif word2vec_model is not None and len(processed_texts) >= update_vocab_size:
        word2vec_model.build_vocab(processed_texts, update=True)
        word2vec_model.train(processed_texts, total_examples=len(processed_texts), epochs=word2vec_model.epochs)
        processed_texts = []

    if i % 1000 == 0:
        print(f"Processed {i} documents")

In [None]:
# Load Simlex 999 dataset
simlex = pd.read_csv('./simlex999.csv')

# Function to calculate similarity using the trained Word2Vec model
def calculate_similarity(model, word1, word2):
    if word1 in model.wv.key_to_index and word2 in model.wv.key_to_index:
        return model.wv.similarity(word1, word2)
    else:
        return 0

# Calculate similarities for each pair in Simlex 999
simlex['predicted_score_word2vec'] = simlex.apply(lambda row: calculate_similarity(word2vec_model, row['word1'], row['word2'])*10, axis=1)

In [None]:
actual_scores = simlex['sim_value']
predicted_scores = simlex['predicted_score_word2vec']

# Spearman's Correlation Coefficient
spearman_corr, _ = spearmanr(actual_scores, predicted_scores)
print(f"Spearman's Correlation: {spearman_corr}")

# Pearson Correlation Coefficient
pearson_corr, _ = pearsonr(actual_scores, predicted_scores)
print(f"Pearson Correlation: {pearson_corr}")

# Mean Squared Error (MSE)
mse = mean_squared_error(actual_scores, predicted_scores)

# Root Mean Squared Error (RMSE)
rmse = np.sqrt(mse)
print(f"Root Mean Squared Error: {rmse}")


In [None]:
categories = simlex['label1'].unique()
spearman_corrs = []
pearson_corrs = []

# Category-wise Performance Analysis
for category in simlex['label1'].unique():
    category_data = simlex[simlex['label1'] == category]
    actual_scores = category_data['sim_value']
    predicted_scores = category_data['predicted_score_word2vec']

    spearman_corr, _ = spearmanr(actual_scores, predicted_scores)
    pearson_corr, _ = pearsonr(actual_scores, predicted_scores)
    
    print(f"Category: {category}")
    print(f"Spearman's Correlation: {spearman_corr}")
    print(f"Pearson Correlation: {pearson_corr}")
    print()

    spearman_corrs.append(spearman_corr)
    pearson_corrs.append(pearson_corr)

# Plotting
plt.figure(figsize=(12, 6))
x = range(len(categories))
plt.bar(x, spearman_corrs, width=0.4, label='Spearman', align='center')
plt.bar(x, pearson_corrs, width=0.4, label='Pearson', align='edge')
plt.xlabel('Category')
plt.ylabel('Correlation Coefficient')
plt.title('Category-wise Performance Analysis')
plt.xticks(x, categories)
plt.legend()
plt.show()
