In [1]:
import pandas as pd
import torch.nn as nn
import pickle
import torch
from torchvision import models
from torchvision.models import detection, resnet50, ResNet50_Weights
import os
import numpy as np
import cv2
from torchvision import transforms
import pymc3 as pm
import theano.tensor as tt
from sklearn.preprocessing import LabelEncoder
import scipy

In [2]:
CONFIGS = {
    # determine the current device and based on that set the pin memory
    # flag
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    # specify ImageNet mean and standard deviation
    "IMG_MEAN": [0.485, 0.456, 0.406],
    "IMG_STD": [0.229, 0.224, 0.225],
    "MC_DROPOUT_ENABLED": False,  # Switch to enable/disable MC Dropout for confidence score
    "NUM_DROPOUT_RUNS": 3,
    "CONFIDENCE_THRESHOLD": 0,
    "BIG_MODEL_IMG_SIZE": 320,
    "SMALL_MODEL_IMG_SIZE": 60,
    "MEAN_PRIOR": -15,
}

# Big model

## Model loading

In [3]:
class MultiHeadResNet_BigModel(nn.Module):
    def __init__(self, num_classes_prdtype, num_classes_weight, num_classes_halal, num_classes_healthy):
        super(MultiHeadResNet_BigModel, self).__init__()
        self.base_model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Identity()

        # Define custom fully connected layers for each prediction head
        self.fc_prdtype = nn.Linear(num_ftrs, num_classes_prdtype)
        self.fc_weight = nn.Linear(num_ftrs, num_classes_weight)
        self.fc_halal = nn.Linear(num_ftrs, num_classes_halal)
        self.fc_healthy = nn.Linear(num_ftrs, num_classes_healthy)
        self.fc_bbox = nn.Linear(num_ftrs, 4)

    def forward(self, x):
        x = self.base_model(x)
        prdtype = self.fc_prdtype(x)
        weight = self.fc_weight(x)
        halal = self.fc_halal(x)
        healthy = self.fc_healthy(x)
        box = self.fc_bbox(x)
        return prdtype, weight, halal, healthy, box

    
# load label encoder 
def load_label_encoder_big_model():
    le_prdtype = pickle.loads(open("../big_model/le_prdtype.pickle", "rb").read())
    le_weight = pickle.loads(open("../big_model/le_weight.pickle", "rb").read())
    le_halal = pickle.loads(open("../big_model/le_halal.pickle", "rb").read())
    le_healthy = pickle.loads(open("../big_model/le_healthy.pickle", "rb").read())
    
    return le_prdtype, le_weight, le_halal, le_healthy

le_prdtype, le_weight, le_halal, le_healthy = load_label_encoder_big_model()

# Load the trained MultiHeadResNet model
def load_model():
    # Verify the number of classes for each label
    num_classes_prdtype = len(le_prdtype.classes_)
    num_classes_weight = len(le_weight.classes_)
    num_classes_halal = len(le_halal.classes_)
    num_classes_healthy = len(le_healthy.classes_)
    # print(num_classes_prdtype)
    # print(num_classes_healthy)

    custom_resnet_model = MultiHeadResNet_BigModel(
        num_classes_prdtype=num_classes_prdtype,
        num_classes_weight=num_classes_weight,
        num_classes_halal=num_classes_halal,
        num_classes_healthy=num_classes_healthy
    )

    model_path = '../big_model/multi_head_model.pth'
    # print("test1")
    if os.path.exists(model_path):
        custom_resnet_model.load_state_dict(torch.load(model_path, map_location=CONFIGS['DEVICE']))
    else:
        raise FileNotFoundError(f"Model file not found: {model_path}")
    # print("test2")
    custom_resnet_model.to(CONFIGS['DEVICE'])
    custom_resnet_model.eval()
    return custom_resnet_model

big_model = load_model()

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


## Scoring on main imgs

In [4]:
main_imgs_results_big_model = pd.read_csv("../big_model/main_imgs_results_big_model.csv")
main_imgs_results_big_model.head()

Unnamed: 0,Filename,CorrectTotalLabel,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,20231222_0151,OtherDriedFood_100-199g_NonHalal_NonHealthy,-2.706577,-2.922772,-1.214439,-1.865561,0.283586,-3.440058,-4.692658,-4.688087,...,-1.259161,0.065835,-3.455903,-0.858137,-2.530261,-2.970333,-3.660007,3.940337,-4.098905,3.530842
1,20231215_output_frame_0189,CornChip_1-99g_NonHalal_NonHealthy,-3.233564,-3.182215,-2.141094,-3.1284,-0.067066,-0.05715,-4.155023,-3.69233,...,-0.663016,3.715246,-3.54175,-1.947755,-2.929745,-3.327941,-2.904644,2.911234,-3.37238,3.274534
2,20231222_0128,OtherDriedFood_100-199g_NonHalal_NonHealthy,-2.389156,-2.581268,-0.982575,-1.549859,-0.02521,-3.036732,-4.422239,-4.099058,...,-1.157517,-0.187478,-3.295898,-0.533032,-2.493292,-2.85114,-3.369161,3.539039,-3.949286,3.203784
3,2023_10_25_11_18_47_935674,AdultMilk_1000-1999g_Halal_NonHealthy,7.135806,-1.507035,-0.294488,-2.144132,0.031597,-5.09254,-3.257587,-4.648274,...,-1.969015,-1.605096,1.296935,-2.679279,0.709327,-0.644222,4.02416,-2.919567,-4.004476,4.033983
4,20231222_0869,Pasta_500-599g_Halal_NonHealthy,-3.143818,-2.75022,-2.700432,-2.682302,-1.469182,-4.778887,-4.106943,-2.820279,...,-1.469882,7.507227,-2.209881,-0.109595,-3.094495,-3.523381,3.533983,-2.916756,-3.851676,3.714993


In [5]:
# Create a copy of the current column names to a list
new_columns = main_imgs_results_big_model.columns.tolist()

# Modify the first two elements
new_columns[0] = 'filepath'
new_columns[1] = 'label'

