# Conifer classifier


## Imports

In [1]:
# Importing Library
# !pip install pgmpy
from pgmpy.models import BayesianNetwork
from pgmpy.inference import VariableElimination

# Defining the parameters using CPT
from pgmpy.factors.discrete import TabularCPD

  device = torch.device("cpu")


## Model

In [2]:
# Defining network structure

conifer_model = BayesianNetwork(
    [
        ("FruitType", "BerryTree"),
        ("Color", "BerryTree"),
        ("FruitType", "ConeShapeHint"),
        ("ConeShape", "ConeShapeHint"),
        ("ConeShapeHint", "FirSpruce"),
        ("Orientation", "FirSpruce"),
        ("ConeShapeHint", "NeedleCountHint"),
        ("NeedleCount", "NeedleCountHint"),
        ("NeedleCountHint", "Pines"),
        ("PlantShape", "Pines"),
        ("ConeSize", "Pines"),
        ("FruitType", "Pseudotsuga"),
        ("NeedleScent", "Pseudotsuga"),
        ("NeedleCount", "ConeSeedHint"),
        ("ConeSeed", "ConeSeedHint")
    ]
)

cpd_fruit_type = TabularCPD(
    variable="FruitType",
    variable_card=2,
    values=[[0.1], [0.9]],
    state_names={"FruitType": ["Berry", "Cone"]}
)

cpd_color = TabularCPD(
    variable="Color",
    variable_card=2,
    values=[[0.05], [0.95]],
    state_names={"Color": ["Red", "Blue"]}
)

cpd_berry_tree = TabularCPD(
    variable="BerryTree",
    variable_card=3,
    values=[[0.9, 0.05, 0.001, 0.001], # [[Yew | (berry, red), Yew | (berry, blue), Yew | (cone, red), Yew | (cone, blue)],
            [0.03, 0.9, 0.001, 0.001], # [Juniper | (berry, red), Juniper | (berry, blue), Juniper | (cone, red), Juniper | (cone, blue)]]
            [0.07, 0.05, 0.998, 0.998]],  
    evidence=["FruitType", "Color"],
    evidence_card=[2, 2],
    state_names={"BerryTree": ["Yew", "Juniper", "Other"], 
                 "FruitType": ["Berry", "Cone"], 
                 "Color": ["Red", "Blue"]}
)

cpd_cone_shape = TabularCPD(
    variable="ConeShape",
    variable_card=2,
    values=[[0.2], [0.8]],
    state_names={"ConeShape": ["Elongated", "Oval"]}
)

cpd_cone_shape_hint = TabularCPD(
    variable="ConeShapeHint",
    variable_card=2,
    values=[[0.01, 0.01, 0.96, 0.1], # [[FirSpruceHint | (berry, elongated), FirSpruceHint | (berry, oval), FirSpruceHint | (cone, elongated), FirSpruceHint | (cone, oval)]
            [0.99, 0.99, 0.04, 0.9]],      # [PineLarchHint | (berry, elongated), PineLarchHint | (berry, oval), PineLarchHint | (cone, elongated), PineLarchHint | (cone, oval)]]
    evidence=["FruitType", "ConeShape"],
    evidence_card=[2,2],
    state_names={"ConeShapeHint": ["FirSpruceHint", "PineLarchHint"],
                  "FruitType": ["Berry", "Cone"],
                  "ConeShape": ["Elongated", "Oval"]}
)

cpd_orientation = TabularCPD(
    variable="Orientation",
    variable_card=2,
    values=[[0.005], [0.995]],
    state_names={"Orientation": ["Upwards", "Downwards"]}
)

cpd_fir_spruce = TabularCPD(
    variable="FirSpruce",
    variable_card=3,
    values=[[0.95, 0.04, 0.01, 0.01],
            [0.03, 0.94, 0.01, 0.01],
            [0.02, 0.02, 0.98, 0.98]],
    evidence=["ConeShapeHint", "Orientation"],
    evidence_card=[2,2],
    state_names={"FirSpruce": ["Fir", "Spruce", "Other"],
                 "ConeShapeHint": ["FirSpruceHint", "PineLarchHint"],
                 "Orientation": ["Upwards", "Downwards"]}
)

