In [1]:
!pip install -q transformers accelerate bitsandbytes

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m28.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import json

# Check for GPU availability
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


GPU available: True
GPU name: Tesla T4
GPU memory: 15.83 GB


In [3]:
def setup_model():
    """Initialize and return the model and tokenizer"""
    model_name = "microsoft/phi-3-mini-4k-instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        load_in_4bit=True
    )
    return model, tokenizer


In [4]:
# Material property knowledge base with categories, units, aliases, and realistic ranges
property_knowledge = {
    # ======== MECHANICAL PROPERTIES ========
    # Strength-related properties
    "strength": "tensile_strength",
    "tensile strength": "tensile_strength",
    "yield strength": "yield_strength",
    "compressive strength": "compressive_strength",
    "flexural strength": "flexural_strength",
    "shear strength": "shear_strength",
    "ultimate strength": "tensile_strength",
    "breaking strength": "tensile_strength",
    "tensile_strength": {"min": 30, "max": 3000, "unit": "MPa", "category": "mechanical"},
    "yield_strength": {"min": 20, "max": 2500, "unit": "MPa", "category": "mechanical"},
    "compressive_strength": {"min": 50, "max": 3500, "unit": "MPa", "category": "mechanical"},
    "flexural_strength": {"min": 20, "max": 1500, "unit": "MPa", "category": "mechanical"},
    "shear_strength": {"min": 10, "max": 1200, "unit": "MPa", "category": "mechanical"},

    # Hardness-related properties
    "hardness": "vickers_hardness",
    "vickers hardness": "vickers_hardness",
    "rockwell hardness": "rockwell_hardness",
    "knoop hardness": "knoop_hardness",
    "mohs hardness": "mohs_hardness",
    "vickers_hardness": {"min": 50, "max": 3000, "unit": "HV", "category": "mechanical"},
    "rockwell_hardness": {"min": 20, "max": 70, "unit": "HRC", "category": "mechanical"},
    "knoop_hardness": {"min": 100, "max": 8000, "unit": "HK", "category": "mechanical"},
    "mohs_hardness": {"min": 1, "max": 10, "unit": "Mohs", "category": "mechanical"},

    # Elasticity and stiffness
    "elastic modulus": "youngs_modulus",
    "young's modulus": "youngs_modulus",
    "stiffness": "youngs_modulus",
    "elasticity": "youngs_modulus",
    "modulus of elasticity": "youngs_modulus",
    "rigidity": "shear_modulus",
    "shear modulus": "shear_modulus",
    "bulk modulus": "bulk_modulus",
    "compressibility": "bulk_modulus",
    "poisson ratio": "poisson_ratio",
    "poisson's ratio": "poisson_ratio",
    "youngs_modulus": {"min": 1, "max": 1200, "unit": "GPa", "category": "mechanical"},
    "shear_modulus": {"min": 0.5, "max": 500, "unit": "GPa", "category": "mechanical"},
    "bulk_modulus": {"min": 1, "max": 400, "unit": "GPa", "category": "mechanical"},
    "poisson_ratio": {"min": 0, "max": 0.5, "unit": "dimensionless", "category": "mechanical"},

    # Ductility and plasticity
    "ductility": "elongation_at_break",
    "elongation": "elongation_at_break",
    "plasticity": "elongation_at_break",
    "malleability": "elongation_at_break",
    "brittleness": "elongation_at_break",  # inverse relationship
    "elongation_at_break": {"min": 0.1, "max": 70, "unit": "%", "category": "mechanical"},

    # Toughness and fracture
    "toughness": "fracture_toughness",
    "impact toughness": "impact_toughness",
    "fracture toughness": "fracture_toughness",
    "impact resistance": "impact_toughness",
    "fracture resistance": "fracture_toughness",
    "fracture_toughness": {"min": 0.5, "max": 200, "unit": "MPa·m^(1/2)", "category": "mechanical"},
    "impact_toughness": {"min": 1, "max": 300, "unit": "J", "category": "mechanical"},

    # Fatigue and creep
    "fatigue strength": "fatigue_strength",
    "fatigue limit": "fatigue_strength",
    "endurance limit": "fatigue_strength",
    "creep resistance": "creep_resistance",
    "creep strength": "creep_strength",
    "fatigue_strength": {"min": 10, "max": 700, "unit": "MPa", "category": "mechanical"},
    "creep_resistance": {"min": 0.1, "max": 10, "unit": "rating", "category": "mechanical"},
    "creep_strength": {"min": 10, "max": 400, "unit": "MPa", "category": "mechanical"},

    # ======== THERMAL PROPERTIES ========
    "melting point": "melting_point",
    "melting temperature": "melting_point",
    "liquidus temperature": "melting_point",
    "solidus temperature": "melting_point",
    "freezing point": "melting_point",
    "melting_point": {"min": -270, "max": 4000, "unit": "°C", "category": "thermal"},

    "glass transition": "glass_transition_temperature",
    "glass transition temperature": "glass_transition_temperature",
    "tg": "glass_transition_temperature",
    "glass_transition_temperature": {"min": -150, "max": 600, "unit": "°C", "category": "thermal"},

    "thermal conductivity": "thermal_conductivity",
    "heat conductivity": "thermal_conductivity",
    "thermal conduction": "thermal_conductivity",
    "thermal insulation": "thermal_conductivity",  # inverse relationship
    "thermal_conductivity": {"min": 0.01, "max": 3000, "unit": "W/(m·K)", "category": "thermal"},

    "thermal expansion": "thermal_expansion_coefficient",
    "cte": "thermal_expansion_coefficient",
    "coefficient of thermal expansion": "thermal_expansion_coefficient",
    "thermal_expansion_coefficient": {"min": -10, "max": 100, "unit": "10^-6/K", "category": "thermal"},

    "specific heat": "specific_heat_capacity",
    "heat capacity": "specific_heat_capacity",
    "specific heat capacity": "specific_heat_capacity",
    "specific_heat_capacity": {"min": 100, "max": 5000, "unit": "J/(kg·K)", "category": "thermal"},

    "thermal diffusivity": "thermal_diffusivity",
    "thermal_diffusivity": {"min": 0.1, "max": 200, "unit": "mm²/s", "category": "thermal"},

    "thermal shock resistance": "thermal_shock_resistance",
    "thermal_shock_resistance": {"min": 1, "max": 10, "unit": "rating", "category": "thermal"},

    # ======== ELECTRONIC PROPERTIES ========
    "band gap": "band_gap",
    "bandgap": "band_gap",
    "energy gap": "band_gap",
    "semiconductor gap": "band_gap",
    "band_gap": {"min": 0, "max": 10, "unit": "eV", "category": "electronic"},

    "conductivity": "electrical_conductivity",
    "electrical conductivity": "electrical_conductivity",
    "electrical conduction": "electrical_conductivity",
    "electrical resistivity": "electrical_resistivity",
    "resistivity": "electrical_resistivity",
    "electrical_conductivity": {"min": 1e-16, "max": 1e8, "unit": "S/m", "category": "electronic"},
    "electrical_resistivity": {"min": 1e-8, "max": 1e16, "unit": "Ω·m", "category": "electronic"},

    "carrier mobility": "carrier_mobility",
    "electron mobility": "electron_mobility",
    "hole mobility": "hole_mobility",
    "carrier_mobility": {"min": 1e-5, "max": 1e5, "unit": "cm²/(V·s)", "category": "electronic"},
    "electron_mobility": {"min": 1e-5, "max": 1e5, "unit": "cm²/(V·s)", "category": "electronic"},
    "hole_mobility": {"min": 1e-5, "max": 1e5, "unit": "cm²/(V·s)", "category": "electronic"},

    "carrier concentration": "carrier_concentration",
    "electron concentration": "carrier_concentration",
    "hole concentration": "carrier_concentration",
    "carrier_concentration": {"min": 1e12, "max": 1e23, "unit": "cm^-3", "category": "electronic"},

    "work function": "work_function",
    "work_function": {"min": 2, "max": 7, "unit": "eV", "category": "electronic"},

    "electron affinity": "electron_affinity",
    "electron_affinity": {"min": 0, "max": 5, "unit": "eV", "category": "electronic"},

    "dielectric constant": "dielectric_constant",
    "permittivity": "dielectric_constant",
    "relative permittivity": "dielectric_constant",
    "dielectric_constant": {"min": 1, "max": 1000, "unit": "ε/ε₀", "category": "electronic"},

    # ======== OPTICAL PROPERTIES ========
    "refractive index": "refractive_index",
    "index of refraction": "refractive_index",
    "optical index": "refractive_index",
    "refractive_index": {"min": 1, "max": 5, "unit": "n", "category": "optical"},

    "absorption coefficient": "absorption_coefficient",
    "optical absorption": "absorption_coefficient",
    "absorption_coefficient": {"min": 0, "max": 1e7, "unit": "cm^-1", "category": "optical"},

    "transparency": "transparency",
    "optical transparency": "transparency",
    "light transmission": "transparency",
    "transparency": {"min": 0, "max": 100, "unit": "%", "category": "optical"},

    "reflectivity": "reflectivity",
    "optical reflectance": "reflectivity",
    "reflectance": "reflectivity",
    "reflectivity": {"min": 0, "max": 100, "unit": "%", "category": "optical"},

    # ======== CHEMICAL PROPERTIES ========
    "formation energy": "formation_energy",
    "enthalpy of formation": "formation_energy",
    "heat of formation": "formation_energy",
    "formation_energy": {"min": -20, "max": 5, "unit": "eV/atom", "category": "chemical"},

    "corrosion resistance": "corrosion_resistance",
    "corrosion rate": "corrosion_rate",
    "oxidation resistance": "oxidation_resistance",
    "chemical stability": "chemical_stability",
    "corrosion_resistance": {"min": 1, "max": 10, "unit": "rating", "category": "chemical"},
    "corrosion_rate": {"min": 0, "max": 1000, "unit": "mm/year", "category": "chemical"},
    "oxidation_resistance": {"min": 1, "max": 10, "unit": "rating", "category": "chemical"},
    "chemical_stability": {"min": 1, "max": 10, "unit": "rating", "category": "chemical"},

    # ======== PHYSICAL PROPERTIES ========
    "density": "density",
    "specific gravity": "density",
    "mass density": "density",
    "lightweight": "density",  # inverse relationship
    "heavy": "density",
    "density": {"min": 0.1, "max": 22.5, "unit": "g/cm³", "category": "physical"},

    "porosity": "porosity",
    "pore volume": "porosity",
    "void fraction": "porosity",
    "porosity": {"min": 0, "max": 99, "unit": "%", "category": "physical"},

    "surface area": "specific_surface_area",
    "specific surface area": "specific_surface_area",
    "surface to volume ratio": "specific_surface_area",
    "specific_surface_area": {"min": 0.1, "max": 3000, "unit": "m²/g", "category": "physical"},

    # ======== MAGNETIC PROPERTIES ========
    "magnetic susceptibility": "magnetic_susceptibility",
    "susceptibility": "magnetic_susceptibility",
    "magnetic_susceptibility": {"min": -1, "max": 1e5, "unit": "χ", "category": "magnetic"},

    "saturation magnetization": "saturation_magnetization",
    "magnetic saturation": "saturation_magnetization",
    "saturation_magnetization": {"min": 0, "max": 2.5, "unit": "T", "category": "magnetic"},

    "coercivity": "coercivity",
    "magnetic coercivity": "coercivity",
    "coercive force": "coercivity",
    "coercivity": {"min": 0, "max": 5e6, "unit": "A/m", "category": "magnetic"},

    "curie temperature": "curie_temperature",
    "curie point": "curie_temperature",
    "curie_temperature": {"min": 0, "max": 1400, "unit": "°C", "category": "magnetic"}
}