# Assign the modified list of column names back to the DataFrame
main_imgs_results_big_model.columns = new_columns


In [6]:
big_model_pred_col_name_original = main_imgs_results_big_model.columns[2:].tolist()

In [7]:
new_imgs_results_small_model = pd.read_csv("../small_model/new_imgs_results_small_model.csv")
new_imgs_results_small_model.head()

Unnamed: 0,Filename,CorrectTotalLabel,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,IMG_0722_jpeg.rf.d225be9cf3e9a21b88a9111c79c48...,OtherNoodles_400-499g_NonHalal_NonHealthy,-0.042604,-0.917905,-1.518212,-0.296711,2.354687,-2.620731,-2.82771,-3.824989,...,7.529936,-0.084116,-1.743978,-5.213832,-1.647417,-1.764002,-3.006634,2.80864,-5.266931,5.735337
1,IMG_2125_jpeg.rf.98ceaf3474e8a755edb2d0c969c93...,BiscuitsCrackersCookies_500-599g_NonHalal_NonH...,-3.947608,0.006503,0.566163,-1.790939,9.11619,-4.677173,-5.523167,-5.752591,...,0.353897,10.116268,-3.711791,-4.14294,-1.033993,-4.143282,-3.573488,3.979955,-3.407155,3.641886
2,20231222_0419.jpg,SweetsChocolatesOthers_200-299g_Halal_NonHealthy,-3.898118,-1.853066,-1.576361,-1.745042,0.757353,-1.532205,-6.475076,-2.823888,...,-0.103763,-1.004154,-3.132432,-2.784842,-1.196354,-4.022461,3.352356,-2.715928,-4.029167,5.381159
3,2023_10_25_11_49_41_382262.jpg,BabyMilkPowder_400-499g_Halal_NonHealthy,0.379034,9.634521,0.752696,-3.017066,-0.769694,-3.41203,-2.079742,-1.423823,...,6.357276,-1.178354,-3.272727,-3.53172,3.178313,-0.548229,3.688625,-2.877854,-2.247989,2.201218
4,2023_8_11_12_16_13_156049_png.rf.d1b4db49f97ab...,FlavoredMilk_1-99g_Halal_Healthy,-1.101073,-0.451816,-1.461948,0.367423,0.476725,-4.617911,-2.182668,-2.599943,...,-1.091295,-3.417022,-2.225312,-1.328926,-0.339995,-0.34411,4.364552,-3.306172,3.103263,-0.926259


In [8]:
# Extract column names that start with 'ProductType'
all_prdtypes_new_imgs = [col for col in new_imgs_results_small_model.columns if col.startswith('ProductType')]
# all_prdtypes_new_imgs = [col.split('_', 1)[1] for col in all_prdtypes_new_imgs]
# all_prdtypes_new_imgs

In [9]:
# Check if any name from 'extracted_names' is not in 'df' and add it as a new column
new_prdtype = list(set(all_prdtypes_new_imgs) - set(main_imgs_results_big_model.columns))

if len(new_prdtype)==1:
    main_imgs_results_big_model[new_prdtype[0]] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=main_imgs_results_big_model.shape[0])  # Initialize new columns

main_imgs_results_big_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,ProductType_JennyBakery
0,20231222_0151,OtherDriedFood_100-199g_NonHalal_NonHealthy,-2.706577,-2.922772,-1.214439,-1.865561,0.283586,-3.440058,-4.692658,-4.688087,...,0.065835,-3.455903,-0.858137,-2.530261,-2.970333,-3.660007,3.940337,-4.098905,3.530842,-15.709644
1,20231215_output_frame_0189,CornChip_1-99g_NonHalal_NonHealthy,-3.233564,-3.182215,-2.141094,-3.1284,-0.067066,-0.05715,-4.155023,-3.69233,...,3.715246,-3.54175,-1.947755,-2.929745,-3.327941,-2.904644,2.911234,-3.37238,3.274534,-14.892749
2,20231222_0128,OtherDriedFood_100-199g_NonHalal_NonHealthy,-2.389156,-2.581268,-0.982575,-1.549859,-0.02521,-3.036732,-4.422239,-4.099058,...,-0.187478,-3.295898,-0.533032,-2.493292,-2.85114,-3.369161,3.539039,-3.949286,3.203784,-15.631121
3,2023_10_25_11_18_47_935674,AdultMilk_1000-1999g_Halal_NonHealthy,7.135806,-1.507035,-0.294488,-2.144132,0.031597,-5.09254,-3.257587,-4.648274,...,-1.605096,1.296935,-2.679279,0.709327,-0.644222,4.02416,-2.919567,-4.004476,4.033983,-14.632346
4,20231222_0869,Pasta_500-599g_Halal_NonHealthy,-3.143818,-2.75022,-2.700432,-2.682302,-1.469182,-4.778887,-4.106943,-2.820279,...,7.507227,-2.209881,-0.109595,-3.094495,-3.523381,3.533983,-2.916756,-3.851676,3.714993,-14.79223


## Scoring on new imgs

In [10]:
new_imgs_df = pd.read_csv("../small_model/new_imgs_list.csv")
new_imgs_df.reset_index(drop=True, inplace=True)
new_imgs_df.head()

Unnamed: 0,filepath,label,ProductType,Weight,HalalStatus,HealthStatus
0,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
1,5141704785419_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
2,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy


In [11]:
transforms_test = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=CONFIGS['IMG_MEAN'], std=CONFIGS['IMG_STD'])
])

In [12]:
new_imgs_results_big_model = []  # List to store the results