cpd_needle_count = TabularCPD(
    variable="NeedleCount",
    variable_card=3,
    values=[[0.2], [0.7], [0.1]],
    state_names={"NeedleCount": ["MoreThanTwenty", "AlwaysTwo", "AlwaysFive"]}
)

cpd_needle_count_hint = TabularCPD(
    variable="NeedleCountHint",
    variable_card=2,
            # [[Larch | (FirSpruceHint, MoreThanTwenty), Larch | (FirSpruceHint, AlwaysTwo), Larch | (FirSpruceHint, AlwaysFive), Larch | (PineLarchHint, MoreThanTwenty), Larch | (PineLarchHint, AlwaysTwo), Larch | (PineLarchHint, AlwaysFive)],
            # [TwoNeedlePine | (FirSpruceHint, MoreThanTwenty), TwoNeedlePine | (FirSpruceHint, AlwaysTwo), TwoNeedlePine | (FirSpruceHint, AlwaysFive), TwoNeedlePine | (PineLarchHint, MoreThanTwenty), TwoNeedlePine | (PineLarchHint,   AlwaysTwo), TwoNeedlePine | (PineLarchHint, AlwaysFive)]]
    values=[[0.15, 0.05, 0.7, 0.85, 0.18, 0.2],
            [0.85, 0.95, 0.3, 0.15, 0.82, 0.8]],
    evidence=["ConeShapeHint", "NeedleCount"],
    evidence_card=[2,3],
    state_names={"NeedleCountHint": ["Larch", "TwoNeedlePine"],
                 "ConeShapeHint": ["FirSpruceHint", "PineLarchHint"],
                 "NeedleCount": ["MoreThanTwenty", "AlwaysTwo", "AlwaysFive"]}
)

cpd_plant_shape = TabularCPD(
    variable="PlantShape",
    variable_card=2,
    values=[[0.1], [0.9]],
    state_names={"PlantShape": ["Bush", "Tree"]}
)

cpd_cone_size = TabularCPD(
    variable="ConeSize",
    variable_card=2,
    values=[[0.15], [0.85]],
    state_names={"ConeSize": ["Small", "Big"]}
)

cpd_pines = TabularCPD(
    variable="Pines",
    variable_card=4,
     # [[PinusMugo | (Larch, Bush, Small), PinusMugo | (Larch, Bush, Big), PinusMugo | (Larch, Tree, Small), PinusMugo | (Larch, Tree, Big),
     #  PinusMugo | (TwoNeedlePine, Bush, Small), PinusMugo | (TwoNeedlePine, Bush, Big), PinusMugo | (TwoNeedlePine, Tree, Small), PinusMugo | (TwoNeedlePine, Tree, Big)],
     # [ScotsPine | (Larch, Bush, Small), ScotsPine | (Larch, Bush, Big), ScotsPine | (Larch, Tree, Small), ScotsPine | (Larch, Tree, Big),
     #  ScotsPine | (TwoNeedlePine, Bush, Small), ScotsPine | (TwoNeedlePine, Bush, Big), ScotsPine | (TwoNeedlePine, Tree, Small), ScotsPine | (TwoNeedlePine, Tree, Big)],
     # [BlackPine | (Larch, Bush, Small), BlackPine | (Larch, Bush, Big), BlackPine | (Larch, Tree, Small), BlackPine | (Larch, Tree, Big),
     #  BlackPine | (TwoNeedlePine, Bush, Small), BlackPine | (TwoNeedlePine, Bush, Big), BlackPine | (TwoNeedlePine, Tree, Small), BlackPine | (TwoNeedlePine, Tree, Big)],
     # [Other | (Larch, Bush, Small), Other | (Larch, Bush, Big), Other | (Larch, Tree, Small), Other | (Larch, Tree, Big),
     #  Other | (TwoNeedlePine, Bush, Small), Other | (TwoNeedlePine, Bush, Big), Other | (TwoNeedlePine, Tree, Small), Other | (TwoNeedlePine, Tree, Big)]]
    values=[[0.12, 0.10, 0.06, 0.03, 0.75, 0.03, 0.09, 0.07],
            [0.07, 0.04, 0.15, 0.12, 0.13, 0.14, 0.78, 0.16],
            [0.03, 0.05, 0.11, 0.13, 0.07, 0.09, 0.12, 0.76],
            [0.78, 0.81, 0.68, 0.72, 0.05, 0.74, 0.01, 0.01]],
    evidence=["NeedleCountHint", "PlantShape", "ConeSize"],
    evidence_card=[2,2,2],
    state_names={"Pines": ["PinusMugo", "ScotsPine", "BlackPine", "Other"],
                 "NeedleCountHint": ["Larch", "TwoNeedlePine"],
                 "PlantShape": ["Bush", "Tree"],
                 "ConeSize": ["Small", "Big"]}
)