# Qualitative mappings with granular descriptors
qualitative_mappings = {
    # General intensity descriptors
    "extremely": {"min_factor": 0.95, "max_factor": 0.98},
    "very": {"min_factor": 0.80, "max_factor": 0.95},
    "high": {"min_factor": 0.70, "max_factor": 0.90},
    "good": {"min_factor": 0.60, "max_factor": 0.80},
    "moderate": {"min_factor": 0.40, "max_factor": 0.60},
    "medium": {"min_factor": 0.40, "max_factor": 0.60},
    "average": {"min_factor": 0.40, "max_factor": 0.60},
    "fair": {"min_factor": 0.30, "max_factor": 0.50},
    "low": {"min_factor": 0.10, "max_factor": 0.30},
    "poor": {"min_factor": 0.05, "max_factor": 0.20},
    "very low": {"min_factor": 0.01, "max_factor": 0.10},
    "minimal": {"min_factor": 0.00, "max_factor": 0.05},

    # Specific combination descriptors
    "ultrahigh": {"min_factor": 0.90, "max_factor": 0.99},
    "ultra-high": {"min_factor": 0.90, "max_factor": 0.99},
    "ultra high": {"min_factor": 0.90, "max_factor": 0.99},
    "very high": {"min_factor": 0.80, "max_factor": 0.95},
    "extremely high": {"min_factor": 0.95, "max_factor": 0.99},
    "exceptionally high": {"min_factor": 0.95, "max_factor": 0.99},
    "reasonably high": {"min_factor": 0.65, "max_factor": 0.80},
    "relatively high": {"min_factor": 0.60, "max_factor": 0.80},
    "moderately high": {"min_factor": 0.55, "max_factor": 0.75},
    "somewhat high": {"min_factor": 0.50, "max_factor": 0.70},
    "not too high": {"min_factor": 0.40, "max_factor": 0.60},
    "not very high": {"min_factor": 0.30, "max_factor": 0.50},
    "not high": {"min_factor": 0.10, "max_factor": 0.40},

    "very low": {"min_factor": 0.01, "max_factor": 0.10},
    "extremely low": {"min_factor": 0.00, "max_factor": 0.05},
    "ultra low": {"min_factor": 0.00, "max_factor": 0.03},
    "ultra-low": {"min_factor": 0.00, "max_factor": 0.03},
    "ultralow": {"min_factor": 0.00, "max_factor": 0.03},
    "exceptionally low": {"min_factor": 0.00, "max_factor": 0.05},
    "negligible": {"min_factor": 0.00, "max_factor": 0.02},

    # Domain-specific descriptors
    "excellent insulation": {"min_factor": 0.00, "max_factor": 0.10},
    "superior conductivity": {"min_factor": 0.90, "max_factor": 1.00},

    # Comparative descriptors
    "higher than average": {"min_factor": 0.60, "max_factor": 0.80},
    "lower than average": {"min_factor": 0.20, "max_factor": 0.40},
    "above average": {"min_factor": 0.60, "max_factor": 0.80},
    "below average": {"min_factor": 0.20, "max_factor": 0.40},

    # Special cases for inversely related properties
    "highly insulating": {"min_factor": 0.00, "max_factor": 0.10},
    "lightweight": {"min_factor": 0.05, "max_factor": 0.30},
    "ultralight": {"min_factor": 0.00, "max_factor": 0.10},
    "heavy": {"min_factor": 0.70, "max_factor": 0.95},
    "brittle": {"min_factor": 0.00, "max_factor": 0.10},
}

