In [54]:
from experta import *

class UserNeeds(Fact):
    """User's needs for model recommendation."""
    task = Field(str)
    vram = Field(int, default=0)  # in GB
    minimum_parameters_number = Field(int, default=0)
    maximum_parameters_number = Field(int, default=1_000_000_000)
    latency_sensitive = Field(bool, default=False)
    memory_efficient = Field(bool, default=False)
    training_required = Field(bool, default=False)
    real_time = Field(bool, default=False)
    open_source = Field(bool, default=True)
    mobile = Field(bool, default=False)

class ModelRecommender(KnowledgeEngine):

    @Rule(UserNeeds(task='image_classification',
                    real_time=True,
                    latency_sensitive=True,
                    memory_efficient=True,
                    vram=P(lambda x: x >= 4),
                    minimum_parameters_number=P(lambda minp: minp <= 5300000),
                    maximum_parameters_number=P(lambda maxp: maxp >= 5300000)))
    def recommend_efficientnet_lite(self):
        print("✅ Recommended Model: EfficientNet-Lite (Real-time & efficient)")
        print("🔗 TensorFlow: https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2")
        print("🔗 PyTorch (community): https://github.com/rwightman/gen-efficientnet-pytorch")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='text_generation',
                    training_required=False,
                    open_source=True,
                    vram=P(lambda x: x >= 12)))
    def recommend_mistral(self):
        print("✅ Recommended Model: Mistral-7B (open-source)")
        print("🔗 PyTorch: https://huggingface.co/mistralai/Mistral-7B-v0.1")
        print("🔗 TensorFlow (convertible): https://huggingface.co/docs/transformers/main/en/model_doc/mistral")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='image_classification',
                    vram=P(lambda x: x >= 8),
                    minimum_parameters_number=P(lambda minp: minp <= 25000000),
                    maximum_parameters_number=P(lambda maxp: maxp >= 25000000)))
    def recommend_resnet50(self):
        print("✅ Recommended Model: ResNet50")
        print("🔗 PyTorch: https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet50.html")
        print("🔗 TensorFlow: https://www.tensorflow.org/api_docs/python/tf/keras/applications/ResNet50")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='image_classification',
                    memory_efficient=True))
    def recommend_mobilenetv2(self):
        print("✅ Recommended Model: MobileNetV2")
        print("🔗 TensorFlow: https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV2")
        print("🔗 PyTorch: https://pytorch.org/vision/stable/models/generated/torchvision.models.mobilenet_v2.html")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='object_detection',
                    vram=P(lambda x: x >= 4)))
    def recommend_yolov5(self):
        print("✅ Recommended Model: YOLOv5")
        print("🔗 PyTorch: https://github.com/ultralytics/yolov5")
        print("🔗 TensorFlow (convertible): https://github.com/zldrobit/yolov5_tf")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='object_detection'))
    def recommend_efficientdet(self):
        print("✅ Recommended Model: EfficientDet (TF2)")
        print("🔗 TensorFlow: https://github.com/google/automl/tree/master/efficientdet")
        print("🔗 PyTorch (community): https://github.com/rwightman/efficientdet-pytorch")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='text_classification',
                    memory_efficient=True))
    def recommend_distilbert(self):
        print("✅ Recommended Model: DistilBERT (lightweight)")
        print("🔗 PyTorch: https://huggingface.co/distilbert-base-uncased")
        print("🔗 TensorFlow: https://huggingface.co/docs/transformers/model_doc/distilbert")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='text_generation',
                    vram=P(lambda x: x >= 6)))
    def recommend_gpt2(self):
        print("✅ Recommended Model: GPT-2")
        print("🔗 PyTorch: https://huggingface.co/gpt2")
        print("🔗 TensorFlow: https://huggingface.co/docs/transformers/model_doc/gpt2")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='image_segmentation'))
    def recommend_deeplabv3(self):
        print("✅ Recommended Model: DeepLabV3")
        print("🔗 PyTorch: https://pytorch.org/vision/stable/models/generated/torchvision.models.segmentation.deeplabv3_resnet101.html")
        print("🔗 TensorFlow: https://tfhub.dev/tensorflow/deeplabv3/1")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='audio_classification'))
    def recommend_yamnet(self):
        print("✅ Recommended Model: YAMNet")
        print("🔗 TensorFlow: https://tfhub.dev/google/yamnet/1")
        print("🔗 PyTorch (community): https://github.com/tensorflow/models/tree/master/research/audioset/yamnet")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='image_classification',
                    memory_efficient=True))
    def recommend_squeezenet(self):
        print("✅ Recommended Model: SqueezeNet")
        print("🔗 PyTorch: https://pytorch.org/vision/stable/models/generated/torchvision.models.squeezenet1_0.html")
        print("🔗 TensorFlow (community): https://github.com/yeephycho/tensorflow-squeezenet")
        self.declare(Fact(recommended=True))

    @Rule(UserNeeds(task='text_classification'))
    def recommend_albert(self):
        print("✅ Recommended Model: ALBERT")
        print("🔗 PyTorch: https://huggingface.co/albert-base-v2")
        print("🔗 TensorFlow: https://huggingface.co/docs/transformers/model_doc/albert")
        self.declare(Fact(recommended=True))

    @Rule(NOT(Fact(recommended=True)),
          AS.needs << UserNeeds(), salience=-5)
    def no_model_found(self, needs):
        print("\n⚠️ No suitable model found. Adjust your requirements:")
        print(f"- Task: {needs['task']}")
        print(f"- VRAM: {needs['vram']}GB")
        if needs['minimum_parameters_number'] > 0 or needs['maximum_parameters_number'] < 1_000_000_000:
            print(f"- Parameter range: {needs['minimum_parameters_number']:,} to {needs['maximum_parameters_number']:,}")
        print("\n💡 Suggestions: Increase VRAM, relax parameter constraints, or broaden the model scope.")