for idx, row in new_imgs_df.iterrows():
    image_path = "../small_model/new_imgs/" + row['filepath']
    frame = cv2.imread(image_path)

    # Preprocessing steps
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = cv2.resize(frame, (CONFIGS['BIG_MODEL_IMG_SIZE'], CONFIGS['BIG_MODEL_IMG_SIZE']))
    frame = frame.transpose((2, 0, 1))
    frame = torch.from_numpy(frame).float()
    frame = transforms_test(frame).unsqueeze(0).to(CONFIGS['DEVICE'])

    # Perform prediction
    with torch.no_grad():
        out1, out2, out3, out4, _ = big_model(frame)

    # Extract and store the results
    prediction_row = [row['filepath'], row['label']]
    prediction_row.extend(out1.cpu().numpy().flatten())
    prediction_row.extend(out2.cpu().numpy().flatten())
    prediction_row.extend(out3.cpu().numpy().flatten())
    prediction_row.extend(out4.cpu().numpy().flatten())
    new_imgs_results_big_model.append(prediction_row)


# Define column names for the new DataFrame
column_names = ['filepath', 'label']
column_names += big_model_pred_col_name_original

# Create the DataFrame
new_imgs_results_big_model = pd.DataFrame(new_imgs_results_big_model, columns=column_names)
new_imgs_results_big_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.039913,-1.142112,-2.089395,-2.328363,-0.22226,-3.591941,-1.837839,-3.071156,...,0.525842,0.025882,-2.101103,-0.797072,-1.364065,-2.148964,4.077797,-3.363919,1.234696,-1.354995
1,5141704785419_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.231967,-0.849782,-0.971154,-2.580669,-0.117418,-3.130956,-2.25668,-3.139605,...,-0.271013,-0.138061,-1.910356,-0.641025,-1.320471,-1.997742,3.420116,-2.79764,0.566004,-0.747712
2,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.013698,-0.999261,-1.92391,-3.230182,0.373773,-3.073135,-1.915508,-2.893718,...,0.532749,-0.740819,-1.070358,-0.64976,-1.268365,-1.791995,3.33759,-2.694616,0.84793,-1.066223
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.591283,-0.559135,-1.444255,-3.373103,-0.162989,-3.930771,-2.704357,-2.371563,...,0.617335,-1.164889,0.938439,-1.136923,-1.990635,-1.283677,2.217344,-1.924003,-1.514813,1.182325
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.173288,-1.146693,-2.45628,-2.941579,0.184543,-3.305616,-2.411296,-2.792666,...,-0.431126,1.034137,-1.321176,-1.04378,-1.338015,-1.725011,2.448851,-2.084122,-0.886382,0.851273


In [13]:
new_imgs_results_big_model.shape

(10, 59)

In [14]:
if len(new_prdtype)==1:
    new_imgs_results_big_model[new_prdtype[0]] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=new_imgs_results_big_model.shape[0])  # Initialize new columns

new_imgs_results_big_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,ProductType_JennyBakery
0,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.039913,-1.142112,-2.089395,-2.328363,-0.22226,-3.591941,-1.837839,-3.071156,...,0.025882,-2.101103,-0.797072,-1.364065,-2.148964,4.077797,-3.363919,1.234696,-1.354995,-15.081317
1,5141704785419_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.231967,-0.849782,-0.971154,-2.580669,-0.117418,-3.130956,-2.25668,-3.139605,...,-0.138061,-1.910356,-0.641025,-1.320471,-1.997742,3.420116,-2.79764,0.566004,-0.747712,-14.817607
2,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.013698,-0.999261,-1.92391,-3.230182,0.373773,-3.073135,-1.915508,-2.893718,...,-0.740819,-1.070358,-0.64976,-1.268365,-1.791995,3.33759,-2.694616,0.84793,-1.066223,-15.051154
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.591283,-0.559135,-1.444255,-3.373103,-0.162989,-3.930771,-2.704357,-2.371563,...,-1.164889,0.938439,-1.136923,-1.990635,-1.283677,2.217344,-1.924003,-1.514813,1.182325,-14.569236
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.173288,-1.146693,-2.45628,-2.941579,0.184543,-3.305616,-2.411296,-2.792666,...,1.034137,-1.321176,-1.04378,-1.338015,-1.725011,2.448851,-2.084122,-0.886382,0.851273,-15.425144


In [15]:
new_imgs_results_big_model.shape

(10, 60)

In [16]:
main_imgs_results_big_model.shape

(3457, 60)

## All scorings from big model

In [17]:
all_imgs_results_big_model = pd.concat([main_imgs_results_big_model, new_imgs_results_big_model], axis=0)
all_imgs_results_big_model.reset_index(drop=True, inplace=True)
all_imgs_results_big_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,ProductType_JennyBakery
0,20231222_0151,OtherDriedFood_100-199g_NonHalal_NonHealthy,-2.706577,-2.922772,-1.214439,-1.865561,0.283586,-3.440058,-4.692658,-4.688087,...,0.065835,-3.455903,-0.858137,-2.530261,-2.970333,-3.660007,3.940337,-4.098905,3.530842,-15.709644
1,20231215_output_frame_0189,CornChip_1-99g_NonHalal_NonHealthy,-3.233564,-3.182215,-2.141094,-3.1284,-0.067066,-0.05715,-4.155023,-3.69233,...,3.715246,-3.54175,-1.947755,-2.929745,-3.327941,-2.904644,2.911234,-3.37238,3.274534,-14.892749
2,20231222_0128,OtherDriedFood_100-199g_NonHalal_NonHealthy,-2.389156,-2.581268,-0.982575,-1.549859,-0.02521,-3.036732,-4.422239,-4.099058,...,-0.187478,-3.295898,-0.533032,-2.493292,-2.85114,-3.369161,3.539039,-3.949286,3.203784,-15.631121
3,2023_10_25_11_18_47_935674,AdultMilk_1000-1999g_Halal_NonHealthy,7.135806,-1.507035,-0.294488,-2.144132,0.031597,-5.09254,-3.257587,-4.648274,...,-1.605096,1.296935,-2.679279,0.709327,-0.644222,4.02416,-2.919567,-4.004476,4.033983,-14.632346
4,20231222_0869,Pasta_500-599g_Halal_NonHealthy,-3.143818,-2.75022,-2.700432,-2.682302,-1.469182,-4.778887,-4.106943,-2.820279,...,7.507227,-2.209881,-0.109595,-3.094495,-3.523381,3.533983,-2.916756,-3.851676,3.714993,-14.79223