# Properties where qualitative descriptors have inverse meaning
inverse_properties = {
    "thermal_insulation": "thermal_conductivity",
    "electrical_insulation": "electrical_conductivity",
    "lightweight": "density",
    "brittleness": "elongation_at_break",
    "weight": "density",
    "thermal_resistance": "thermal_conductivity",
    "electrical_resistance": "electrical_conductivity"
}

# Patterns for extracting numeric values with units
numeric_patterns = {
    "MPa": {"pattern": r"(\d+(?:\.\d+)?)\s*(?:MPa|mpa|megapascals?)", "property": "tensile_strength"},
    "GPa": {"pattern": r"(\d+(?:\.\d+)?)\s*(?:GPa|gpa|gigapascals?)", "property": "youngs_modulus"},
    "g/cm3": {"pattern": r"(\d+(?:\.\d+)?)\s*(?:g\/cm3|g\/cm\^3|g\/cc|g\/ml|grams?\s*per\s*(?:cubic\s*)?(?:cm|centimeters?|milliliters?|cc))", "property": "density"},
    "eV": {"pattern": r"(\d+(?:\.\d+)?)\s*(?:eV|ev|electron\s*volts?)", "property": "band_gap"},
    "°C": {"pattern": r"(\d+(?:\.\d+)?)\s*(?:°C|C|degrees?\s*(?:celsius|centigrade)?)", "property": "melting_point"},
    "%": {"pattern": r"(\d+(?:\.\d+)?)\s*(?:%|percent)", "property": "elongation_at_break"}
}