cpd_lemon_scent = TabularCPD(
    variable="NeedleScent",
    variable_card=2,
    values=[[0.3], [0.7]],
    state_names={"NeedleScent": ["Lemon", "Other"]}
)

cpd_pseudotsuga = TabularCPD(
    variable="Pseudotsuga",
    variable_card=2,
    values=[[0.14, 0.02, 0.88, 0.16],   # [[Pseudotsuga | (Berry, LemonScent), Pseudotsuga | (Berry, OtherScent), Pseudotsuga | (Cone, LemonScent), Pseudotsuga | (Cone, OtherScent)],
            [0.86, 0.98, 0.12, 0.84]],  # [Other | (Berry, LemonScent), Other | (Berry, OtherScent), Other | (Cone, LemonScent), Other | (Cone, OtherScent)]]
    evidence=["FruitType", "NeedleScent"],
    evidence_card=[2,2],
    state_names={"Pseudotsuga": ["Yes", "No"],
                 "FruitType": ["Berry", "Cone"],
                 "NeedleScent": ["Lemon", "Other"]}
)

cpd_cone_seed = TabularCPD(
    variable="ConeSeed",
    variable_card=2,
    values=[[0.25], [0.75]],
    state_names={"ConeSeed": ["Nut", "Wing"]}
)

cpd_cone_seed_hint = TabularCPD(
    variable="ConeSeedHint",
    variable_card=2,
              # [PinusCembra | (MoreThanTwenty, Nut), PinusCembra | (MoreThanTwenty, Wing), PinusCembra | (AlwaysTwo, Nut), PinusCembra | (AlwaysTwo, Wing),
              #  PinusCembra | (AlwaysFive, Nut), PinusCembra | (AlwaysFive, Wing)]
              # [PinusStrobus | (MoreThanTwenty, Nut), PinusStrobus | (MoreThanTwenty, Wing), PinusStrobus | (AlwaysTwo, Nut), PinusStrobus | (AlwaysTwo, Wing),
              #  PinusStrobus | (AlwaysFive, Nut), PinusStrobus | (AlwaysFive, Wing)]
    values=[[0.09, 0.06, 0.12, 0.08, 0.82, 0.18],
            [0.91, 0.94, 0.88, 0.92, 0.18, 0.82]],
    evidence=["NeedleCount", "ConeSeed"],
    evidence_card=[3,2],
    state_names={"ConeSeedHint": ["PinusCembra", "PinusStrobus"],
                 "NeedleCount": ["MoreThanTwenty", "AlwaysTwo", "AlwaysFive"],
                 "ConeSeed": ["Nut", "Wing"]}
)

# Associating the parameters with the model structure
conifer_model.add_cpds(
    cpd_fruit_type, cpd_color, cpd_berry_tree, cpd_cone_shape, cpd_cone_shape_hint, cpd_orientation, cpd_fir_spruce,
    cpd_needle_count, cpd_needle_count_hint, cpd_plant_shape, cpd_cone_size, cpd_pines, cpd_lemon_scent, cpd_pseudotsuga, cpd_cone_seed, cpd_cone_seed_hint
)

In [3]:
# Checking if the cpds are valid for the model
conifer_model.check_model()

True

In [4]:
print(cpd_fruit_type)
print(cpd_color)
print(cpd_berry_tree)
print(cpd_cone_shape)
print(cpd_orientation)
print(cpd_fir_spruce)
print(cpd_needle_count)
print(cpd_needle_count_hint)
print(cpd_plant_shape)
print(cpd_cone_size)
print(cpd_pines)
print(cpd_lemon_scent)
print(cpd_pseudotsuga)
print(cpd_cone_seed)
print(cpd_cone_seed_hint)

