In [1]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import pandas as pd

models = [
    smp.Unet,
    smp.UnetPlusPlus,
    smp.MAnet,
    smp.Linknet,
    smp.FPN,
    smp.PSPNet,
    smp.PAN,
    smp.DeepLabV3,
    smp.DeepLabV3Plus,
]

all_encoders = smp.encoders.encoders.keys()
# Do not include timm encoders for now...
timm_encoders = {e for e in all_encoders if 'timm' in e}
native_encoders = sorted(set(all_encoders) - timm_encoders)

results = []

for model_class in models:
    model_name = model_class.__name__
    for encoder_name in native_encoders:
        try:
            pretrained_options = smp.encoders.encoders[encoder_name]['pretrained_settings'].keys()
        except Exception as e:
            results.append({
                "Model": model_name,
                "Encoder": encoder_name,
                "Weights": pretrained,
                "Preprocessing_fn or Error": f"DOWNLOAD ERROR: {str(e)}"
            })
        for pretrained in pretrained_options:
            try:
                model = model_class(
                    encoder_name=encoder_name,
                    encoder_weights=pretrained,
                    in_channels=3,
                    classes=1
                )
                preprocessing_fn = get_preprocessing_fn(encoder_name, pretrained)
                results.append({
                    "Model": model_name,
                    "Encoder": encoder_name,
                    "Weights": pretrained,
                    "Preprocessing_fn or Error": preprocessing_fn.__name__ if hasattr(preprocessing_fn, '__name__') else str(preprocessing_fn)
                })
            except Exception as e:
                results.append({
                    "Model": model_name,
                    "Encoder": encoder_name,
                    "Weights": pretrained,
                    "Preprocessing_fn or Error": f"ERROR: {str(e)}"
                })

results.append({
    "Model": 'FINISHED!',
    "Encoder": "",
    "Weights": "",
    "Preprocessing_fn or Error": ""
})

df = pd.DataFrame(results)
df.to_csv("smp_model_encoder_compatibility.csv", index=False)

print('Finished!')

Finished!
