In [94]:
import json
import torch
import tqdm
import torchmetrics

import numpy as np
import sklearn
import sklearn.model_selection
import warnings

In [95]:
with open("../dataset/full.json") as f:
    full_data = json.load(f)

all_blend_notes = set()
all_single_notes = set()
for d in full_data:
    all_blend_notes.update(d["blend_notes"])
    all_single_notes.update(d["mol1_notes"])
    all_single_notes.update(d["mol2_notes"])
print(f"Found {len(all_blend_notes)} blend notes and {len(all_single_notes)} single notes")

Found 109 blend notes and 496 single notes


In [96]:
def canonize(notes):
    canon = {"":None,"No odor group found for these":None,
    "anisic":"anise","corn chip":"corn",
    "medicinal,":"medicinal"}

    cleaned = set()
    for n in notes:
        # Valid note
        if not n in canon:
            cleaned.add(n)
        # Should be removed
        elif not canon[n]:
            continue
        else:
            cleaned.add(canon[n])
    return cleaned

In [97]:
all_blend_notes = canonize(all_blend_notes)
print(sorted(all_blend_notes))

['acetic', 'acidic', 'alcoholic', 'aldehydic', 'alliaceous', 'amber', 'ammoniacal', 'animal', 'anise', 'aromatic', 'balsamic', 'berry', 'bitter', 'bready', 'brown', 'burnt', 'buttery', 'cabbage', 'camphoreous', 'caramellic', 'celery', 'cheesy', 'chemical', 'cherry', 'chocolate', 'citrus', 'clean', 'cocoa', 'coconut', 'coffee', 'cooling', 'corn', 'coumarinic', 'creamy', 'dairy', 'dusty', 'earthy', 'eggy', 'estery', 'ethereal', 'fatty', 'fermented', 'fishy', 'floral', 'fresh', 'fruity', 'fungal', 'fusel', 'garlic', 'green', 'hay', 'herbal', 'honey', 'jammy', 'juicy', 'lactonic', 'leathery', 'licorice', 'malty', 'marine', 'meaty', 'medicinal', 'melon', 'mentholic', 'minty', 'moldy', 'mossy', 'mushroom', 'musk', 'mustard', 'musty', 'nutty', 'oily', 'onion', 'orris', 'peach', 'phenolic', 'pine', 'potato', 'powdery', 'pungent', 'roasted', 'rooty', 'rummy', 'salty', 'smoky', 'soapy', 'solvent', 'sour', 'spicy', 'sulfurous', 'sweet', 'tarragon', 'thujonic', 'toasted', 'tobacco', 'tomato', 'ton

In [108]:
from InstructorEmbedding import INSTRUCTOR
model = INSTRUCTOR('hkunlp/instructor-large')
prompt = 'Represent the perfume note for canonicalization:'

load INSTRUCTOR_Transformer
max_seq_length  512


In [109]:
blend_inputs = [[prompt,n] for n in sorted(all_blend_notes)]
blend_embeddings = model.encode(blend_inputs)
blend_embeddings.shape

(104, 768)

In [110]:
single_inputs = [[prompt,n] for n in sorted(all_single_notes)]
single_embeddings = model.encode(single_inputs)
single_embeddings.shape

(496, 768)

In [113]:
dists = torch.cdist(torch.from_numpy(single_embeddings),torch.from_numpy(blend_embeddings))
dists.shape

closest = dists.argmin(axis=-1)
canonization_dictionary = {sorted(all_single_notes)[sngl_idx]:sorted(all_blend_notes)[blnd_idx] for sngl_idx, blnd_idx in enumerate(closest)}
canonization_dictionary

{'absinthe': 'alcoholic',
 'acacia': 'coconut',
 'acetic': 'acetic',
 'acetone': 'acetic',
 'acidic': 'acidic',
 'acorn': 'corn',
 'acrylate': 'phenolic',
 'agarwood': 'woody',
 'alcoholic': 'alcoholic',
 'aldehydic': 'aldehydic',
 'algae': 'fungal',
 'alliaceous': 'alliaceous',
 'allspice': 'alliaceous',
 'almond': 'nutty',
 'almond bitter almond': 'bitter',
 'almond roasted almond': 'roasted',
 'almond toasted almond': 'toasted',
 'amber': 'amber',
 'ambergris': 'amber',
 'ambrette': 'amber',
 'ammoniacal': 'ammoniacal',
 'angelica': 'anise',
 'animal': 'animal',
 'anise': 'anise',
 'anisic': 'anise',
 'apple': 'fruity',
 'apple cooked apple': 'roasted',
 'apple dried apple': 'fruity',
 'apple green apple': 'green',
 'apple skin': 'melon',
 'apricot': 'fruity',
 'aromatic': 'aromatic',
 'arrack': 'aromatic',
 'artichoke': 'celery',
 'asparagus': 'vegetable',
 'astringent': 'pungent',
 'autumn': 'floral',
 'bacon': 'meaty',
 'baked': 'bready',
 'balsamic': 'balsamic',
 'banana': 'melo

In [104]:
from transformers import AutoTokenizer, BertModel
import torch

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

prompt = "many woody fresh butter popcorn woody orange biscuit rubbery {note} intense garlic woody comments warm evaluation comments umami grape rain one"

blend_inputs = [prompt.format(note=n) for n in sorted(all_blend_notes)]
single_inputs = [prompt.format(note=n) for n in sorted(all_single_notes)]

In [105]:
blend_embeddings = []
for inp in tqdm.tqdm(blend_inputs):
    inputs = tokenizer(inp, return_tensors="pt")
    embed = model(**inputs).last_hidden_state.squeeze().mean(axis=0)
    blend_embeddings.append(embed)
blend_embeddings = torch.stack(blend_embeddings)
blend_embeddings.shape

100%|███████████████████████████████████████| 104/104 [00:03<00:00, 29.29it/s]


torch.Size([104, 768])

In [106]:
single_embeddings = []
for inp in tqdm.tqdm(single_inputs):
    inputs = tokenizer(inp, return_tensors="pt")
    embed = model(**inputs).last_hidden_state.squeeze().mean(axis=0)
    single_embeddings.append(embed)
single_embeddings = torch.stack(single_embeddings)
single_embeddings.shape

100%|███████████████████████████████████████| 496/496 [00:20<00:00, 23.79it/s]


torch.Size([496, 768])

In [107]:
dists = torch.cdist(single_embeddings,blend_embeddings)
dists.shape

closest = dists.argmin(axis=-1)
for sngl_idx, blnd_idx in enumerate(closest):
    print(sorted(all_single_notes)[sngl_idx],"->",sorted(all_blend_notes)[blnd_idx])

absinthe -> licorice
acacia -> marine
acetic -> acetic
acetone -> acetic
acidic -> acidic
acorn -> orris
acrylate -> acetic
agarwood -> pine
alcoholic -> alcoholic
aldehydic -> aldehydic
algae -> berry
alliaceous -> alliaceous
allspice -> licorice
almond -> coconut
almond bitter almond -> bitter
almond roasted almond -> roasted
almond toasted almond -> toasted
amber -> amber
ambergris -> tarragon
ambrette -> estery
ammoniacal -> ammoniacal
angelica -> berry
animal -> animal
anise -> anise
anisic -> estery
apple -> berry
apple cooked apple -> potato
apple dried apple -> corn
apple green apple -> green
apple skin -> green
apricot -> melon
aromatic -> aromatic
arrack -> fusel
artichoke -> tonka
asparagus -> licorice
astringent -> celery
autumn -> bitter
bacon -> tomato
baked -> roasted
balsamic -> balsamic
banana -> tomato
banana peel -> coconut
banana ripe banana -> coconut
banana unripe banana -> melon
barley roasted barley -> roasted
basil -> berry
bay -> fresh
bean green bean -> corn