In [5]:
# Defining various properties
def preprocess_query(query):
    """
    Preprocess the natural language query to standardize formatting and handle common variations.
    """
    # Convert to lowercase for easier matching
    processed_query = query.lower()

    # Replace common abbreviations and variations
    replacements = {
        "ymodulus": "young's modulus",
        "young modulus": "young's modulus",
        "e-modulus": "young's modulus",
        "emodulus": "young's modulus",
        "tensile mod": "young's modulus",
        "elastic mod": "young's modulus",
        "fracture tough": "fracture toughness",
        "k1c": "fracture toughness",
        "kic": "fracture toughness",
        "yield str": "yield strength",
        "tensile str": "tensile strength",
        "uts": "tensile strength",
        "ultimate tensile": "tensile strength",
        "elong": "elongation",
        "melting temp": "melting point",
        "cond": "conductivity",
        "therm cond": "thermal conductivity",
        "elec cond": "electrical conductivity",
        "bandgap": "band gap",
        "eg": "band gap",
        "tc": "thermal conductivity",
        "tm": "melting point",
        "ef": "formation energy",
        "bm": "bulk modulus",
        "ym": "young's modulus",
    }

    for old, new in replacements.items():
        processed_query = processed_query.replace(old, new)

    return processed_query


In [6]:
# ========== CORE PROCESSING FUNCTIONS ==========