+------------------+-----+
| FruitType(Berry) | 0.1 |
+------------------+-----+
| FruitType(Cone)  | 0.9 |
+------------------+-----+
+-------------+------+
| Color(Red)  | 0.05 |
+-------------+------+
| Color(Blue) | 0.95 |
+-------------+------+
+--------------------+-----+-----------------+
| FruitType          | ... | FruitType(Cone) |
+--------------------+-----+-----------------+
| Color              | ... | Color(Blue)     |
+--------------------+-----+-----------------+
| BerryTree(Yew)     | ... | 0.001           |
+--------------------+-----+-----------------+
| BerryTree(Juniper) | ... | 0.001           |
+--------------------+-----+-----------------+
| BerryTree(Other)   | ... | 0.998           |
+--------------------+-----+-----------------+
+----------------------+-----+
| ConeShape(Elongated) | 0.2 |
+----------------------+-----+
| ConeShape(Oval)      | 0.8 |
+----------------------+-----+
+------------------------+-------+
| Orientation(Upwards)   | 0.005 |
+-------

In [5]:
# Viewing nodes of the model
conifer_model.nodes()

NodeView(('FruitType', 'BerryTree', 'Color', 'ConeShapeHint', 'ConeShape', 'FirSpruce', 'Orientation', 'NeedleCountHint', 'NeedleCount', 'Pines', 'PlantShape', 'ConeSize', 'Pseudotsuga', 'NeedleScent', 'ConeSeedHint', 'ConeSeed'))

In [6]:
# Viewing edges of the model
conifer_model.edges()

OutEdgeView([('FruitType', 'BerryTree'), ('FruitType', 'ConeShapeHint'), ('FruitType', 'Pseudotsuga'), ('Color', 'BerryTree'), ('ConeShapeHint', 'FirSpruce'), ('ConeShapeHint', 'NeedleCountHint'), ('ConeShape', 'ConeShapeHint'), ('Orientation', 'FirSpruce'), ('NeedleCountHint', 'Pines'), ('NeedleCount', 'NeedleCountHint'), ('NeedleCount', 'ConeSeedHint'), ('PlantShape', 'Pines'), ('ConeSize', 'Pines'), ('NeedleScent', 'Pseudotsuga'), ('ConeSeed', 'ConeSeedHint')])

In [7]:
# Listing all Independencies
#conifer_model.get_independencies()

## Inference

In [8]:
conifer_infer = VariableElimination(conifer_model)

### Juniper

In [9]:
q = conifer_infer.query(variables=["BerryTree"], evidence={
    "FruitType": "Berry",
    "Color": "Blue"
})

print(q)

+--------------------+------------------+
| BerryTree          |   phi(BerryTree) |
| BerryTree(Yew)     |           0.0500 |
+--------------------+------------------+
| BerryTree(Juniper) |           0.9000 |
+--------------------+------------------+
| BerryTree(Other)   |           0.0500 |
+--------------------+------------------+


### Yew

In [10]:
p = conifer_infer.query(variables=["BerryTree"], evidence={
    "FruitType": "Berry",
    "Color": "Red"
})

print(p)

+--------------------+------------------+
| BerryTree          |   phi(BerryTree) |
| BerryTree(Yew)     |           0.9000 |
+--------------------+------------------+
| BerryTree(Juniper) |           0.0300 |
+--------------------+------------------+
| BerryTree(Other)   |           0.0700 |
+--------------------+------------------+


### Fir

In [11]:
p = conifer_infer.query(variables=["FirSpruce"], evidence={
    "FruitType": "Cone",
    "ConeShape": "Elongated",
    "Orientation": "Upwards"
})

print(p)

+-------------------+------------------+
| FirSpruce         |   phi(FirSpruce) |
| FirSpruce(Fir)    |           0.9124 |
+-------------------+------------------+
| FirSpruce(Spruce) |           0.0292 |
+-------------------+------------------+
| FirSpruce(Other)  |           0.0584 |
+-------------------+------------------+


### Spruce

In [12]:
p = conifer_infer.query(variables=["FirSpruce"], evidence={
    "FruitType": "Cone",
    "ConeShape": "Elongated",
    "Orientation": "Downwards"
})

print(p)

