In [1]:
import sys
import os
import torch
from torch import nn

sys.path.append(os.path.abspath('../GFM'))
from GFM.models import build_model
from GFM.config import get_config

ModuleNotFoundError: No module named 'GFM'

In [None]:
# 1. Import necessary modules
import torch
from types import SimpleNamespace
import yaml
import os
import sys

# 4. Define a helper class to simulate argparse.Namespace
class Args:
    def __init__(self, cfg, opts=None, batch_size=None, data_path=None, pretrained=None,
                 resume=None, accumulation_steps=None, use_checkpoint=False, amp_opt_level=None,
                 output=None, tag=None, eval=False, throughput=False, train_frac=None,
                 no_val=False, alpha=None, local_rank=0):
        self.cfg = cfg
        self.opts = opts
        self.batch_size = batch_size
        self.data_path = data_path
        self.pretrained = pretrained
        self.resume = resume
        self.accumulation_steps = accumulation_steps
        self.use_checkpoint = use_checkpoint
        self.amp_opt_level = amp_opt_level
        self.output = output
        self.tag = tag
        self.eval = eval
        self.throughput = throughput
        self.train_frac = train_frac
        self.no_val = no_val
        self.alpha = alpha
        self.local_rank = local_rank

# 5. Create an instance of Args with the necessary attributes
args = Args(
    cfg='../configs/gfm_config.yaml',                # Path to BEN.yaml
    # opts=None,                                        # Additional options (if any)
    # batch_size=None,                                  # Use the value from BEN.yaml
    # data_path='/path/to/your/dataset',                # **Replace with your actual dataset path**
    pretrained='../simmim_pretrain/gfm.pth',           # Path to gfm.pth
    # resume=None,                                      # Resume from checkpoint (if any)
    # accumulation_steps=None,                          # Gradient accumulation steps
    # use_checkpoint=False,                             # Whether to use gradient checkpointing
    # amp_opt_level='O1',                               # Mixed precision opt level ('O0', 'O1', 'O2')
    # output='output',                                  # Output directory
    # tag='simmim_finetune__swin_base__img128_window4__100ep',  # Experiment tag
    # eval=False,                                       # Set to True for evaluation only
    # throughput=False,                                 # Set to True to test throughput only
    # train_frac=1.0,                                   # Fraction of training data to use
    # no_val=False,                                     # Whether to skip validation
    # alpha=None,                                       # Mixup/Cutmix alpha (if applicable)
    # local_rank=0                                      # Local rank for DistributedDataParallel
)

# 6. Load the configuration using get_config
config = get_config(args)

# # 7. Print the configuration for verification
# print("Configuration:")
# print(config)


# # 10. Determine the device to run the model on (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

# 11. Create the output directory if it doesn't exist
os.makedirs(config.OUTPUT, exist_ok=True)

# 12. Instantiate the model using build_model
print("\nBuilding the model...")
model = build_model(config, is_pretrain=False)  # Set is_pretrain=True if building pretraining model

# 13. Move the model to the specified device
model = model.to(device)
print("Model has been moved to the device.")

# 14. Print the model architecture (optional)
print("\nModel Architecture:")
# print(model)

# 15. Load the pretrained weights into the model
# Assuming that 'build_model' does not automatically load the weights, perform it manually

print("\nLoading pretrained weights...")
state_dict = torch.load(config.PRETRAINED, map_location=device)
state_dict = state_dict['model']

# Remove 'module.' prefix if the model was saved using DataParallel
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items()}

# Optionally, filter out keys related to 'teacher.' and 'projector.' if they exist
filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith(('teacher.', 'projector.'))}

# Load the state dictionary into the model
missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)

# Print missing and unexpected keys for debugging
if missing_keys:
    print("\nMissing keys when loading pretrained weights:")
    for key in missing_keys:
        print(f"  - {key}")
if unexpected_keys:
    print("\nUnexpected keys when loading pretrained weights:")
    for key in unexpected_keys:
        print(f"  - {key}")

if not missing_keys and not unexpected_keys:
    print("\nAll pretrained weights loaded successfully!")
else:
    print("\nPretrained weights loaded with some missing/unexpected keys.")


In [None]:
model

In [None]:

# 16. Test the model with a dummy input (optional)
# This helps verify that the model is operational
dummy_input = torch.randn(1, 3, config.DATA.IMG_SIZE, config.DATA.IMG_SIZE).to(device)
model.eval()  # Set model to evaluation mode
with torch.no_grad():
    output = model(dummy_input)
print(f"\nDummy input output shape: {output.shape}")  # Expected: [1, NUM_CLASSES, H, W]