def preprocess_query(query):
    """Standardize formatting and handle common variations in the query text"""
    processed_query = query.lower()

    replacements = {
        "ymodulus": "young's modulus",
        "young modulus": "young's modulus",
        "e-modulus": "young's modulus",
        "emodulus": "young's modulus",
        "tensile mod": "young's modulus",
        "elastic mod": "young's modulus",
        "fracture tough": "fracture toughness",
        "k1c": "fracture toughness",
        "kic": "fracture toughness",
        "yield str": "yield strength",
        "tensile str": "tensile strength",
        "uts": "tensile strength",
        "ultimate tensile": "tensile strength",
        "elong": "elongation",
        "melting temp": "melting point",
        "cond": "conductivity",
        "therm cond": "thermal conductivity",
        "elec cond": "electrical conductivity",
        "bandgap": "band gap",
        "eg": "band gap",
        "tc": "thermal conductivity",
        "tm": "melting point",
        "ef": "formation energy",
        "bm": "bulk modulus",
        "ym": "young's modulus",
    }

    for old, new in replacements.items():
        processed_query = processed_query.replace(old, new)

    return processed_query

def extract_explicit_numeric_values(query):
    """Extract explicit numeric values mentioned in the query"""
    explicit_constraints = {}
    processed_query = preprocess_query(query)

    for unit, pattern_info in numeric_patterns.items():
        matches = re.findall(pattern_info["pattern"], processed_query)
        if matches:
            property_name = pattern_info["property"]
            values = [float(match) for match in matches]

            # Look for ranges (between X and Y)
            range_pattern = f"between\\s+({pattern_info['pattern']})\\s+and\\s+({pattern_info['pattern']})"
            range_matches = re.findall(range_pattern, processed_query)

            if range_matches:
                for match in range_matches:
                    min_val = float(match[0])
                    max_val = float(match[1])
                    explicit_constraints[property_name] = {"min": min_val, "max": max_val}
            elif len(values) >= 2:
                # If multiple values found, use min and max
                explicit_constraints[property_name] = {"min": min(values), "max": max(values)}
            elif len(values) == 1:
                # If only one value found, look for modifiers
                value = values[0]

                # Look for "above X", "greater than X", etc.
                if re.search(f"(?:above|greater than|more than|over|exceeding)\\s+{value}\\s*{unit}", processed_query):
                    prop_info = property_knowledge.get(property_name, {})
                    max_possible = prop_info.get("max", value * 2)
                    explicit_constraints[property_name] = {"min": value, "max": max_possible}

                # Look for "below X", "less than X", etc.
                elif re.search(f"(?:below|less than|under|not exceeding)\\s+{value}\\s*{unit}", processed_query):
                    prop_info = property_knowledge.get(property_name, {})
                    min_possible = prop_info.get("min", 0)
                    explicit_constraints[property_name] = {"min": min_possible, "max": value}

                # If no modifiers, use as both min and max with small range
                else:
                    margin = value * 0.1  # 10% margin
                    explicit_constraints[property_name] = {"min": value - margin, "max": value + margin}

    return explicit_constraints

def extract_qualitative_descriptions(query):
    """Extract qualitative descriptions of properties from the query"""
    qualitative_constraints = {}
    processed_query = preprocess_query(query)

    # For each property in our knowledge base
    for prop_name, prop_info in property_knowledge.items():
        # Skip aliases (string values)
        if isinstance(prop_info, str):
            continue

        # Skip properties without category (likely not well-defined)
        if not isinstance(prop_info, dict) or "category" not in prop_info:
            continue

        # Check for direct mentions in query
        if prop_name in processed_query:
            # Look for qualitative descriptors around this property
            for descriptor, factors in qualitative_mappings.items():
                pattern = f"{descriptor}\\s+{prop_name}|{prop_name}\\s+{descriptor}"
                if re.search(pattern, processed_query):
                    qualitative_constraints[prop_name] = descriptor
                    break

    # Check for aliases too
    for alias, canonical in property_knowledge.items():
        if isinstance(canonical, str) and alias in processed_query:
            # Look for qualitative descriptors around this alias
            for descriptor, factors in qualitative_mappings.items():
                pattern = f"{descriptor}\\s+{alias}|{alias}\\s+{descriptor}"
                if re.search(pattern, processed_query):
                    qualitative_constraints[canonical] = descriptor
                    break

    return qualitative_constraints

