In [18]:
# Load models and processors 

import os
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [19]:
import numpy as np
import json
def list_to_image(img_list, size=128):
    """
    Convert a list to a 2D image of given size.
    """
    return np.array(json.loads(img_list)).reshape(size, size)

In [20]:
from clip import clip

model, preprocess = clip.load('ViT-B/32', device='cuda:1')
model = torch.load("./model/clip_finetuned.pt", weights_only=False)
processor = torch.load("./model/fine_tuned_clip_processor.pt", weights_only=False)

In [21]:
text_inputs = [
    "The chemical formula is ZnO. The mbj_bandgap value is 3.37.",
    "The chemical formula is GaN. The mbj_bandgap value is 3.44.",
    "The chemical formula is SiC. The mbj_bandgap value is 2.36.",
    "The chemical formula is AlN. The mbj_bandgap value is 6.20.",
    "The chemical formula is InP. The mbj_bandgap value is 1.42.",
    "The chemical formula is CdTe. The mbj_bandgap value is 1.50.",
    "The chemical formula is PbS. The mbj_bandgap value is 0.41.",
    "The chemical formula is SnSe. The mbj_bandgap value is 0.90.",
    "The chemical formula is Bi2Te3. The mbj_bandgap value is 0.30.",
    "The chemical formula is MoS2. The mbj_bandgap value is 1.80.",
    "The chemical formula is WS2. The mbj_bandgap value is 2.10.",
    "The chemical formula is Cu2O. The mbj_bandgap value is 2.17.",
    "The chemical formula is Fe2O3. The mbj_bandgap value is 2.10.",
    "The chemical formula is TiO2. The mbj_bandgap value is 3.20.",
    "The chemical formula is ZrO2. The mbj_bandgap value is 5.00.",
    "The chemical formula is HfO2. The mbj_bandgap value is 5.30.",
    "The chemical formula is SrTiO3. The mbj_bandgap value is 3.20.",
    "The chemical formula is BaTiO3. The mbj_bandgap value is 3.18.",
    "The chemical formula is NaNbO3. The mbj_bandgap value is 3.90.",
    "The chemical formula is KNbO3. The mbj_bandgap value is 3.65.",
    "The chemical formula is MgO. The mbj_bandgap value is 7.80.",
    "The chemical formula is CaO. The mbj_bandgap value is 7.10.",
    "The chemical formula is BeO. The mbj_bandgap value is 10.60.",
    "The chemical formula is SrO. The mbj_bandgap value is 5.90.",
    "The chemical formula is LaAlO3. The mbj_bandgap value is 5.60.",
    "The chemical formula is Y2O3. The mbj_bandgap value is 5.50.",
    "The chemical formula is Al2O3. The mbj_bandgap value is 8.70.",
    "The chemical formula is Te3SeO8. The mbj_bandgap value is 4.382.",
    "The chemical formula is SiO2. The mbj_bandgap value is 9.00.",
    "The chemical formula is GeO2. The mbj_bandgap value is 5.40.",
    "The chemical formula is SnO2. The mbj_bandgap value is 3.60.",
    "The chemical formula is Sb2O3. The mbj_bandgap value is 3.20.",
    "The chemical formula is Bi2O3. The mbj_bandgap value is 2.80.",
    "The chemical formula is CeO2. The mbj_bandgap value is 3.10.",
    "The chemical formula is VO2. The mbj_bandgap value is 0.70.",
    "The chemical formula is Nb2O5. The mbj_bandgap value is 3.50.",
    "The chemical formula is Ta2O5. The mbj_bandgap value is 4.20.",
    "The chemical formula is WO3. The mbj_bandgap value is 2.60.",
    "The chemical formula is Cr2O3. The mbj_bandgap value is 3.40.",
    "The chemical formula is MnO2. The mbj_bandgap value is 1.30.",
    "The chemical formula is Co3O4. The mbj_bandgap value is 1.60.",
    "The chemical formula is NiO. The mbj_bandgap value is 4.00.",
    "The chemical formula is CuO. The mbj_bandgap value is 1.70.",
    "The chemical formula is ZnS. The mbj_bandgap value is 3.60.",
    "The chemical formula is CdS. The mbj_bandgap value is 2.42.",
    "The chemical formula is HgS. The mbj_bandgap value is 2.10.",
    "The chemical formula is InAs. The mbj_bandgap value is 0.36.",
    "The chemical formula is GaAs. The mbj_bandgap value is 1.52.",
    "The chemical formula is GaSb. The mbj_bandgap value is 0.72.",
    "The chemical formula is InSb. The mbj_bandgap value is 0.17."
]
text_inputs_tokens = clip.tokenize(text_inputs).to("cuda:1")

In [22]:
import pandas as pd

data_train = pd.read_csv('./dataset/alpaca_mbj_bandgap_train.csv') # Load your training data if needed


In [23]:
from PIL import Image


In [24]:
sample_img, sample_text = data_train['image'][0], data_train['input'][0]
sample_img = preprocess(Image.fromarray(list_to_image(sample_img))).unsqueeze(0).to("cuda:1")

In [25]:
# Calculate features
model.to("cuda:1")
with torch.no_grad():
    text_features = model.encode_text(text_inputs_tokens)
    image_features = model.encode_image(sample_img)

In [26]:
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(10)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"index: {index}  {text_inputs[index]} {100 * value.item():.2f}%")
    


Top predictions:

index: 8  The chemical formula is Bi2Te3. The mbj_bandgap value is 0.30. 19.57%
index: 15  The chemical formula is HfO2. The mbj_bandgap value is 5.30. 11.50%
index: 27  The chemical formula is Te3SeO8. The mbj_bandgap value is 4.382. 9.84%
index: 36  The chemical formula is Ta2O5. The mbj_bandgap value is 4.20. 6.76%
index: 10  The chemical formula is WS2. The mbj_bandgap value is 2.10. 6.76%
index: 32  The chemical formula is Bi2O3. The mbj_bandgap value is 2.80. 6.66%
index: 45  The chemical formula is HgS. The mbj_bandgap value is 2.10. 4.57%
index: 47  The chemical formula is GaAs. The mbj_bandgap value is 1.52. 4.36%
index: 49  The chemical formula is InSb. The mbj_bandgap value is 0.17. 4.17%
index: 31  The chemical formula is Sb2O3. The mbj_bandgap value is 3.20. 4.10%


In [27]:
sample_text

'The chemical formula is Te3SeO8. The  mbj_bandgap value is 4.382.'