task = 'image_segmentation'
vram = 12  # in GB
min_params = 10_000
max_params = 250_000_000
memory_efficient = False
latency_sensitive=False
real_time = False
training_required = False
open_source = True
mobile = False

# task = input("Enter task (e.g., image_classification, object_detection, etc.): ")
# vram = float(input("Enter your available VRAM in GB: "))
# min_params = int(input("Minimum number of model parameters (or 0 to skip): "))
# max_params = int(input("Maximum number of model parameters (or 1000000000 to skip): "))
# memory_efficient = input("Need memory-efficient model? (yes/no): ").strip().lower() == 'yes'
# latency_sensitive = input("Latency-sensitive application? (yes/no): ").strip().lower() == 'yes'
# real_time = input("Need real-time inference? (yes/no): ").strip().lower() == 'yes'
# training_required = input("Do you need to train the model from scratch? (yes/no): ").strip().lower() == 'yes'
# open_source = input("Only open-source models? (yes/no): ").strip().lower() == 'yes'
# mobile = input("should the model run on mobile devices? (yes/no): ").strip().lower() == 'yes'


# ---------- RUN ENGINE ----------

engine = ModelRecommender()
engine.reset()
engine.declare(UserNeeds(
    task=task,
    vram=vram,
    minimum_parameters_number=min_params,
    maximum_parameters_number=max_params,
    memory_efficient=memory_efficient,
    latency_sensitive=latency_sensitive,
    real_time = real_time,
    training_required= training_required,
    open_source = open_source,
    mobile=mobile
))
engine.run()

✅ Recommended Model: DeepLabV3
🔗 PyTorch: https://pytorch.org/vision/stable/models/generated/torchvision.models.segmentation.deeplabv3_resnet101.html
🔗 TensorFlow: https://tfhub.dev/tensorflow/deeplabv3/1