def process_query_with_llm(query, model, tokenizer):
    """Process a query with the LLM to extract constraints"""
    processed_query = preprocess_query(query)

    prompt = f"""
    You are a materials science expert. Convert the following material science query into a structured constraints dictionary.

    Query: {processed_query}

    Output the result as a Python dictionary with the format:
    constraints = {{
        'property_name': {{'min': value, 'max': value}},
        'property_name': {{'min': value, 'max': value}},
        ...
    }}

    Guidelines:
    1. Only include properties that are explicitly mentioned or strongly implied in the query.
    2. Use realistic value ranges for common material properties:
       - tensile_strength: 30-3000 MPa
       - youngs_modulus: 1-1200 GPa
       - elongation_at_break: 0.1-70%
       - band_gap: 0-10 eV
       - melting_point: -270-4000 °C
       - density: 0.1-22.5 g/cm³
       - thermal_conductivity: 0.01-3000 W/(m·K)
       - electrical_conductivity: 10^-16 to 10^8 S/m
    3. Use standard property names (e.g., 'youngs_modulus' instead of 'elasticity').
    4. For qualitative terms like "high" or "low", translate to appropriate numerical ranges.

    The output should ONLY contain the Python dictionary, nothing else.
    """

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(
        inputs.input_ids,
        max_length=512,
        temperature=0.1,
        top_p=0.9,
        do_sample=True
    )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the dictionary part from the response
    try:
        # First attempt: standard dictionary format
        dict_match = re.search(r'constraints\s*=\s*({[^{}]*(?:{[^{}]*}[^{}]*)*})', response, re.DOTALL)
        if dict_match:
            dict_str = dict_match.group(1)
        else:
            # Second attempt: look for any dictionary-like structure
            dict_match = re.search(r'({[^{}]*(?:{[^{}]*}[^{}]*)*})', response, re.DOTALL)
            if dict_match:
                dict_str = dict_match.group(1)
            else:
                # Third attempt: try to find line-by-line property definitions
                prop_pattern = r"'(\w+)':\s*{'min':\s*(-?\d+(?:\.\d+)?),\s*'max':\s*(-?\d+(?:\.\d+)?)}"
                prop_matches = re.findall(prop_pattern, response)

                if prop_matches:
                    # Manually construct dictionary
                    constraints = {}
                    for prop, min_val, max_val in prop_matches:
                        constraints[prop] = {
                            'min': float(min_val),
                            'max': float(max_val)
                        }
                    return constraints
                else:
                    return {"error": "Parsing failed", "raw_output": response}

        # Clean up the dictionary string
        dict_str = dict_str.replace("'", '"')  # Replace single quotes with double quotes
        dict_str = re.sub(r'(\w+):', r'"\1":', dict_str)  # Add quotes around property names

        # Fix common JSON formatting issues
        dict_str = dict_str.replace("None", "null")
        dict_str = dict_str.replace("True", "true")
        dict_str = dict_str.replace("False", "false")

        # Clean trailing commas
        dict_str = re.sub(r',\s*}', '}', dict_str)
        dict_str = re.sub(r',\s*]', ']', dict_str)

        # Try to parse the JSON
        try:
            constraints = json.loads(dict_str)
        except json.JSONDecodeError:
            # More aggressive cleaning for complex cases
            dict_str = re.sub(r'//.*?\n', '\n', dict_str)  # Remove comments
            dict_str = re.sub(r'/\*.*?\*/', '', dict_str, flags=re.DOTALL)  # Remove block comments

            # Try to cast numeric strings to actual numbers
            dict_str = re.sub(r'"(-?\d+(?:\.\d+)?)"', r'\1', dict_str)

            # Final attempt to parse JSON
            try:
                constraints = json.loads(dict_str)
            except json.JSONDecodeError as e:
                return {"error": f"JSON parsing failed: {e}", "raw_output": response}

        return constraints

    except Exception as e:
        return {"error": str(e), "raw_output": response}

