# Conifer classifier


In [83]:
# Importing Library
# !pip install pgmpy
from pgmpy.models import BayesianNetwork
from pgmpy.inference import VariableElimination

# Defining the parameters using CPT
from pgmpy.factors.discrete import TabularCPD

In [84]:
# Defining network structure

conifer_model = BayesianNetwork(
    [
        ("FruitType", "BerryTree"),
        ("Color", "BerryTree"),
        ("FruitType", "ConeShapeHint"),
        ("ConeShape", "ConeShapeHint"),
        ("ConeShapeHint", "FirSpruce"),
        ("Orientation", "FirSpruce")
    ]
)

cpd_fruit_type = TabularCPD(
    variable="FruitType",
    variable_card=2,
    values=[[0.001], [0.999]],
    state_names={"FruitType": ["Berry", "Cone"]}
)

cpd_color = TabularCPD(
    variable="Color",
    variable_card=2,
    values=[[0.05], [0.95]],
    state_names={"Color": ["Red", "Blue"]}
)

cpd_berry_tree = TabularCPD(
    variable="BerryTree",
    variable_card=3,
    values=[[0.9, 0.05, 0.001, 0.001], # [[Yew | (berry, red), Yew | (berry, blue), Yew | (cone, red), Yew | (cone, blue)],
            [0.03, 0.9, 0.001, 0.001], # [Juniper | (berry, red), Juniper | (berry, blue), Juniper | (cone, red), Juniper | (cone, blue)]]
            [0.07, 0.05, 0.998, 0.998]],  
    evidence=["FruitType", "Color"],
    evidence_card=[2, 2],
    state_names={"BerryTree": ["Yew", "Juniper", "Other"], 
                 "FruitType": ["Berry", "Cone"], 
                 "Color": ["Red", "Blue"]}
)

cpd_cone_shape = TabularCPD(
    variable="ConeShape",
    variable_card=2,
    values=[[0.2], [0.8]],
    state_names={"ConeShape": ["Elongated", "Oval"]}
)

cpd_cone_shape_hint = TabularCPD(
    variable="ConeShapeHint",
    variable_card=2,
    values=[[0.01, 0.01, 0.96, 0.1], # [[FirSpruceHint | (berry, elongated), FirSpruceHint | (berry, oval), FirSpruceHint | (cone, elongated), FirSpruceHint | (cone, oval)]
            [0.99, 0.99, 0.04, 0.9]],      # [PineLarchHint | (berry, elongated), PineLarchHint | (berry, oval), PineLarchHint | (cone, elongated), PineLarchHint | (cone, oval)]]
    evidence=["FruitType", "ConeShape"],
    evidence_card=[2,2],
    state_names={"ConeShapeHint": ["FirSpruceHint", "PineLarchHint"],
                  "FruitType": ["Berry", "Cone"],
                  "ConeShape": ["Elongated", "Oval"]}
)

cpd_orientation = TabularCPD(
    variable="Orientation",
    variable_card=2,
    values=[[0.005], [0.995]],
    state_names={"Orientation": ["Upwards", "Downwards"]}
)

cpd_fir_spruce = TabularCPD(
    variable="FirSpruce",
    variable_card=3,
    values=[[0.95, 0.04, 0.01, 0.01],
            [0.03, 0.94, 0.01, 0.01],
            [0.02, 0.02, 0.98, 0.98]],
    evidence=["ConeShapeHint", "Orientation"],
    evidence_card=[2,2],
    state_names={"FirSpruce": ["Fir", "Spruce", "Other"],
                 "ConeShapeHint": ["FirSpruceHint", "PineLarchHint"],
                 "Orientation": ["Upwards", "Downwards"]}
)

cpd_needle_count = TabularCPD(
    variable="NeedleCount",
    variable_card=2,
    values=[[0.2], [0.8]],
    state_names={"NeedleCount": ["AlwaysTwo", "MoreThanTwenty"]}
)

# Associating the parameters with the model structure
conifer_model.add_cpds(
    cpd_fruit_type, cpd_color, cpd_berry_tree, cpd_cone_shape, cpd_cone_shape_hint, cpd_orientation, cpd_fir_spruce
)

In [85]:
# Checking if the cpds are valid for the model
conifer_model.check_model()

True

In [86]:
print(cpd_fruit_type)
print(cpd_color)
print(cpd_berry_tree)
print(cpd_cone_shape)
print(cpd_fir_spruce_hint)
print(cpd_orientation)
print(cpd_fir_spruce)

+------------------+-------+
| FruitType(Berry) | 0.001 |
+------------------+-------+
| FruitType(Cone)  | 0.999 |
+------------------+-------+
+-------------+------+
| Color(Red)  | 0.05 |
+-------------+------+
| Color(Blue) | 0.95 |
+-------------+------+
+--------------------+------------------+------------------+-----------------+-----------------+
| FruitType          | FruitType(Berry) | FruitType(Berry) | FruitType(Cone) | FruitType(Cone) |
+--------------------+------------------+------------------+-----------------+-----------------+
| Color              | Color(Red)       | Color(Blue)      | Color(Red)      | Color(Blue)     |
+--------------------+------------------+------------------+-----------------+-----------------+
| BerryTree(Yew)     | 0.9              | 0.05             | 0.001           | 0.001           |
+--------------------+------------------+------------------+-----------------+-----------------+
| BerryTree(Juniper) | 0.03             | 0.9              | 

In [87]:
# Viewing nodes of the model
conifer_model.nodes()

NodeView(('FruitType', 'BerryTree', 'Color', 'ConeShapeHint', 'ConeShape', 'FirSpruce', 'Orientation'))

In [88]:
# Viewing edges of the model
conifer_model.edges()

OutEdgeView([('FruitType', 'BerryTree'), ('FruitType', 'ConeShapeHint'), ('Color', 'BerryTree'), ('ConeShapeHint', 'FirSpruce'), ('ConeShape', 'ConeShapeHint'), ('Orientation', 'FirSpruce')])

In [89]:
# Listing all Independencies
conifer_model.get_independencies()

(BerryTree ⟂ Orientation, ConeShape)
(BerryTree ⟂ Orientation, ConeShape | Color)
(BerryTree ⟂ FirSpruce, ConeShapeHint, Orientation, ConeShape | FruitType)
(BerryTree ⟂ Orientation | ConeShape)
(BerryTree ⟂ ConeShape | Orientation)
(BerryTree ⟂ FirSpruce, Orientation | ConeShapeHint)
(BerryTree ⟂ ConeShape, Orientation, ConeShapeHint | FirSpruce, FruitType)
(BerryTree ⟂ Orientation | FirSpruce, ConeShapeHint)
(BerryTree ⟂ FirSpruce, ConeShape, Orientation, ConeShapeHint | Color, FruitType)
(BerryTree ⟂ Orientation | Color, ConeShape)
(BerryTree ⟂ ConeShape | Color, Orientation)
(BerryTree ⟂ FirSpruce, Orientation | Color, ConeShapeHint)
(BerryTree ⟂ FirSpruce, Orientation, ConeShapeHint | ConeShape, FruitType)
(BerryTree ⟂ FirSpruce, ConeShape, ConeShapeHint | Orientation, FruitType)
(BerryTree ⟂ FirSpruce, Orientation, ConeShape | ConeShapeHint, FruitType)
(BerryTree ⟂ FirSpruce, Orientation | ConeShapeHint, ConeShape)
(BerryTree ⟂ FirSpruce | Orientation, ConeShapeHint)
(BerryTree ⟂