# MaterialVision Model Loading Demo

This notebook demonstrates how to load and use the different vision-language models available in the MaterialVision project:

- **CLIPP-SciBERT**: CLIPP model with SciBERT text encoder
- **CLIPP-DistilBERT**: CLIPP model with DistilBERT text encoder  
- **MobileCLIP**: Apple's MobileCLIP model
- **BLIP**: Salesforce's BLIP model for image-text retrieval

Each model has its own loading function that handles checkpoint loading, device placement, and provides a consistent interface.

## 1. Import Required Libraries

First, let's import all the necessary libraries and modules.

In [4]:
import sys
import os
from pathlib import Path
import importlib.util
import torch
import numpy as np
from PIL import Image
import warnings

# Since we're already in the webapp directory, we can import models.py directly
# No need to add paths since models.py is in the same directory

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

Using device: cuda


## 2. Load Functions from External Files

Now let's import the model loading functions from the `models.py` file.

In [5]:
try:
    # Import model loading functions from models.py
    from models import (
        load_clipp_scibert,
        load_clipp_distilbert, 
        load_mobileclip,
        load_blip
    )
    
    print("‚úÖ Successfully imported model loading functions:")
    print("  - load_clipp_scibert")
    print("  - load_clipp_distilbert")
    print("  - load_mobileclip") 
    print("  - load_blip")
    
except ImportError as e:
    print(f"‚ùå Error importing model functions: {e}")
    print("Make sure you're running this notebook from the MaterialVision root directory")
    print("and that the webapp/models.py file exists.")

Adding to path: /home/jipengsun/MaterialVision/models/CLIPP_allenai
‚úÖ Successfully imported CLIPP SciBERT
Adding to path: /home/jipengsun/MaterialVision/models/CLIPP_bert
‚úÖ Successfully imported CLIPP DistilBERT
Adding to path: /home/jipengsun/MaterialVision/models/Apple_MobileCLIP
‚úÖ Successfully imported MobileCLIP
Adding to path: /home/jipengsun/MaterialVision/models/Salesforce
‚úÖ Successfully imported BLIP
‚úÖ Successfully imported model loading functions:
  - load_clipp_scibert
  - load_clipp_distilbert
  - load_mobileclip
  - load_blip
‚úÖ Successfully imported CLIPP SciBERT
Adding to path: /home/jipengsun/MaterialVision/models/CLIPP_bert
‚úÖ Successfully imported CLIPP DistilBERT
Adding to path: /home/jipengsun/MaterialVision/models/Apple_MobileCLIP
‚úÖ Successfully imported MobileCLIP
Adding to path: /home/jipengsun/MaterialVision/models/Salesforce
‚úÖ Successfully imported BLIP
‚úÖ Successfully imported model loading functions:
  - load_clipp_scibert
  - load_clipp_disti

## 3. Call Loaded Functions with Sample Data

Let's check for available checkpoints and demonstrate loading each model.

In [6]:
# Define checkpoint paths (relative to webapp directory, go up one level to access models)
checkpoint_paths = {
    'clipp_scibert': '../models/CLIPP_allenai/checkpoints/best_clipp.pth',
    'clipp_distilbert': '../models/CLIPP_bert/checkpoints/best_clipp_bert.pth', 
    'mobileclip': '../models/Apple_MobileCLIP/checkpoints/best_clipp_apple.pth',
    'blip': '../models/Salesforce/checkpoints_blip/best_blip.pth'
}

# Check which checkpoints exist
available_models = {}
for model_name, path in checkpoint_paths.items():
    full_path = Path(path)
    if full_path.exists():
        available_models[model_name] = str(full_path)
        print(f"‚úÖ {model_name}: {path}")
    else:
        print(f"‚ùå {model_name}: {path} (not found)")

print(f"\nFound {len(available_models)} available model checkpoints.")

‚úÖ clipp_scibert: ../models/CLIPP_allenai/checkpoints/best_clipp.pth
‚úÖ clipp_distilbert: ../models/CLIPP_bert/checkpoints/best_clipp_bert.pth
‚úÖ mobileclip: ../models/Apple_MobileCLIP/checkpoints/best_clipp_apple.pth
‚úÖ blip: ../models/Salesforce/checkpoints_blip/best_blip.pth