In [18]:
all_imgs_results_big_model.tail()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,ProductType_JennyBakery
3462,5181704785427_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.085035,-0.69423,-1.524884,-2.258776,-0.523402,-3.853667,-2.045266,-2.82182,...,-0.60893,-1.379059,-0.607297,-0.91048,-1.412673,3.083356,-2.480405,0.093135,-0.154447,-14.92466
3463,5191704785428_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.781084,-0.384382,-1.344725,-2.581363,-0.442115,-3.333997,-1.991416,-3.166846,...,1.277482,-1.916547,-0.338003,-0.876995,-1.690415,3.801517,-3.136102,0.626521,-0.60006,-14.891737
3464,5201704785430_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.312832,-1.113134,-2.11366,-2.963626,-0.180593,-3.317151,-2.546812,-3.186569,...,1.074274,-2.676588,-1.309397,-1.522568,-2.150415,2.411772,-1.994419,-0.605169,0.582726,-14.524239
3465,5211704785432_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.179433,-0.726598,-2.506155,-2.856576,-0.116809,-3.409679,-1.881398,-2.863961,...,1.071388,-1.775874,-0.696511,-1.079474,-1.285486,3.013429,-2.538661,0.530294,-0.569061,-14.201457
3466,5221704785433_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.128102,-0.579057,-2.348396,-2.582589,-0.019272,-3.056315,-1.428041,-2.344664,...,-0.28426,-1.116279,-0.539162,-0.89878,-1.04927,3.385865,-2.904839,0.386792,-0.330786,-14.979305


In [19]:
all_imgs_results_big_model.shape

(3467, 60)

In [20]:
all_imgs_results_big_model.to_csv("all_imgs_results_big_model.csv", index=True)

# Small model

## Model loading

In [21]:
class MultiHeadResNet_SmallModel(nn.Module):
    def __init__(self, num_classes_prdtype, num_classes_weight, num_classes_halal, num_classes_healthy):
        super(MultiHeadResNet_SmallModel, self).__init__()
        self.base_model = models.resnet18(pretrained=True)
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Identity()

        # Define custom fully connected layers for each prediction head
        self.fc_prdtype = nn.Linear(num_ftrs, num_classes_prdtype)
        self.fc_weight = nn.Linear(num_ftrs, num_classes_weight)
        self.fc_halal = nn.Linear(num_ftrs, num_classes_halal)
        self.fc_healthy = nn.Linear(num_ftrs, num_classes_healthy)

    def forward(self, x):
        x = self.base_model(x)
        prdtype = self.fc_prdtype(x)
        weight = self.fc_weight(x)
        halal = self.fc_halal(x)
        healthy = self.fc_healthy(x)
        return prdtype, weight, halal, healthy

    
# load label encoder 
def load_label_encoder_small_model():
    le_prdtype = pickle.loads(open("../small_model/output/le_prdtype.pickle", "rb").read())
    le_weight = pickle.loads(open("../small_model/output/le_weight.pickle", "rb").read())
    le_halal = pickle.loads(open("../small_model/output/le_halal.pickle", "rb").read())
    le_healthy = pickle.loads(open("../small_model/output/le_healthy.pickle", "rb").read())
    
    return le_prdtype, le_weight, le_halal, le_healthy

le_prdtype, le_weight, le_halal, le_healthy = load_label_encoder_small_model()

# Load the trained MultiHeadResNet model
def load_model():
    # Verify the number of classes for each label
    num_classes_prdtype = len(le_prdtype.classes_)
    num_classes_weight = len(le_weight.classes_)
    num_classes_halal = len(le_halal.classes_)
    num_classes_healthy = len(le_healthy.classes_)
    # print(num_classes_prdtype)
    # print(num_classes_healthy)

    custom_resnet_model = MultiHeadResNet_SmallModel(
        num_classes_prdtype=num_classes_prdtype,
        num_classes_weight=num_classes_weight,
        num_classes_halal=num_classes_halal,
        num_classes_healthy=num_classes_healthy
    )

    model_path = '../small_model/output/multi_head_model.pth'
    # print("test1")
    if os.path.exists(model_path):
        custom_resnet_model.load_state_dict(torch.load(model_path, map_location=CONFIGS['DEVICE']))
    else:
        raise FileNotFoundError(f"Model file not found: {model_path}")
    # print("test2")
    custom_resnet_model.to(CONFIGS['DEVICE'])
    custom_resnet_model.eval()
    return custom_resnet_model
 
small_model = load_model()

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


## Scoring on new imgs

In [22]:
new_imgs_df = pd.read_csv("../small_model/new_imgs_list.csv")
new_imgs_df.reset_index(drop=True, inplace=True)
new_imgs_df.head()

Unnamed: 0,filepath,label,ProductType,Weight,HalalStatus,HealthStatus
0,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
1,5141704785419_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
2,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy


In [23]:
new_imgs_results_small_model = pd.read_csv("../small_model/new_imgs_results_small_model.csv")
new_imgs_results_small_model = new_imgs_results_small_model.loc[new_imgs_results_small_model.Filename.isin(new_imgs_df.filepath)]
new_imgs_results_small_model.reset_index(drop=True, inplace=True)
new_imgs_results_small_model.head()

