In [21]:
!pip install -U datasets



In [22]:
import polars as pl
from datasets import load_dataset

train_dataset = load_dataset("tasksource/esci", split="train[:500000]")
test_dataset = load_dataset("tasksource/esci", split="test[:250000]")

train_df = pl.from_pandas(train_dataset.to_pandas())
test_df = pl.from_pandas(test_dataset.to_pandas())

cols_to_be_dropped = [
    "example_id", "query_id", "product_id",
    "small_version", "large_version",
]

train_df = train_df.drop(cols_to_be_dropped)
test_df = test_df.drop(cols_to_be_dropped)

print(train_df.schema)
print(train_df.head(5))

Schema([('query', String), ('product_locale', String), ('esci_label', String), ('product_title', String), ('product_description', String), ('product_bullet_point', String), ('product_brand', String), ('product_color', String), ('product_text', String)])
shape: (5, 9)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ query     ┆ product_l ┆ esci_labe ┆ product_t ┆ … ┆ product_b ┆ product_b ┆ product_c ┆ product_ │
│ ---       ┆ ocale     ┆ l         ┆ itle      ┆   ┆ ullet_poi ┆ rand      ┆ olor      ┆ text     │
│ str       ┆ ---       ┆ ---       ┆ ---       ┆   ┆ nt        ┆ ---       ┆ ---       ┆ ---      │
│           ┆ str       ┆ str       ┆ str       ┆   ┆ ---       ┆ str       ┆ str       ┆ str      │
│           ┆           ┆           ┆           ┆   ┆ str       ┆           ┆           ┆          │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ revent 80 ┆ us        ┆

In [23]:
# print(train_df.schema)
print(train_df.head(5))

shape: (5, 9)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ query     ┆ product_l ┆ esci_labe ┆ product_t ┆ … ┆ product_b ┆ product_b ┆ product_c ┆ product_ │
│ ---       ┆ ocale     ┆ l         ┆ itle      ┆   ┆ ullet_poi ┆ rand      ┆ olor      ┆ text     │
│ str       ┆ ---       ┆ ---       ┆ ---       ┆   ┆ nt        ┆ ---       ┆ ---       ┆ ---      │
│           ┆ str       ┆ str       ┆ str       ┆   ┆ ---       ┆ str       ┆ str       ┆ str      │
│           ┆           ┆           ┆           ┆   ┆ str       ┆           ┆           ┆          │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ revent 80 ┆ us        ┆ Irrelevan ┆ Panasonic ┆ … ┆ WhisperCe ┆ Panasonic ┆ White     ┆ Panasoni │
│ cfm       ┆           ┆ t         ┆ FV-20VQ3  ┆   ┆ iling     ┆           ┆           ┆ c        │
│           ┆           ┆           ┆ WhisperCe ┆   ┆ fans      ┆           ┆

In [24]:
train_us = train_df.filter(pl.col("product_locale") == "us")

query_counts = (
    train_us
    .group_by("query")
    .agg(pl.count().alias("product_count"))
    .filter(pl.col("product_count") >= 10)
)

filtered_df = train_us.join(query_counts, on="query", how="inner")
# filtered_df = filtered_df.drop("product_count")

print(filtered_df.shape)
print(filtered_df.head(5))

(274200, 10)
shape: (5, 10)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ query     ┆ product_l ┆ esci_labe ┆ product_t ┆ … ┆ product_b ┆ product_c ┆ product_t ┆ product_ │
│ ---       ┆ ocale     ┆ l         ┆ itle      ┆   ┆ rand      ┆ olor      ┆ ext       ┆ count    │
│ str       ┆ ---       ┆ ---       ┆ ---       ┆   ┆ ---       ┆ ---       ┆ ---       ┆ ---      │
│           ┆ str       ┆ str       ┆ str       ┆   ┆ str       ┆ str       ┆ str       ┆ u32      │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ revent 80 ┆ us        ┆ Irrelevan ┆ Panasonic ┆ … ┆ Panasonic ┆ White     ┆ Panasonic ┆ 16       │
│ cfm       ┆           ┆ t         ┆ FV-20VQ3  ┆   ┆           ┆           ┆ FV-20VQ3  ┆          │
│           ┆           ┆           ┆ WhisperCe ┆   ┆           ┆           ┆ WhisperCe ┆          │
│           ┆           ┆           ┆ il…       ┆   ┆          

  .agg(pl.count().alias("product_count"))