Found 4 available model checkpoints.


### 3.1 Load CLIPP-SciBERT Model

In [4]:
if 'clipp_scibert' in available_models:
    try:
        print("Loading CLIPP-SciBERT model...")
        clipp_scibert_model, clipp_scibert_tokenizer, clipp_scibert_dataset = load_clipp_scibert(
            checkpoint_path=available_models['clipp_scibert'],
            device=str(device)
        )
        
        print("‚úÖ CLIPP-SciBERT model loaded successfully!")
        print(f"   Model device: {next(clipp_scibert_model.parameters()).device}")
        print(f"   Tokenizer type: {type(clipp_scibert_tokenizer).__name__}")
        print(f"   Dataset type: {type(clipp_scibert_dataset).__name__}")        
        # Test tokenization
        sample_text = "The chemical formula is UGe2Pt2. The mbj_bandgap value is 0.0."
        caption, input_ids, attention_mask = clipp_scibert_dataset.prepare_caption(sample_text)
        print(f"sample input_ids: {input_ids}")
        print(f"sample attention_mask: {attention_mask}")

        # Test text embedding
        txt_emb = clipp_scibert_model.get_text_features(input_ids.view(1,-1).to(device), attention_mask.view(1,-1).to(device))
        print(f"Text embedding shape: {txt_emb.shape}")
    except Exception as e:
        print(f"‚ùå Error loading CLIPP-SciBERT: {e}")
else:
    print("‚è≠Ô∏è  CLIPP-SciBERT checkpoint not available, skipping...")

Loading CLIPP-SciBERT model...