Unnamed: 0,Filename,CorrectTotalLabel,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5191704785428_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.73214,-3.000211,-0.198857,0.194554,-0.865558,-4.121096,-4.387307,-1.700306,...,-1.372936,6.730967,-2.435776,-2.367831,-3.973673,-1.865151,-2.280402,3.703254,-3.360344,2.330563
1,5211704785432_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.367036,-2.281196,-0.743124,-0.448515,-0.53996,-2.552695,-3.838513,-1.861633,...,-0.718885,6.298512,-1.400892,-2.321977,-2.438059,-1.812233,-2.273663,2.617148,-3.586301,2.025738
2,5201704785430_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.268956,-2.461139,-0.739557,-0.349508,-0.425265,-2.53597,-3.361694,-1.725807,...,-0.215318,6.07436,-0.978758,-2.304519,-2.627634,-1.396129,-1.958033,1.95553,-3.441591,2.107739
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.9195,-2.373328,-1.200829,-0.910363,-1.116349,-3.314448,-6.152689,-3.126892,...,-1.2912,8.329418,-0.659951,-3.034494,-3.698917,-3.169655,-2.345009,3.980706,-4.012096,2.409551
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.867137,-1.932994,-0.826453,-0.438431,-0.686115,-2.052685,-3.965569,-1.633217,...,-0.831345,6.307858,-0.90048,-2.531227,-2.455461,-1.456185,-2.266023,2.336849,-4.000704,2.411709


In [24]:
new_imgs_results_small_model.shape

(10, 60)

In [25]:
# Create a copy of the current column names to a list
new_columns = new_imgs_results_small_model.columns.tolist()

# Modify the first two elements
new_columns[0] = 'filepath'
new_columns[1] = 'label'

# Assign the modified list of column names back to the DataFrame
new_imgs_results_small_model.columns = new_columns
new_imgs_results_small_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5191704785428_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.73214,-3.000211,-0.198857,0.194554,-0.865558,-4.121096,-4.387307,-1.700306,...,-1.372936,6.730967,-2.435776,-2.367831,-3.973673,-1.865151,-2.280402,3.703254,-3.360344,2.330563
1,5211704785432_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.367036,-2.281196,-0.743124,-0.448515,-0.53996,-2.552695,-3.838513,-1.861633,...,-0.718885,6.298512,-1.400892,-2.321977,-2.438059,-1.812233,-2.273663,2.617148,-3.586301,2.025738
2,5201704785430_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.268956,-2.461139,-0.739557,-0.349508,-0.425265,-2.53597,-3.361694,-1.725807,...,-0.215318,6.07436,-0.978758,-2.304519,-2.627634,-1.396129,-1.958033,1.95553,-3.441591,2.107739
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.9195,-2.373328,-1.200829,-0.910363,-1.116349,-3.314448,-6.152689,-3.126892,...,-1.2912,8.329418,-0.659951,-3.034494,-3.698917,-3.169655,-2.345009,3.980706,-4.012096,2.409551
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.867137,-1.932994,-0.826453,-0.438431,-0.686115,-2.052685,-3.965569,-1.633217,...,-0.831345,6.307858,-0.90048,-2.531227,-2.455461,-1.456185,-2.266023,2.336849,-4.000704,2.411709


In [26]:
# Check if any name from 'extracted_names' is not in 'df' and add it as a new column
new_prdtype = list(set(all_imgs_results_big_model.columns) - set(new_imgs_results_small_model.columns))

if len(new_prdtype)>0:
    for col in new_prdtype:
        new_imgs_results_small_model[col] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=new_imgs_results_small_model.shape[0])  # Initialize new columns

new_imgs_results_small_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5191704785428_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.73214,-3.000211,-0.198857,0.194554,-0.865558,-4.121096,-4.387307,-1.700306,...,-1.372936,6.730967,-2.435776,-2.367831,-3.973673,-1.865151,-2.280402,3.703254,-3.360344,2.330563
1,5211704785432_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.367036,-2.281196,-0.743124,-0.448515,-0.53996,-2.552695,-3.838513,-1.861633,...,-0.718885,6.298512,-1.400892,-2.321977,-2.438059,-1.812233,-2.273663,2.617148,-3.586301,2.025738
2,5201704785430_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.268956,-2.461139,-0.739557,-0.349508,-0.425265,-2.53597,-3.361694,-1.725807,...,-0.215318,6.07436,-0.978758,-2.304519,-2.627634,-1.396129,-1.958033,1.95553,-3.441591,2.107739
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.9195,-2.373328,-1.200829,-0.910363,-1.116349,-3.314448,-6.152689,-3.126892,...,-1.2912,8.329418,-0.659951,-3.034494,-3.698917,-3.169655,-2.345009,3.980706,-4.012096,2.409551
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.867137,-1.932994,-0.826453,-0.438431,-0.686115,-2.052685,-3.965569,-1.633217,...,-0.831345,6.307858,-0.90048,-2.531227,-2.455461,-1.456185,-2.266023,2.336849,-4.000704,2.411709


In [27]:
new_imgs_results_small_model.shape

(10, 60)

## Scoring on main imgs

In [28]:
main_imgs_master_list = pd.read_csv("../master_list.csv")
main_imgs_master_list.head()

Unnamed: 0,filepath,xmin,ymin,xmax,ymax,label,ProductType,Weight,HalalStatus,HealthStatus,new_camera,tag
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,151,42,497,591,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,88,81,442,567,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,35,34,492,622,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,99,122,428,587,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,103,17,474,592,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,


In [29]:
main_imgs_results_small_model = []  # List to store the results
le_prdtype, le_weight, le_halal, le_healthy = load_label_encoder_small_model()

for idx, row in main_imgs_master_list.iterrows():
    image_path = "../all_images/" + row['filepath']
    frame = cv2.imread(image_path)

    # Preprocessing steps
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = cv2.resize(frame, (CONFIGS['SMALL_MODEL_IMG_SIZE'], CONFIGS['SMALL_MODEL_IMG_SIZE']))
    frame = frame.transpose((2, 0, 1))
    frame = torch.from_numpy(frame).float()
    frame = transforms_test(frame).unsqueeze(0).to(CONFIGS['DEVICE'])

    # Perform prediction
    with torch.no_grad():
        out1, out2, out3, out4 = small_model(frame)
    
    # Extract and store the results
    prediction_row = [row['filepath'], row['label']]
    prediction_row.extend(out1.cpu().numpy().flatten())
    prediction_row.extend(out2.cpu().numpy().flatten())
    prediction_row.extend(out3.cpu().numpy().flatten())
    prediction_row.extend(out4.cpu().numpy().flatten())
    main_imgs_results_small_model.append(prediction_row)


