In [None]:
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors
from PIL import Image
import io
from rdkit.Chem import rdMolDescriptors  
from rdkit.Chem.Draw import rdMolDraw2D
from PIL import Image
from io import BytesIO
import cairosvg

In [None]:
# 1. 加载模型
model_name = "merged_model"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    device_map="auto", 
    dtype=torch.float16)
model.eval()

In [None]:

def predict(smiles_input):
    global model, tokenizer
    torch.cuda.empty_cache()
    mol = Chem.MolFromSmiles(smiles_input)
    if mol is None:
        return "Invalid SMILES", "-", None, "Invalid SMILES input"
    formula = rdMolDescriptors.CalcMolFormula(mol)
    mol_weight = round(rdMolDescriptors.CalcExactMolWt(mol), 3)
    size=(400, 400)
    drawer = rdMolDraw2D.MolDraw2DSVG(size[0], size[1])
    opts = drawer.drawOptions()
    opts.useBWAtomPalette()
    opts.bondLineWidth = 4.0   
    drawer.DrawMolecule(mol)
    drawer.FinishDrawing()
    svg_string = drawer.GetDrawingText()
    png_data = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'))
    img = Image.open(io.BytesIO(png_data))
    smiles_input_CH = "分子的smiles表示:" + smiles_input
    device = "cuda"
    prompt = (
        f"<|im_start|>system\n"
        f"你是一名分子性质预测专家。请根据输入分子的smiles表示，"
        f"预测该分子的最高占据分子轨道能级、最低未占分子轨道能级、带隙这三个性质的数值。"
        f"并严格按照如下例子格式输出，最高占据分子轨道能级:数值;最低未占分子轨道能级:数值;带隙:数值"
        f"（数值保留到小数点后三位即可,不要输出任何文字解释、单位或换行符）。"
        f"<|im_end|>\n"
        f"<|im_start|>user\n{smiles_input_CH}<|im_end|>\n"
        f"<|im_start|>assistant\n"
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    generated_ids = model.generate(
        inputs.input_ids,
        max_new_tokens=64,
        do_sample=False,
        temperature=0.1,
        pad_token_id=tokenizer.pad_token_id
    )
    output_ids = generated_ids[0][len(inputs.input_ids[0]):]
    response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    try:
        parts = response.split(";")
        hom = (parts[0].split(":")[1])
        lum = (parts[1].split(":")[1])
        gap = (parts[2].split(":")[1])
    except Exception:
        hom, lum, gap = "-", "-", "-"
    response_english = (
        f"HOMO:{hom}\n"
        f"LUMO:{lum}\n"
        f"Band Gap:{gap}"
    )
    return formula, mol_weight, img, response_english


In [None]:
with gr.Blocks(
    theme=gr.themes.Soft(primary_hue="blue"),
    css="""
    .main-container {
        max-width: 800px;
        margin: 15px auto;
        padding: 20px;
        background-color: #fafafa;
        border-radius: 12px;
        box-shadow: 0 4px 10px rgba(0,0,0,0.1);
    }
    .header {
        text-align: center;
        font-size: 22px;
        font-weight: bold;
        background-color: #e8e8e8;
        padding: 10px;
        border-radius: 8px;
        margin-bottom: 12px;
    }
    .info {
        background-color: #f0f7ff;
        border-left: 4px solid #2196f3;
        padding: 10px 15px;
        font-size: 13px;
        line-height: 1.6;
        border-radius: 6px;
        margin-bottom: 15px;
    }
    textarea {
        font-weight: bold !important;
        background-color: #f2f2f2 !important;
        border: 1px solid #333 !important;
        border-radius: 8px !important;
    }
    .gr-row {
        justify-content: flex-start !important;
        gap: 6px !important;
        margin-top: 4px;
        margin-bottom: 8px;
    }
    .gr-button, button {
        border-radius: 5px !important;
        font-size: 12px !important;
        padding: 3px 8px !important;
        width: 150px !important;          
        min-width: unset !important;     
        max-width: 150px !important;
        transition: all 0.2s ease-in-out;
    }

    button.secondary, .gr-button.secondary {
        background-color: white !important;
        border: 1px solid #bbb !important;
        color: #333 !important;
        font-weight: 500 !important;
        box-shadow: 0 1px 2px rgba(0,0,0,0.05);
    }
    button.secondary:hover, .gr-button.secondary:hover {
        background-color: #f7f7f7 !important;
        transform: scale(1.02);
        border-color: #999 !important;
    }

    #button-row {
        margin-top: -6px !important;  
    }
    #mol-img {
        width: 420px !important; 
        height: 328.9px !important;
        display: flex !important;
        justify-content: center !important;
        align-items: center !important;
        background: white !important;
        border-radius: 4px !important;
        overflow: hidden !important;
        padding: 10px !important;
    }
    #mol-img img {
        width: 300px !important; 
        height: 300px !important; 
        object-fit: contain !important; 
        object-position: center !important;
        max-width: none !important;
        max-height: none !important;
        margin: 0 !important;
        padding: 0 !important;
    }
    """
) as demo:

    with gr.Column(elem_classes=["main-container"]):
        gr.Markdown("<div class='header'>Molecular Property Prediction AI</div>")
        gr.Markdown("""
        <div class='info'>
         1. This website aims to rapidly predict the HOMO, LUMO, and Band Gap properties of molecules using fine-tuned large language models based on their SMILES structures.<br>
         2. Code and data are available at <a href="https://github.com/ggyy020628/molpredai" target="_blank">GitHub</a>.
        </div>
        """)
        smiles_input = gr.Textbox(
            label="Input Molecule SMILES",
            placeholder="e.g., NC(CC#C)C(F)(F)F",
            lines=1.1
        )
        with gr.Row(elem_id="button-row"):
            submit_btn = gr.Button("Submit and Predict", elem_classes=["secondary"])
        with gr.Row():
            with gr.Column(scale=2):
                formula_out = gr.Textbox(
                    label="Molecular Formula",
                    placeholder="C5H6F3N",
                    interactive=False,
                    lines=1.1
                )
                mw_out = gr.Textbox(
                    label="Molecular Weight (g/mol)",
                    placeholder="137.045",
                    interactive=False,
                    lines=1.1
                )
                model_output = gr.Textbox(
                    label="AI Prediction Output (eV)",
                    interactive=False,
                    placeholder="HOMO:-6.980;\nLUMO:-1.235;\nBand Gap:8.215;",
                    lines=3
                )
            with gr.Column(scale=4):
                mol_img = gr.Image(
                    label = "Molecular Structure",
                    type = "pil",
                    show_download_button=False,   
                    show_fullscreen_button=False,    
                    elem_id = "mol-img"
                )
        submit_btn.click(
            fn=predict,
            inputs=smiles_input,
            outputs=[formula_out, mw_out, mol_img, model_output]
        )