+-------------------+------------------+
| FirSpruce         |   phi(FirSpruce) |
| FirSpruce(Fir)    |           0.0388 |
+-------------------+------------------+
| FirSpruce(Spruce) |           0.9028 |
+-------------------+------------------+
| FirSpruce(Other)  |           0.0584 |
+-------------------+------------------+


### Larch

In [13]:
p = conifer_infer.query(variables=["NeedleCountHint"], evidence={
    "FruitType": "Cone",
    "ConeShape": "Oval",
    "NeedleCount": "MoreThanTwenty"
})

print(p)

+--------------------------------+------------------------+
| NeedleCountHint                |   phi(NeedleCountHint) |
| NeedleCountHint(Larch)         |                 0.7800 |
+--------------------------------+------------------------+
| NeedleCountHint(TwoNeedlePine) |                 0.2200 |
+--------------------------------+------------------------+


### Pinus Mugo

In [14]:
p = conifer_infer.query(variables=["Pines"], evidence={
    "FruitType": "Cone",
    "ConeShape": "Oval",
    "NeedleCount": "AlwaysTwo",
    "PlantShape": "Bush",
    "ConeSize": "Small"
})

print(p)

+------------------+--------------+
| Pines            |   phi(Pines) |
| Pines(PinusMugo) |       0.6448 |
+------------------+--------------+
| Pines(ScotsPine) |       0.1200 |
+------------------+--------------+
| Pines(BlackPine) |       0.0633 |
+------------------+--------------+
| Pines(Other)     |       0.1719 |
+------------------+--------------+


### Scots Pine

In [15]:
p = conifer_infer.query(variables=["Pines"], evidence={
    "FruitType": "Cone",
    "ConeShape": "Oval",
    "NeedleCount": "AlwaysTwo",
    "ConeSize": "Small"
})

print(p)

+------------------+--------------+
| Pines            |   phi(Pines) |
| Pines(PinusMugo) |       0.1410 |
+------------------+--------------+
| Pines(ScotsPine) |       0.6193 |
+------------------+--------------+
| Pines(BlackPine) |       0.1128 |
+------------------+--------------+
| Pines(Other)     |       0.1269 |
+------------------+--------------+


### Black Pine

In [16]:
p = conifer_infer.query(variables=["Pines"], evidence={
    "FruitType": "Cone",
    "ConeShape": "Oval",
    "NeedleCount": "AlwaysTwo",
    "ConeSize": "Big"
})

print(p)

+------------------+--------------+
| Pines            |   phi(Pines) |
| Pines(PinusMugo) |       0.0612 |
+------------------+--------------+
| Pines(ScotsPine) |       0.1503 |
+------------------+--------------+
| Pines(BlackPine) |       0.5976 |
+------------------+--------------+
| Pines(Other)     |       0.1909 |
+------------------+--------------+


### Pseudotsuga

In [17]:
p = conifer_infer.query(variables=["Pseudotsuga"], evidence={
    "FruitType": "Cone",
    "NeedleScent": "Lemon"
})

print(p)

+------------------+--------------------+
| Pseudotsuga      |   phi(Pseudotsuga) |
| Pseudotsuga(Yes) |             0.8800 |
+------------------+--------------------+
| Pseudotsuga(No)  |             0.1200 |
+------------------+--------------------+


### Pinus Cembra

In [18]:
p = conifer_infer.query(variables=["ConeSeedHint"], evidence={
    "NeedleCount": "AlwaysFive",
    "ConeSeed": "Nut"
})

print(p)

+----------------------------+---------------------+
| ConeSeedHint               |   phi(ConeSeedHint) |
| ConeSeedHint(PinusCembra)  |              0.8200 |
+----------------------------+---------------------+
| ConeSeedHint(PinusStrobus) |              0.1800 |
+----------------------------+---------------------+


### Pinus Strobus

In [19]:
p = conifer_infer.query(variables=["ConeSeedHint"], evidence={
    "NeedleCount": "AlwaysFive",
    "ConeSeed": "Wing"
})

print(p)

+----------------------------+---------------------+
| ConeSeedHint               |   phi(ConeSeedHint) |
| ConeSeedHint(PinusCembra)  |              0.1800 |
+----------------------------+---------------------+
| ConeSeedHint(PinusStrobus) |              0.8200 |
+----------------------------+---------------------+