In [32]:
label_counts = (
    filtered_df
    .group_by("esci_label")
    .agg(pl.count().alias("count"))
    .sort("count", descending=True)
)

print(label_counts)

shape: (4, 2)
┌────────────┬────────┐
│ esci_label ┆ count  │
│ ---        ┆ ---    │
│ str        ┆ u32    │
╞════════════╪════════╡
│ Exact      ┆ 180619 │
│ Substitute ┆ 57393  │
│ Irrelevant ┆ 30223  │
│ Complement ┆ 5965   │
└────────────┴────────┘


  .agg(pl.count().alias("count"))


In [25]:
def filter_us_and_queries(df: pl.DataFrame, limit: int = None) -> pl.DataFrame:
    df_us = df.filter(pl.col("product_locale") == "us")

    # Queries with >= 10 products
    query_counts = (
        df_us
        .group_by("query")
        .agg(pl.count().alias("product_count"))
        .filter(pl.col("product_count") >= 10)
    )

    # Join and optionally limit
    filtered = df_us.join(query_counts, on="query", how="inner").drop("product_count")

    return filtered.head(limit) if limit is not None else filtered

filtered_train_df = filter_us_and_queries(train_df, limit=75_000)
filtered_test_df = filter_us_and_queries(test_df, limit=25000)

print(f"Train shape: {filtered_train_df.shape}")
print(f"Test shape:  {filtered_test_df.shape}")
print(filtered_train_df.head(3))

  .agg(pl.count().alias("product_count"))


Train shape: (75000, 9)
Test shape:  (25000, 9)
shape: (3, 9)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ query     ┆ product_l ┆ esci_labe ┆ product_t ┆ … ┆ product_b ┆ product_b ┆ product_c ┆ product_ │
│ ---       ┆ ocale     ┆ l         ┆ itle      ┆   ┆ ullet_poi ┆ rand      ┆ olor      ┆ text     │
│ str       ┆ ---       ┆ ---       ┆ ---       ┆   ┆ nt        ┆ ---       ┆ ---       ┆ ---      │
│           ┆ str       ┆ str       ┆ str       ┆   ┆ ---       ┆ str       ┆ str       ┆ str      │
│           ┆           ┆           ┆           ┆   ┆ str       ┆           ┆           ┆          │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ revent 80 ┆ us        ┆ Irrelevan ┆ Panasonic ┆ … ┆ WhisperCe ┆ Panasonic ┆ White     ┆ Panasoni │
│ cfm       ┆           ┆ t         ┆ FV-20VQ3  ┆   ┆ iling     ┆           ┆           ┆ c        │
│           ┆           ┆    

In [33]:
complimentary_numbers = train_us.filter(pl.col("esci_label") == "Complement")
print(complimentary_numbers.head(5))

shape: (5, 9)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ query     ┆ product_l ┆ esci_labe ┆ product_t ┆ … ┆ product_b ┆ product_b ┆ product_c ┆ product_ │
│ ---       ┆ ocale     ┆ l         ┆ itle      ┆   ┆ ullet_poi ┆ rand      ┆ olor      ┆ text     │
│ str       ┆ ---       ┆ ---       ┆ ---       ┆   ┆ nt        ┆ ---       ┆ ---       ┆ ---      │
│           ┆ str       ┆ str       ┆ str       ┆   ┆ ---       ┆ str       ┆ str       ┆ str      │
│           ┆           ┆           ┆           ┆   ┆ str       ┆           ┆           ┆          │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ heat      ┆ us        ┆ Complemen ┆ Panasonic ┆ … ┆ Retrofit  ┆ Panasonic ┆ White     ┆ Panasoni │
│ recovery  ┆           ┆ t         ┆ FV-0811VF ┆   ┆ Solution: ┆           ┆           ┆ c FV-081 │
│ ventilato ┆           ┆           ┆ 5 Whisper ┆   ┆ Ideal for ┆           ┆

In [34]:
def create_input_polars(df: pl.DataFrame, product_cols: list[str]) -> pl.DataFrame:
    for col in  product_cols:
        df = df.with_columns(pl.col(col).fill_null("").alias(col))

    def clean_text(s: str) -> str:
        return (
            str(s)
            .replace('"', "'")
            .replace("\n", " ")
            .replace("\r", " ")
            .replace("\t", " ")
            .strip()
        )

    def build_input_string(row : dict) -> str:
        parts = [f"{col} : {clean_text(row[col])}" for col in product_cols]
        return "[CLS] " + " ".join(parts) + " [SEP]"

    df = df.with_columns(
        pl.struct(product_cols).map_elements(build_input_string).alias("product_input")
    )

    return df