def enhance_constraints(query, llm_constraints):
    """Enhance constraints by combining LLM output with knowledge base"""
    # Get explicit numeric constraints and qualitative descriptions
    explicit_constraints = extract_explicit_numeric_values(query)
    qualitative_descriptions = extract_qualitative_descriptions(query)

    # If LLM failed completely, still try to proceed with other data
    if "error" in llm_constraints:
        llm_constraints = {}

    # Collect all mentioned properties from all sources
    all_properties = set(list(explicit_constraints.keys()) +
                         list(qualitative_descriptions.keys()) +
                         list(llm_constraints.keys()))

    # Initialize enhanced constraints dictionary
    enhanced_constraints = {}

    # Process each property
    for prop in all_properties:
        # Normalize property name (handle aliases)
        canonical_prop = prop
        if prop in property_knowledge and isinstance(property_knowledge[prop], str):
            canonical_prop = property_knowledge[prop]

        # Get standard range for this property
        std_range = None
        if canonical_prop in property_knowledge and isinstance(property_knowledge[canonical_prop], dict):
            std_range = property_knowledge[canonical_prop]

        # Initialize with standard range if available
        if std_range:
            min_val = std_range.get('min', 0)
            max_val = std_range.get('max', 0)
        else:
            # If no standard range, use defaults or LLM values
            if canonical_prop in llm_constraints:
                min_val = llm_constraints[canonical_prop].get('min', 0)
                max_val = llm_constraints[canonical_prop].get('max', 0)
            else:
                # No information available for this property
                continue

        # Update with explicit values if available
        if canonical_prop in explicit_constraints:
            explicit_min = explicit_constraints[canonical_prop].get('min')
            explicit_max = explicit_constraints[canonical_prop].get('max')

            if explicit_min is not None:
                min_val = explicit_min
            if explicit_max is not None:
                max_val = explicit_max

        # Update with LLM values if available and not overridden by explicit values
        elif canonical_prop in llm_constraints:
            llm_min = llm_constraints[canonical_prop].get('min')
            llm_max = llm_constraints[canonical_prop].get('max')

            if llm_min is not None and canonical_prop not in explicit_constraints:
                min_val = llm_min
            if llm_max is not None and canonical_prop not in explicit_constraints:
                max_val = llm_max

        # Apply qualitative adjustments if available
        if canonical_prop in qualitative_descriptions and std_range:
            descriptor = qualitative_descriptions[canonical_prop]

            # Check if this is an inverse property
            is_inverse = False
            for inv_prop, mapped_prop in inverse_properties.items():
                if canonical_prop == mapped_prop or prop == inv_prop:
                    is_inverse = True
                    break

            # Get adjustment factors based on descriptor
            if descriptor in qualitative_mappings:
                factors = qualitative_mappings[descriptor]

                # For regular properties (high value is good)
                if not is_inverse:
                    range_span = std_range['max'] - std_range['min']

                    # Adjust min and max based on the qualitative descriptor
                    min_factor = factors.get('min_factor', 0.5)
                    max_factor = factors.get('max_factor', 0.8)

                    min_val = std_range['min'] + (range_span * min_factor)
                    max_val = std_range['min'] + (range_span * max_factor)
                else:
                    # For inverse properties (low value is good)
                    range_span = std_range['max'] - std_range['min']

                    # Invert the factors for inverse properties
                    min_factor = 1 - factors.get('max_factor', 0.8)
                    max_factor = 1 - factors.get('min_factor', 0.5)

                    min_val = std_range['min'] + (range_span * min_factor)
                    max_val = std_range['min'] + (range_span * max_factor)

        # Add the property to final constraints
        enhanced_constraints[canonical_prop] = {
            'min': round(min_val, 3),
            'max': round(max_val, 3)
        }

        # Add units and category if available
        if std_range:
            if 'unit' in std_range:
                enhanced_constraints[canonical_prop]['unit'] = std_range['unit']
            if 'category' in std_range:
                enhanced_constraints[canonical_prop]['category'] = std_range['category']

    # Validate the constraints (ensure min <= max)
    for prop, values in enhanced_constraints.items():
        if values['min'] > values['max']:
            # Swap min and max if they're inverted
            enhanced_constraints[prop]['min'], enhanced_constraints[prop]['max'] = \
                enhanced_constraints[prop]['max'], enhanced_constraints[prop]['min']

    return enhanced_constraints

def validate_constraints(constraints):
    """Validate constraints to ensure they are physically realistic"""
    validated = {}

    for prop, values in constraints.items():
        # Skip non-dictionary values or entries with errors
        if not isinstance(values, dict) or 'error' in values:
            validated[prop] = values
            continue

        # Ensure min and max exist
        if 'min' not in values or 'max' not in values:
            continue

        min_val = values['min']
        max_val = values['max']

        # Fix reversed min/max
        if min_val > max_val:
            min_val, max_val = max_val, min_val

        # Check against known physical limits if available
        if prop in property_knowledge and isinstance(property_knowledge[prop], dict):
            std_range = property_knowledge[prop]
            absolute_min = std_range.get('min')
            absolute_max = std_range.get('max')

            if absolute_min is not None and min_val < absolute_min:
                min_val = absolute_min

            if absolute_max is not None and max_val > absolute_max:
                max_val = absolute_max

        # Create validated entry
        validated[prop] = {
            'min': min_val,
            'max': max_val
        }

        # Copy additional fields
        for key, value in values.items():
            if key not in ['min', 'max']:
                validated[prop][key] = value

    return validated