# Define column names for the new DataFrame
column_names = ['filepath', 'label']
column_names += ['ProductType_' + name for name in le_prdtype.classes_]
column_names += ['Weight_' + name for name in le_weight.classes_]
column_names += ['HalalStatus_' + name for name in le_halal.classes_]
column_names += ['HealthStatus_' + name for name in le_healthy.classes_]


# Create the DataFrame
main_imgs_results_small_model = pd.DataFrame(main_imgs_results_small_model, columns=column_names)
main_imgs_results_small_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,Sugar_400-499g_NonHalal_NonHealthy,-1.953461,-1.542739,-0.997972,0.460374,-2.042936,-2.646236,-3.269565,-0.885203,...,5.276489,-1.422885,-2.510418,-2.427734,1.232989,-1.789387,-0.594374,0.807424,-1.844332,3.114218
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,Sugar_400-499g_NonHalal_NonHealthy,-1.794917,-2.511401,-1.626677,-0.188273,-1.597507,-2.022752,-1.949365,-0.757312,...,3.734361,-2.239303,-2.30619,-1.83063,2.118354,-2.629922,-1.808498,1.737238,-2.04311,2.796489
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,Sugar_400-499g_NonHalal_NonHealthy,-3.244689,-2.878056,-0.071886,-0.601693,1.579423,-4.955077,-2.63188,-0.901233,...,9.777389,-1.973637,-1.529812,-5.475394,0.709039,-2.600585,-3.76091,4.619157,-4.566307,5.265439
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,Sugar_400-499g_NonHalal_NonHealthy,-1.286951,-0.810264,-2.171451,-1.45576,-1.87449,-1.876697,-2.205654,-1.915093,...,3.40488,-2.823619,-1.256304,-1.315832,0.255522,-3.013819,-1.771451,2.683984,-1.932551,2.173832
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,Sugar_400-499g_NonHalal_NonHealthy,-2.4259,-3.212403,-0.505121,0.699909,0.794093,-4.230815,-2.915214,-1.216002,...,6.881946,0.698764,-3.240481,-4.012082,-1.049979,-2.509206,-1.758401,3.187051,-3.176752,4.081141


In [30]:
main_imgs_results_small_model.shape

(3457, 60)

In [31]:
# Check if any name from 'extracted_names' is not in 'df' and add it as a new column
new_prdtype = list(set(all_imgs_results_big_model.columns) - set(main_imgs_results_small_model.columns))

if len(new_prdtype)>0:
    for col in new_prdtype:
        main_imgs_results_small_model[col] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=main_imgs_results_small_model.shape[0])  # Initialize new columns

main_imgs_results_small_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,Sugar_400-499g_NonHalal_NonHealthy,-1.953461,-1.542739,-0.997972,0.460374,-2.042936,-2.646236,-3.269565,-0.885203,...,5.276489,-1.422885,-2.510418,-2.427734,1.232989,-1.789387,-0.594374,0.807424,-1.844332,3.114218
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,Sugar_400-499g_NonHalal_NonHealthy,-1.794917,-2.511401,-1.626677,-0.188273,-1.597507,-2.022752,-1.949365,-0.757312,...,3.734361,-2.239303,-2.30619,-1.83063,2.118354,-2.629922,-1.808498,1.737238,-2.04311,2.796489
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,Sugar_400-499g_NonHalal_NonHealthy,-3.244689,-2.878056,-0.071886,-0.601693,1.579423,-4.955077,-2.63188,-0.901233,...,9.777389,-1.973637,-1.529812,-5.475394,0.709039,-2.600585,-3.76091,4.619157,-4.566307,5.265439
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,Sugar_400-499g_NonHalal_NonHealthy,-1.286951,-0.810264,-2.171451,-1.45576,-1.87449,-1.876697,-2.205654,-1.915093,...,3.40488,-2.823619,-1.256304,-1.315832,0.255522,-3.013819,-1.771451,2.683984,-1.932551,2.173832
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,Sugar_400-499g_NonHalal_NonHealthy,-2.4259,-3.212403,-0.505121,0.699909,0.794093,-4.230815,-2.915214,-1.216002,...,6.881946,0.698764,-3.240481,-4.012082,-1.049979,-2.509206,-1.758401,3.187051,-3.176752,4.081141


In [32]:
main_imgs_results_small_model.shape

(3457, 60)

## All scorings from small model

In [33]:
all_imgs_results_small_model = pd.concat([main_imgs_results_small_model, new_imgs_results_small_model], axis=0)
all_imgs_results_small_model.reset_index(drop=True, inplace=True)
all_imgs_results_small_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,Sugar_400-499g_NonHalal_NonHealthy,-1.953461,-1.542739,-0.997972,0.460374,-2.042936,-2.646236,-3.269565,-0.885203,...,5.276489,-1.422885,-2.510418,-2.427734,1.232989,-1.789387,-0.594374,0.807424,-1.844332,3.114218
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,Sugar_400-499g_NonHalal_NonHealthy,-1.794917,-2.511401,-1.626677,-0.188273,-1.597507,-2.022752,-1.949365,-0.757312,...,3.734361,-2.239303,-2.30619,-1.83063,2.118354,-2.629922,-1.808498,1.737238,-2.04311,2.796489
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,Sugar_400-499g_NonHalal_NonHealthy,-3.244689,-2.878056,-0.071886,-0.601693,1.579423,-4.955077,-2.63188,-0.901233,...,9.777389,-1.973637,-1.529812,-5.475394,0.709039,-2.600585,-3.76091,4.619157,-4.566307,5.265439
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,Sugar_400-499g_NonHalal_NonHealthy,-1.286951,-0.810264,-2.171451,-1.45576,-1.87449,-1.876697,-2.205654,-1.915093,...,3.40488,-2.823619,-1.256304,-1.315832,0.255522,-3.013819,-1.771451,2.683984,-1.932551,2.173832
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,Sugar_400-499g_NonHalal_NonHealthy,-2.4259,-3.212403,-0.505121,0.699909,0.794093,-4.230815,-2.915214,-1.216002,...,6.881946,0.698764,-3.240481,-4.012082,-1.049979,-2.509206,-1.758401,3.187051,-3.176752,4.081141