product_cols = ["product_title", "product_brand", "product_bullet_point", "product_description", "product_color"]

filtered_train_df = create_input_polars(filtered_train_df, product_cols)
filtered_test_df = create_input_polars(filtered_test_df, product_cols)

print(filtered_train_df.select("product_input").head(1))

  df = df.with_columns(
  df = df.with_columns(


shape: (1, 1)
┌─────────────────────────────────┐
│ product_input                   │
│ ---                             │
│ str                             │
╞═════════════════════════════════╡
│ [CLS] product_title : Panasoni… │
└─────────────────────────────────┘


In [27]:
print(filtered_train_df.head())

shape: (5, 10)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ query     ┆ product_l ┆ esci_labe ┆ product_t ┆ … ┆ product_b ┆ product_c ┆ product_t ┆ product_ │
│ ---       ┆ ocale     ┆ l         ┆ itle      ┆   ┆ rand      ┆ olor      ┆ ext       ┆ input    │
│ str       ┆ ---       ┆ ---       ┆ ---       ┆   ┆ ---       ┆ ---       ┆ ---       ┆ ---      │
│           ┆ str       ┆ str       ┆ str       ┆   ┆ str       ┆ str       ┆ str       ┆ str      │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ revent 80 ┆ us        ┆ Irrelevan ┆ Panasonic ┆ … ┆ Panasonic ┆ White     ┆ Panasonic ┆ [CLS]    │
│ cfm       ┆           ┆ t         ┆ FV-20VQ3  ┆   ┆           ┆           ┆ FV-20VQ3  ┆ product_ │
│           ┆           ┆           ┆ WhisperCe ┆   ┆           ┆           ┆ WhisperCe ┆ title :  │
│           ┆           ┆           ┆ il…       ┆   ┆           ┆           

In [35]:
def apply_label_mapping(df: pl.DataFrame) -> pl.DataFrame:
    return df.with_columns(
        pl.when(pl.col("esci_label") == "Irrelevant").then(0.0)
        .when(pl.col("esci_label") == "Complement").then(0.01)
        .when(pl.col("esci_label") == "Substitute").then(0.1)
        .when(pl.col("esci_label") == "Exact").then(1.0)
        .otherwise(None)
        .alias("esci_label")
    )

filtered_train_df = apply_label_mapping(filtered_train_df)
filtered_test_df = apply_label_mapping(filtered_test_df)

In [None]:
print(filtered_train_df.head())

shape: (5, 11)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ query     ┆ product_l ┆ esci_labe ┆ product_t ┆ … ┆ product_c ┆ product_t ┆ product_i ┆ label_sc │
│ ---       ┆ ocale     ┆ l         ┆ itle      ┆   ┆ olor      ┆ ext       ┆ nput      ┆ ore      │
│ str       ┆ ---       ┆ ---       ┆ ---       ┆   ┆ ---       ┆ ---       ┆ ---       ┆ ---      │
│           ┆ str       ┆ f64       ┆ str       ┆   ┆ str       ┆ str       ┆ str       ┆ f64      │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ revent 80 ┆ us        ┆ 0.0       ┆ Panasonic ┆ … ┆ White     ┆ Panasonic ┆ [CLS] pro ┆ 0.0      │
│ cfm       ┆           ┆           ┆ FV-20VQ3  ┆   ┆           ┆ FV-20VQ3  ┆ duct_titl ┆          │
│           ┆           ┆           ┆ WhisperCe ┆   ┆           ┆ WhisperCe ┆ e :       ┆          │
│           ┆           ┆           ┆ il…       ┆   ┆           ┆ il…       

In [36]:
# Select desired columns
train_final = filtered_train_df.select(["query", "product_input", "esci_label"])
test_final = filtered_test_df.select(["query", "product_input", "esci_label"])

# Save to CSV
train_final.write_csv("filtered_train.csv")
test_final.write_csv("filtered_test.csv")

In [37]:
from google.colab import files

files.download("filtered_train.csv")
files.download("filtered_test.csv")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>