def clean_output_format(constraints):
    """Clean the output format to match the desired simple structure"""
    cleaned = {}

    for prop, values in constraints.items():
        if not isinstance(values, dict) or 'min' not in values or 'max' not in values:
            continue

        # Only keep min and max in the output
        cleaned[prop] = {
            'min': values['min'],
            'max': values['max']
        }

    return cleaned


In [7]:
# ========== MAIN INTERFACE FUNCTION ==========

def ensure_required_properties(constraints):
    """
    Ensure that required properties for MEGNet+VAE model are present in the constraints dictionary.

    Args:
        constraints (dict): Constraints dictionary from NLP pipeline

    Returns:
        dict: Constraints dictionary with required properties
    """
    required_properties = {
        'band_gap': {'min': 0.5, 'max': 2.5},  # Default semiconducting range
        'formation_energy': {'min': -2.0, 'max': -0.1},  # Default stability range
        'bulk_modulus': {'min': 50, 'max': 200}  # Default mechanical property range
    }

    # Add required properties if missing
    for prop, default_range in required_properties.items():
        if prop not in constraints:
            constraints[prop] = default_range

    return constraints

def material_constraint_converter(query, model=None, tokenizer=None, use_knowledge_base=True, include_metadata=False):
    """
    Convert a natural language query about material properties into a constraints dictionary.

    Args:
        query (str): The natural language query
        model: The LLM model (if None, will be initialized)
        tokenizer: The tokenizer (if None, will be initialized)
        use_knowledge_base (bool): Whether to enhance results with material science knowledge
        include_metadata (bool): Whether to include units and categories in output

    Returns:
        dict: A dictionary of material property constraints
    """
    # Initialize model if not provided
    if model is None or tokenizer is None:
        model, tokenizer = setup_model()

    # Get constraints from LLM
    llm_constraints = process_query_with_llm(query, model, tokenizer)

    # Enhance with knowledge base if requested
    if use_knowledge_base:
        constraints = enhance_constraints(query, llm_constraints)
    else:
        constraints = llm_constraints

    # Validate the constraints
    validated = validate_constraints(constraints)

    # Format the output
    if include_metadata:
        output = validated
    else:
        output = clean_output_format(validated)

    # Ensure required properties for MEGNet+VAE integration
    output = ensure_required_properties(output)

    return output

def batch_process_queries(queries, model=None, tokenizer=None, use_knowledge_base=True, include_metadata=False):
    """Process a batch of queries and return constraints for each"""
    # Initialize model if not provided
    if model is None or tokenizer is None:
        model, tokenizer = setup_model()

    results = {}

    for query in queries:
        results[query] = material_constraint_converter(
            query,
            model=model,
            tokenizer=tokenizer,
            use_knowledge_base=use_knowledge_base,
            include_metadata=include_metadata
        )

    return results


In [8]:
# ========== USAGE EXAMPLE ==========
if __name__ == "__main__":
    # Example usage
    query = "I need to create a material with high tensile strength and ductility"

    # Initialize the model (do this once)
    model, tokenizer = setup_model()

    # Process the query
    constraints = material_constraint_converter(query, model, tokenizer)

    # Print the result
    print(f"Query: {query}")
    print(f"Constraints: {json.dumps(constraints, indent=2)}")

tokenizer_config.json:   0%|          | 0.00/3.44k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.94M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/306 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/599 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors.index.json:   0%|          | 0.00/16.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Query: I need to create a material with high tensile strength and ductility
Constraints: {
  "tensile_strength": {
    "min": 2109.0,
    "max": 2703.0
  },
  "band_gap": {
    "min": 0.5,
    "max": 2.5
  },
  "formation_energy": {
    "min": -2.0,
    "max": -0.1
  },
  "bulk_modulus": {
    "min": 50,
    "max": 200
  }
}


In [9]:
def prepare_for_megnet_vae(constraints):
    """Filter constraints to only include properties the model was trained on"""
    megnet_constraints = {
        'band_gap': constraints.get('band_gap', {'min': 0.5, 'max': 2.5}),
        'formation_energy': constraints.get('formation_energy', {'min': -2.0, 'max': -0.1}),
        'bulk_modulus': constraints.get('bulk_modulus', {'min': 50, 'max': 200})
    }
    return megnet_constraints

# Process the query
constraints = material_constraint_converter(query)

# Filter to only include properties the model knows about
megnet_constraints = prepare_for_megnet_vae(constraints)

# Now use with your model
generated_materials = generate_materials_with_constraints(
    vae, property_scaler, recovery, constraints=megnet_constraints
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

NameError: name 'generate_materials_with_constraints' is not defined