2025-11-09 23:18:23,539 INFO: Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.augreg2_in21k_ft_in1k)
2025-11-09 23:18:23,581 INFO: [timm/vit_base_patch16_224.augreg2_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


‚úÖ CLIPP-SciBERT model loaded successfully!
   Model device: cuda:0
   Tokenizer type: BertTokenizerFast
   Dataset type: ImageTextDataset
sample input_ids: tensor([ 102,  158,  504,  170, 1240,  170, 3471,  244,  205,  244,  103,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0

### 3.2 Load CLIPP-DistilBERT Model

In [5]:
if 'clipp_distilbert' in available_models:
    try:
        print("Loading CLIPP-DistilBERT model...")
        clipp_distilbert_model, clipp_distilbert_tokenizer, clipp_distilbert_dataset = load_clipp_distilbert(
            checkpoint_path=available_models['clipp_distilbert'],
            device=str(device)
        )
        
        print("‚úÖ CLIPP-DistilBERT model loaded successfully!")
        print(f"   Model device: {next(clipp_distilbert_model.parameters()).device}")
        print(f"   Tokenizer type: {type(clipp_distilbert_tokenizer).__name__}")
        print(f"   Dataset type: {type(clipp_distilbert_dataset).__name__}")

        # Test tokenization
        sample_text = "The chemical formula is UGe2Pt2. The mbj_bandgap value is 0.0."
        caption, input_ids, attention_mask = clipp_distilbert_dataset.prepare_caption(sample_text)
        embeddings = clipp_distilbert_model.get_text_features(input_ids.view(1,-1).to(device), attention_mask.view(1,-1).to(device))
        print(f"Text embedding shape: {embeddings.shape}")
        
    except Exception as e:
        print(f"‚ùå Error loading CLIPP-DistilBERT: {e}")
else:
    print("‚è≠Ô∏è  CLIPP-DistilBERT checkpoint not available, skipping...")

Loading CLIPP-DistilBERT model...


2025-11-09 23:18:40,711 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet50.a1_in1k)
2025-11-09 23:18:40,756 INFO: [timm/resnet50.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


‚úÖ CLIPP-DistilBERT model loaded successfully!
   Model device: cuda:0
   Tokenizer type: DistilBertTokenizer
   Dataset type: ImageTextDataset
Text embedding shape: torch.Size([1, 256])


### 3.3 Load MobileCLIP Model

In [38]:
# Load BLIP embeddings from saved pickle file
import pickle
import pandas as pd
import numpy as np
from pathlib import Path

# Define the path to the BLIP embeddings
blip_embeddings_path = Path('./embeddings/val_df_with_embeddings_apple.pkl')

print(f"üîÑ Loading BLIP embeddings from: {blip_embeddings_path}")

if blip_embeddings_path.exists():
    # Load the pickle file
    with open(blip_embeddings_path, 'rb') as f:
        blip_embeddings_data = pickle.load(f)

üîÑ Loading BLIP embeddings from: embeddings/val_df_with_embeddings_apple.pkl


In [52]:
torch.allclose(blip_embeddings_data.iloc[0]["val_txt_embs"][0][0], embeddings.cpu()[0][0])


True

In [23]:
if 'mobileclip' in available_models:
    try:
        print("Loading MobileCLIP model...")
        mobileclip_model, mobileclip_tokenizer, mobileclip_dataset = load_mobileclip(
            checkpoint_path=available_models['mobileclip'],
            device=str(device)
        )
        
        print("‚úÖ MobileCLIP model loaded successfully!")
        print(f"   Model device: {next(mobileclip_model.parameters()).device}")
        print(f"   Tokenizer type: {type(mobileclip_tokenizer)}")
        print(f"   Dataset type: {type(mobileclip_dataset).__name__}")

        # Test tokenization (MobileCLIP uses different tokenization)
        sample_text = "The chemical formula is LiGeS. The  mbj_bandgap value is 0.0."
        caption, text_tokens = mobileclip_dataset.prepare_caption(sample_text)
        embeddings = mobileclip_model.get_text_features(text_tokens.to(device))
        print(f"Text embedding shape: {embeddings.shape}")

    except Exception as e:
        print(f"‚ùå Error loading MobileCLIP: {e}")
else:
    print("‚è≠Ô∏è  MobileCLIP checkpoint not available, skipping...")

2025-11-11 00:20:59,847 INFO: Loaded MobileCLIP-S2 model config.


Loading MobileCLIP model...


2025-11-11 00:21:01,495 INFO: Loading pretrained MobileCLIP-S2 weights (datacompdr).


‚úÖ MobileCLIP model loaded successfully!
   Model device: cuda:0
   Tokenizer type: <class 'open_clip.tokenizer.SimpleTokenizer'>
   Dataset type: ImageTextDataset
Text embedding shape: torch.Size([1, 256])


### 3.4 Load BLIP Model

In [11]:
if 'blip' in available_models:
    try:
        print("Loading BLIP model...")
        blip_model, blip_processor, blip_dataset = load_blip(
            checkpoint_path=available_models['blip'],
            device=str(device)
        )
        
        print("‚úÖ BLIP model loaded successfully!")
        print(f"   Model device: {next(blip_model.parameters()).device}")
        print(f"   Processor type: {type(blip_processor).__name__}")
        print(f"   Dataset type: {type(blip_dataset).__name__}")
        
        # Test text processing
        sample_text = "The chemical formula is LiGeS. The  mbj_bandgap value is 0.0."
        caption, input_ids, attention_mask = blip_dataset.prepare_caption(sample_text)
        embeddings = blip_model.get_text_features(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device))
        print(f"Text embedding shape: {embeddings.shape}")
        
    except Exception as e:
        print(f"‚ùå Error loading BLIP: {e}")
else:
    print("‚è≠Ô∏è  BLIP checkpoint not available, skipping...")

Loading BLIP model...
‚úÖ BLIP model loaded successfully!
   Model device: cuda:0
   Processor type: BlipProcessor
   Dataset type: ImageTextDataset
Text embedding shape: torch.Size([1, 256])
‚úÖ BLIP model loaded successfully!
   Model device: cuda:0
   Processor type: BlipProcessor
   Dataset type: ImageTextDataset
Text embedding shape: torch.Size([1, 256])