In [34]:
all_imgs_results_small_model.tail()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
3462,5221704785433_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.907207,-2.304957,-0.337367,-0.218335,-0.591471,-2.655747,-3.853101,-1.22437,...,-0.53311,6.911612,-1.540647,-2.023374,-2.235946,-1.670889,-2.494186,2.710708,-3.854862,2.337982
3463,5181704785427_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.292028,-2.185293,0.262104,0.405771,-0.62938,-4.159621,-4.812106,-1.642707,...,-0.389206,8.422403,-1.084105,-2.937346,-3.019527,-2.680791,-2.630916,3.818831,-3.281815,2.328891
3464,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.015124,-2.114432,0.85622,-0.371208,-2.311827,-2.53076,-3.636174,-3.03951,...,-1.151676,3.027241,-1.234231,-0.329523,-3.034024,-2.849426,-1.300969,1.714323,-2.70315,1.817381
3465,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.440625,-1.464343,-0.454979,-0.059144,-1.00937,-2.70346,-3.949614,-1.50936,...,-1.283507,5.098607,-2.437192,-1.084394,-2.959121,-2.58699,-1.724097,2.619304,-4.050291,1.825536
3466,5141704785419_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.787833,-1.96039,0.144495,-0.061411,-1.025943,-1.64722,-3.985024,-1.873846,...,-1.052803,3.858581,-1.255909,-1.349741,-2.695801,-1.702375,-1.558985,1.279799,-3.167977,2.474307


In [35]:
all_imgs_results_small_model.shape

(3467, 60)

In [36]:
all_imgs_results_small_model.to_csv("all_imgs_results_small_model.csv", index=True)

# Bayesian model

In [37]:
prdtype_cols = [col for col in all_imgs_results_small_model.columns if col.startswith('ProductType_')]

In [38]:
all_imgs_results_small_model_prdtype = all_imgs_results_small_model[['label']+prdtype_cols]
all_imgs_results_big_model_prdtype = all_imgs_results_big_model[['label']+prdtype_cols]
all_imgs_results_small_model_prdtype = all_imgs_results_small_model_prdtype.sort_values(by='label').reset_index(drop=True)
all_imgs_results_big_model_prdtype = all_imgs_results_big_model_prdtype.sort_values(by='label').reset_index(drop=True)

In [39]:
assert (all_imgs_results_small_model_prdtype['label'][(all_imgs_results_small_model_prdtype['label'] == all_imgs_results_big_model_prdtype['label'])]).all()

In [40]:
all_imgs_results_small_model_prdtype['label_prdtype'] = all_imgs_results_small_model_prdtype['label'].str.split('_').str[0]
all_imgs_results_big_model_prdtype['label_prdtype'] = all_imgs_results_big_model_prdtype['label'].str.split('_').str[0]

In [41]:
# Remove the prefix from column names
all_imgs_results_small_model_prdtype.columns = [col.replace("ProductType_", '') if col.startswith("ProductType_") else col for col in all_imgs_results_small_model_prdtype.columns]
all_imgs_results_big_model_prdtype.columns = [col.replace("ProductType_", '') if col.startswith("ProductType_") else col for col in all_imgs_results_big_model_prdtype.columns]

In [42]:
prdtype_label_encoder = LabelEncoder()
truelabel = prdtype_label_encoder.fit_transform(all_imgs_results_big_model_prdtype['label_prdtype'])

In [43]:
# Assuming 'category_names' is the list of unique category names in the order they appear in logitscoresA
category_names = list(all_imgs_results_small_model_prdtype['label_prdtype'].unique())
category_to_encoded = {name: prdtype_label_encoder.transform([name])[0] for name in category_names}

# Reorder columns of logitscoresA and logitscoresB to match the order of encoded labels
ordered_columns = [category_names[i] for i in prdtype_label_encoder.transform(category_names)]
logitscoresA = all_imgs_results_big_model_prdtype[ordered_columns].values
logitscoresB = all_imgs_results_small_model_prdtype[ordered_columns].values


In [44]:
all_imgs_results_big_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,ProductType_JennyBakery
0,20231222_0151,OtherDriedFood_100-199g_NonHalal_NonHealthy,-2.706577,-2.922772,-1.214439,-1.865561,0.283586,-3.440058,-4.692658,-4.688087,...,0.065835,-3.455903,-0.858137,-2.530261,-2.970333,-3.660007,3.940337,-4.098905,3.530842,-15.709644
1,20231215_output_frame_0189,CornChip_1-99g_NonHalal_NonHealthy,-3.233564,-3.182215,-2.141094,-3.1284,-0.067066,-0.05715,-4.155023,-3.69233,...,3.715246,-3.54175,-1.947755,-2.929745,-3.327941,-2.904644,2.911234,-3.37238,3.274534,-14.892749
2,20231222_0128,OtherDriedFood_100-199g_NonHalal_NonHealthy,-2.389156,-2.581268,-0.982575,-1.549859,-0.02521,-3.036732,-4.422239,-4.099058,...,-0.187478,-3.295898,-0.533032,-2.493292,-2.85114,-3.369161,3.539039,-3.949286,3.203784,-15.631121
3,2023_10_25_11_18_47_935674,AdultMilk_1000-1999g_Halal_NonHealthy,7.135806,-1.507035,-0.294488,-2.144132,0.031597,-5.09254,-3.257587,-4.648274,...,-1.605096,1.296935,-2.679279,0.709327,-0.644222,4.02416,-2.919567,-4.004476,4.033983,-14.632346
4,20231222_0869,Pasta_500-599g_Halal_NonHealthy,-3.143818,-2.75022,-2.700432,-2.682302,-1.469182,-4.778887,-4.106943,-2.820279,...,7.507227,-2.209881,-0.109595,-3.094495,-3.523381,3.533983,-2.916756,-3.851676,3.714993,-14.79223


In [45]:
# big model accuracy - total
pred_big_model_prdtype = np.argmax(logitscoresA, axis=1)
sum(pred_big_model_prdtype == truelabel) / len(truelabel)

0.9786558984713009

In [46]:
# small model accuracy - total
pred_small_model_prdtype = np.argmax(logitscoresB, axis=1)
sum(pred_small_model_prdtype == truelabel) / len(truelabel)

0.6509950966253245

In [47]:
# big model accuracy - new imgs
indices = np.where(truelabel == category_to_encoded['JennyBakery'])
sum(pred_big_model_prdtype[indices] == truelabel[indices]) / len(indices[0].tolist())

0.0

In [48]:
# small model accuracy - new imgs
indices = np.where(truelabel == category_to_encoded['JennyBakery'])
sum(pred_small_model_prdtype[indices] == truelabel[indices]) / len(indices[0].tolist())

1.0

In [49]:
len(truelabel)

3467

In [50]:
logitscoresA.shape[1]

42

In [51]:
ordered_columns

['AdultMilk',
 'BabyMilkPowder',
 'Babyfood',
 'BeehoonVermicelliMeesua',
 'BiscuitsCrackersCookies',
 'Book',
 'BreakfastCerealsCornflakes',
 'CannedPacketCreamersSweet',
 'CerealBeveragePowder',
 'ChilliSauce',
 'Coffee',
 'CornChip',
 'FlavoredMilk',
 'Flour',
 'FruitJuice',
 'HoneyOtherSpreads',
 'InstantMeals',
 'InstantNoodles',
 'JennyBakery',
 'Kaya',
 'MaternalMilkPowder',
 'MiloPowder',
 'NutellaChocolate',
 'Nuts',
 'Oil',
 'OtherBakingNeeds',
 'OtherDriedFood',
 'OtherNoodles',
 'OtherSauceDressingSoupbasePaste',
 'Pasta',
 'Peanutbutter',
 'PotatoSticks',
 'PotatochipsKeropok',
 'RiceBrownOthers',
 'RolledOatsInstantOatmeal',
 'Salt',
 'Sardines',
 'SoftDrinks',
 'Sugar',
 'SweetsChocolatesOthers',
 'TeaPowderLeaves',
 'WetWiper']

In [52]:
# import pymc3 as pm
# import theano.tensor as tt
# import numpy as np
# import scipy.stats

# # Sample data setup (replace with your actual data)
# # logitscoresA and logitscoresB are matrices of logit scores for each category from classifiers A and B
# # truelabel is an already existing 1D array of integers representing the true labels
# indices = [np.random.choice(100, 3, replace=False)]  # Replace with your indices for missing data

# N = len(truelabel)
# L = logitscoresA.shape[1]
# missingidx = indices[0].tolist()  # Indices of missing data

# # Initialize truelabel_with_missing with the original truelabel and set missing indices to -1
# truelabel_with_missing = np.array(truelabel, dtype=np.int)
# truelabel_with_missing[missingidx] = -1

# # Mask the missing values
# masked_truelabel = np.ma.masked_where(truelabel_with_missing == -1, truelabel_with_missing)

# with pm.Model() as model:
#     # Priors
#     muA1 = pm.Normal('muA1', mu=0, sigma=10)
#     muA0 = pm.Normal('muA0', mu=0, sigma=10)
#     sigmaA = pm.Uniform('sigmaA', lower=0.01, upper=1.0)
#     muB1 = pm.Normal('muB1', mu=0, sigma=10)
#     muB0 = pm.Normal('muB0', mu=0, sigma=10)
#     sigmaB = pm.Uniform('sigmaB', lower=0.01, upper=1.0)
#     rho = pm.Uniform('rho', lower=-1, upper=1)
    
#     # Uniform prior over labels
#     labelprob = pm.Dirichlet('labelprob', a=tt.ones(L))

#     # Likelihood
#     muA = pm.math.switch(tt.eq(tt.arange(L), masked_truelabel[:, None]), muA1, muA0)
#     muB = pm.math.switch(tt.eq(tt.arange(L), masked_truelabel[:, None]), muB1, muB0)
    
#     logitscoresA_obs = pm.Normal('logitscoresA_obs', mu=muA, sigma=sigmaA, observed=logitscoresA)
#     logitscoresB_obs = pm.Normal('logitscoresB_obs', mu=muB + rho * (logitscoresA - muA) / sigmaA, sigma=tt.sqrt((1 - rho ** 2) * sigmaB ** 2), observed=logitscoresB)
    
#     # Define the categorical distribution for the true labels
#     truelabel_obs = pm.Categorical('truelabel_obs', p=labelprob, observed=masked_truelabel)

#     # Inference
#     trace = pm.sample(2000, tune=500, cores=1)

#     # Plotting within the model context
#     # az.plot_trace(trace)
#     # plt.show()

#     # Posterior predictive checks
#     ppc = pm.sample_posterior_predictive(trace, var_names=['truelabel_obs'])

# # Process the posterior predictive checks for missing indices
# infer_labels = []
# for idx in missingidx:
#     label_samples = ppc['truelabel_obs'][:, idx]
#     inferred_label = scipy.stats.mode(label_samples).mode[0]
#     infer_labels.append(inferred_label)

# # Output the inferred labels for missing indices
# print("Inferred labels for missing indices:", infer_labels)

In [50]:
np.unique(infer_labels)

array([ 8, 10, 42])

In [51]:
~np.isin(np.arange(N), missingidx)

array([ True,  True,  True, ...,  True,  True,  True])

In [52]:
np.array(truelabel, dtype=np.int)[missingidx]

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  """Entry point for launching an IPython kernel.


array([0, 0, 1])

In [53]:
pred_small_model_prdtype[missingidx]

array([ 0, 11,  0])

In [54]:
pred_big_model_prdtype[missingidx]

array([0, 0, 1])

In [55]:
missingidx

[53, 24, 66]

In [56]:
ppc['truelabel_obs'].shape

(371, 3467)