In [1]:
import pandas as pd
import torch.nn as nn
import pickle
import torch
from torchvision import models
from torchvision.models import detection, resnet50, ResNet50_Weights
import os
import numpy as np
import cv2
from torchvision import transforms
import pymc3 as pm
import theano.tensor as tt
from sklearn.preprocessing import LabelEncoder
import scipy

In [2]:
CONFIGS = {
    # determine the current device and based on that set the pin memory
    # flag
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    # specify ImageNet mean and standard deviation
    "IMG_MEAN": [0.485, 0.456, 0.406],
    "IMG_STD": [0.229, 0.224, 0.225],
    "MC_DROPOUT_ENABLED": False,  # Switch to enable/disable MC Dropout for confidence score
    "NUM_DROPOUT_RUNS": 3,
    "CONFIDENCE_THRESHOLD": 0,
    "BIG_MODEL_IMG_SIZE": 320,
    "SMALL_MODEL_IMG_SIZE": 60,
    "MEAN_PRIOR": -15,
}

# Big model

## Model loading

In [3]:
class MultiHeadResNet_BigModel(nn.Module):
    def __init__(self, num_classes_prdtype, num_classes_weight, num_classes_halal, num_classes_healthy):
        super(MultiHeadResNet_BigModel, self).__init__()
        self.base_model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Identity()

        # Define custom fully connected layers for each prediction head
        self.fc_prdtype = nn.Linear(num_ftrs, num_classes_prdtype)
        self.fc_weight = nn.Linear(num_ftrs, num_classes_weight)
        self.fc_halal = nn.Linear(num_ftrs, num_classes_halal)
        self.fc_healthy = nn.Linear(num_ftrs, num_classes_healthy)
        self.fc_bbox = nn.Linear(num_ftrs, 4)

    def forward(self, x):
        x = self.base_model(x)
        prdtype = self.fc_prdtype(x)
        weight = self.fc_weight(x)
        halal = self.fc_halal(x)
        healthy = self.fc_healthy(x)
        box = self.fc_bbox(x)
        return prdtype, weight, halal, healthy, box

    
# load label encoder 
def load_label_encoder_big_model():
    le_prdtype = pickle.loads(open("../big_model/le_prdtype.pickle", "rb").read())
    le_weight = pickle.loads(open("../big_model/le_weight.pickle", "rb").read())
    le_halal = pickle.loads(open("../big_model/le_halal.pickle", "rb").read())
    le_healthy = pickle.loads(open("../big_model/le_healthy.pickle", "rb").read())
    
    return le_prdtype, le_weight, le_halal, le_healthy

le_prdtype, le_weight, le_halal, le_healthy = load_label_encoder_big_model()

# Load the trained MultiHeadResNet model
def load_model():
    # Verify the number of classes for each label
    num_classes_prdtype = len(le_prdtype.classes_)
    num_classes_weight = len(le_weight.classes_)
    num_classes_halal = len(le_halal.classes_)
    num_classes_healthy = len(le_healthy.classes_)
    # print(num_classes_prdtype)
    # print(num_classes_healthy)

    custom_resnet_model = MultiHeadResNet_BigModel(
        num_classes_prdtype=num_classes_prdtype,
        num_classes_weight=num_classes_weight,
        num_classes_halal=num_classes_halal,
        num_classes_healthy=num_classes_healthy
    )

    model_path = '../big_model/multi_head_model.pth'
    # print("test1")
    if os.path.exists(model_path):
        custom_resnet_model.load_state_dict(torch.load(model_path, map_location=CONFIGS['DEVICE']))
    else:
        raise FileNotFoundError(f"Model file not found: {model_path}")
    # print("test2")
    custom_resnet_model.to(CONFIGS['DEVICE'])
    custom_resnet_model.eval()
    return custom_resnet_model

big_model = load_model()

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


## Scoring on main imgs

In [4]:
main_imgs_results_big_model = pd.read_csv("../big_model/main_imgs_results_big_model.csv")
main_imgs_results_big_model.head()

Unnamed: 0,Filename,CorrectTotalLabel,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,IMG_3342_jpeg.rf.66db10f7864d752fb7976d5e2a0d2...,Nuts_1-99g_Halal_NonHealthy,-2.081311,-2.902017,-1.082917,-1.917268,0.513138,-2.056075,-3.103575,-4.815222,...,-1.519879,-2.695246,-2.423932,-3.589839,-3.066173,-3.762192,2.367532,-3.054208,-3.248664,3.118231
1,IMG_6525_jpeg.rf.4819ad251c1c4c1e1271728375ee7...,HoneyOtherSpreads_100-199g_Halal_NonHealthy,-3.528679,-2.220776,-1.372103,-3.527915,-0.806374,-5.616762,-2.908732,-3.535326,...,1.601176,-1.952339,-4.151642,-2.863147,-3.851684,-4.587035,2.521691,-3.121677,-3.247515,2.68848
2,IMG_6549_jpeg.rf.99b9ea7d28228f9e89972b745ddde...,HoneyOtherSpreads_400-499g_Halal_NonHealthy,-3.442461,-1.955826,-3.917761,-3.072334,-1.013375,-5.279501,-2.093924,-1.98879,...,7.286262,-1.452695,-3.765491,-3.505961,-3.42628,-3.930771,2.649693,-3.132369,-4.636736,4.17741
3,IMG_6835_jpeg.rf.2c219ac88826ee3c0452fff0f1964...,Babyfood_1-99g_Halal_NonHealthy,-0.814915,-0.093799,6.678509,-1.392194,-1.317889,-4.297375,-3.702756,-2.697277,...,-0.794399,-2.359164,-1.809866,-2.472234,-2.895805,-3.457967,3.939218,-4.644306,-4.465299,3.686565
4,Crackers_200-299g_0311_NonHalal_6_png.rf.d36bf...,BiscuitsCrackersCookies_200-299g_NonHalal_NonH...,-2.95755,-1.659869,-1.340207,-0.763084,8.276548,-4.057633,-4.451503,-3.005588,...,-1.978522,0.378352,-1.909216,-2.989162,-1.007948,-3.747228,-2.865301,2.106453,-3.784013,3.310216


In [5]:
# Create a copy of the current column names to a list
new_columns = main_imgs_results_big_model.columns.tolist()

# Modify the first two elements
new_columns[0] = 'filepath'
new_columns[1] = 'label'

# Assign the modified list of column names back to the DataFrame
main_imgs_results_big_model.columns = new_columns


In [6]:
big_model_pred_col_name_original = main_imgs_results_big_model.columns[2:].tolist()

In [7]:
new_imgs_results_small_model = pd.read_csv("../small_model/new_imgs_results_small_model.csv")
new_imgs_results_small_model.head()

Unnamed: 0,Filename,CorrectTotalLabel,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,IMG_3533_jpeg.rf.7c479d2b82aa319692d4c74ba4acf...,Nuts_300-399g_NonHalal_NonHealthy,-2.493247,-2.45811,-0.752643,3.418245,1.410686,-1.484503,-3.675712,-4.696049,...,1.202736,-0.929754,-2.064444,-3.176098,-1.962588,-3.577293,-3.286956,2.287906,-4.930899,2.953754
1,IMG_2282_JPG.rf.55ff8559732da3aeccad420b17c3f9...,BeehoonVermicelliMeesua_200-299g_Halal_NonHealthy,0.131873,0.044445,-0.519548,7.003169,-0.729551,-3.285197,-1.949823,-2.659936,...,-0.359658,0.070565,-1.163762,-2.505035,-1.926098,-2.406361,3.796422,-2.279831,-1.371966,1.770485
2,IMG_20230428_123200_jpg.rf.34cc17cc4cb5a707327...,Sugar_800-899g_NonHalal_NonHealthy,-2.153643,2.140464,-0.861953,-2.464264,0.675603,-1.984505,-1.661865,-4.919771,...,1.070148,-1.785139,-3.493709,-2.736918,6.79421,-0.917141,-3.683269,4.121268,-4.709025,1.865833
3,Crackers_200-299g_0311_Halal_25_png.rf.c55545c...,BiscuitsCrackersCookies_200-299g_Halal_NonHealthy,-1.230329,-2.523535,0.557129,-2.749805,8.67153,-4.183312,-3.959238,-3.233744,...,0.550248,-2.701535,-0.062121,-1.760023,-0.467273,-4.762103,3.189244,-1.663041,-4.48652,1.161967
4,IMG_6362_jpeg.rf.16a8b744b3be105dc728d6b6bfefc...,OtherBakingNeeds_900-999g_Halal_NonHealthy,-0.172884,1.589989,-0.664544,-1.997345,-0.707637,-3.679213,-1.114295,-2.14486,...,-1.347105,-3.294689,-0.024379,-3.188744,1.041424,6.573578,2.014626,-1.763547,-4.737069,4.561938


In [8]:
# Extract column names that start with 'ProductType'
all_prdtypes_new_imgs = [col for col in new_imgs_results_small_model.columns if col.startswith('ProductType')]
# all_prdtypes_new_imgs = [col.split('_', 1)[1] for col in all_prdtypes_new_imgs]
# all_prdtypes_new_imgs

In [9]:
# Check if any name from 'extracted_names' is not in 'df' and add it as a new column
new_prdtype = list(set(all_prdtypes_new_imgs) - set(main_imgs_results_big_model.columns))

if len(new_prdtype)==1:
    main_imgs_results_big_model[new_prdtype[0]] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=main_imgs_results_big_model.shape[0])  # Initialize new columns

main_imgs_results_big_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,IMG_3342_jpeg.rf.66db10f7864d752fb7976d5e2a0d2...,Nuts_1-99g_Halal_NonHealthy,-2.081311,-2.902017,-1.082917,-1.917268,0.513138,-2.056075,-3.103575,-4.815222,...,-1.519879,-2.695246,-2.423932,-3.589839,-3.066173,-3.762192,2.367532,-3.054208,-3.248664,3.118231
1,IMG_6525_jpeg.rf.4819ad251c1c4c1e1271728375ee7...,HoneyOtherSpreads_100-199g_Halal_NonHealthy,-3.528679,-2.220776,-1.372103,-3.527915,-0.806374,-5.616762,-2.908732,-3.535326,...,1.601176,-1.952339,-4.151642,-2.863147,-3.851684,-4.587035,2.521691,-3.121677,-3.247515,2.68848
2,IMG_6549_jpeg.rf.99b9ea7d28228f9e89972b745ddde...,HoneyOtherSpreads_400-499g_Halal_NonHealthy,-3.442461,-1.955826,-3.917761,-3.072334,-1.013375,-5.279501,-2.093924,-1.98879,...,7.286262,-1.452695,-3.765491,-3.505961,-3.42628,-3.930771,2.649693,-3.132369,-4.636736,4.17741
3,IMG_6835_jpeg.rf.2c219ac88826ee3c0452fff0f1964...,Babyfood_1-99g_Halal_NonHealthy,-0.814915,-0.093799,6.678509,-1.392194,-1.317889,-4.297375,-3.702756,-2.697277,...,-0.794399,-2.359164,-1.809866,-2.472234,-2.895805,-3.457967,3.939218,-4.644306,-4.465299,3.686565
4,Crackers_200-299g_0311_NonHalal_6_png.rf.d36bf...,BiscuitsCrackersCookies_200-299g_NonHalal_NonH...,-2.95755,-1.659869,-1.340207,-0.763084,8.276548,-4.057633,-4.451503,-3.005588,...,-1.978522,0.378352,-1.909216,-2.989162,-1.007948,-3.747228,-2.865301,2.106453,-3.784013,3.310216


## Scoring on new imgs

In [10]:
new_imgs_df = pd.read_csv("../small_model/new_imgs_list.csv")

# ADHOC: change the new imgs to existing type
new_imgs_df['label'] = 'AdultMilk_1-99g_Halal_NonHealthy'
new_imgs_df['ProductType'] = 'AdultMilk'
new_imgs_df['Weight'] = '1-99g'
new_imgs_df['HalalStatus'] = 'Halal'
new_imgs_df['HealthStatus'] = 'NonHealthy'

new_imgs_df.reset_index(drop=True, inplace=True)
new_imgs_df.head()

Unnamed: 0,filepath,label,ProductType,Weight,HalalStatus,HealthStatus
0,5131704785418_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,AdultMilk,1-99g,Halal,NonHealthy
1,5141704785419_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,AdultMilk,1-99g,Halal,NonHealthy
2,5151704785420_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,AdultMilk,1-99g,Halal,NonHealthy
3,5161704785422_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,AdultMilk,1-99g,Halal,NonHealthy
4,5171704785423_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,AdultMilk,1-99g,Halal,NonHealthy


In [11]:
transforms_test = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=CONFIGS['IMG_MEAN'], std=CONFIGS['IMG_STD'])
])

In [12]:
new_imgs_results_big_model = []  # List to store the results

for idx, row in new_imgs_df.iterrows():
    image_path = "../small_model/new_imgs/" + row['filepath']
    frame = cv2.imread(image_path)

    # Preprocessing steps
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = cv2.resize(frame, (CONFIGS['BIG_MODEL_IMG_SIZE'], CONFIGS['BIG_MODEL_IMG_SIZE']))
    frame = frame.transpose((2, 0, 1))
    frame = torch.from_numpy(frame).float()
    frame = transforms_test(frame).unsqueeze(0).to(CONFIGS['DEVICE'])

    # Perform prediction
    with torch.no_grad():
        out1, out2, out3, out4, _ = big_model(frame)

    # Extract and store the results
    prediction_row = [row['filepath'], row['label']]
    prediction_row.extend(out1.cpu().numpy().flatten())
    prediction_row.extend(out2.cpu().numpy().flatten())
    prediction_row.extend(out3.cpu().numpy().flatten())
    prediction_row.extend(out4.cpu().numpy().flatten())
    new_imgs_results_big_model.append(prediction_row)


# Define column names for the new DataFrame
column_names = ['filepath', 'label']
column_names += big_model_pred_col_name_original

# Create the DataFrame
new_imgs_results_big_model = pd.DataFrame(new_imgs_results_big_model, columns=column_names)
new_imgs_results_big_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5131704785418_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-3.039913,-1.142112,-2.089395,-2.328363,-0.22226,-3.591941,-1.837839,-3.071156,...,0.525842,0.025882,-2.101103,-0.797072,-1.364065,-2.148964,4.077797,-3.363919,1.234696,-1.354995
1,5141704785419_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-2.231967,-0.849782,-0.971154,-2.580669,-0.117418,-3.130956,-2.25668,-3.139605,...,-0.271013,-0.138061,-1.910356,-0.641025,-1.320471,-1.997742,3.420116,-2.79764,0.566004,-0.747712
2,5151704785420_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-3.013698,-0.999261,-1.92391,-3.230182,0.373773,-3.073135,-1.915508,-2.893718,...,0.532749,-0.740819,-1.070358,-0.64976,-1.268365,-1.791995,3.33759,-2.694616,0.84793,-1.066223
3,5161704785422_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-2.591283,-0.559135,-1.444255,-3.373103,-0.162989,-3.930771,-2.704357,-2.371563,...,0.617335,-1.164889,0.938439,-1.136923,-1.990635,-1.283677,2.217344,-1.924003,-1.514813,1.182325
4,5171704785423_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-2.173288,-1.146693,-2.45628,-2.941579,0.184543,-3.305616,-2.411296,-2.792666,...,-0.431126,1.034137,-1.321176,-1.04378,-1.338015,-1.725011,2.448851,-2.084122,-0.886382,0.851273


In [13]:
new_imgs_results_big_model.shape

(10, 59)

In [14]:
if len(new_prdtype)==1:
    new_imgs_results_big_model[new_prdtype[0]] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=new_imgs_results_big_model.shape[0])  # Initialize new columns

new_imgs_results_big_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5131704785418_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-3.039913,-1.142112,-2.089395,-2.328363,-0.22226,-3.591941,-1.837839,-3.071156,...,0.525842,0.025882,-2.101103,-0.797072,-1.364065,-2.148964,4.077797,-3.363919,1.234696,-1.354995
1,5141704785419_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-2.231967,-0.849782,-0.971154,-2.580669,-0.117418,-3.130956,-2.25668,-3.139605,...,-0.271013,-0.138061,-1.910356,-0.641025,-1.320471,-1.997742,3.420116,-2.79764,0.566004,-0.747712
2,5151704785420_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-3.013698,-0.999261,-1.92391,-3.230182,0.373773,-3.073135,-1.915508,-2.893718,...,0.532749,-0.740819,-1.070358,-0.64976,-1.268365,-1.791995,3.33759,-2.694616,0.84793,-1.066223
3,5161704785422_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-2.591283,-0.559135,-1.444255,-3.373103,-0.162989,-3.930771,-2.704357,-2.371563,...,0.617335,-1.164889,0.938439,-1.136923,-1.990635,-1.283677,2.217344,-1.924003,-1.514813,1.182325
4,5171704785423_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-2.173288,-1.146693,-2.45628,-2.941579,0.184543,-3.305616,-2.411296,-2.792666,...,-0.431126,1.034137,-1.321176,-1.04378,-1.338015,-1.725011,2.448851,-2.084122,-0.886382,0.851273


In [15]:
new_imgs_results_big_model.shape

(10, 59)

In [16]:
main_imgs_results_big_model.shape

(3457, 59)

## All scorings from big model

In [17]:
main_imgs_results_big_model['img_type'] = "existing"
new_imgs_results_big_model['img_type'] = "new"
all_imgs_results_big_model = pd.concat([main_imgs_results_big_model, new_imgs_results_big_model], axis=0)
all_imgs_results_big_model.reset_index(drop=True, inplace=True)
all_imgs_results_big_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,img_type
0,IMG_3342_jpeg.rf.66db10f7864d752fb7976d5e2a0d2...,Nuts_1-99g_Halal_NonHealthy,-2.081311,-2.902017,-1.082917,-1.917268,0.513138,-2.056075,-3.103575,-4.815222,...,-2.695246,-2.423932,-3.589839,-3.066173,-3.762192,2.367532,-3.054208,-3.248664,3.118231,existing
1,IMG_6525_jpeg.rf.4819ad251c1c4c1e1271728375ee7...,HoneyOtherSpreads_100-199g_Halal_NonHealthy,-3.528679,-2.220776,-1.372103,-3.527915,-0.806374,-5.616762,-2.908732,-3.535326,...,-1.952339,-4.151642,-2.863147,-3.851684,-4.587035,2.521691,-3.121677,-3.247515,2.68848,existing
2,IMG_6549_jpeg.rf.99b9ea7d28228f9e89972b745ddde...,HoneyOtherSpreads_400-499g_Halal_NonHealthy,-3.442461,-1.955826,-3.917761,-3.072334,-1.013375,-5.279501,-2.093924,-1.98879,...,-1.452695,-3.765491,-3.505961,-3.42628,-3.930771,2.649693,-3.132369,-4.636736,4.17741,existing
3,IMG_6835_jpeg.rf.2c219ac88826ee3c0452fff0f1964...,Babyfood_1-99g_Halal_NonHealthy,-0.814915,-0.093799,6.678509,-1.392194,-1.317889,-4.297375,-3.702756,-2.697277,...,-2.359164,-1.809866,-2.472234,-2.895805,-3.457967,3.939218,-4.644306,-4.465299,3.686565,existing
4,Crackers_200-299g_0311_NonHalal_6_png.rf.d36bf...,BiscuitsCrackersCookies_200-299g_NonHalal_NonH...,-2.95755,-1.659869,-1.340207,-0.763084,8.276548,-4.057633,-4.451503,-3.005588,...,0.378352,-1.909216,-2.989162,-1.007948,-3.747228,-2.865301,2.106453,-3.784013,3.310216,existing


In [18]:
all_imgs_results_big_model.tail()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,img_type
3462,5181704785427_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-2.085035,-0.69423,-1.524884,-2.258776,-0.523402,-3.853667,-2.045266,-2.82182,...,-0.60893,-1.379059,-0.607297,-0.91048,-1.412673,3.083356,-2.480405,0.093135,-0.154447,new
3463,5191704785428_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-1.781084,-0.384382,-1.344725,-2.581363,-0.442115,-3.333997,-1.991416,-3.166846,...,1.277482,-1.916547,-0.338003,-0.876995,-1.690415,3.801517,-3.136102,0.626521,-0.60006,new
3464,5201704785430_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-2.312832,-1.113134,-2.11366,-2.963626,-0.180593,-3.317151,-2.546812,-3.186569,...,1.074274,-2.676588,-1.309397,-1.522568,-2.150415,2.411772,-1.994419,-0.605169,0.582726,new
3465,5211704785432_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-2.179433,-0.726598,-2.506155,-2.856576,-0.116809,-3.409679,-1.881398,-2.863961,...,1.071388,-1.775874,-0.696511,-1.079474,-1.285486,3.013429,-2.538661,0.530294,-0.569061,new
3466,5221704785433_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,-2.128102,-0.579057,-2.348396,-2.582589,-0.019272,-3.056315,-1.428041,-2.344664,...,-0.28426,-1.116279,-0.539162,-0.89878,-1.04927,3.385865,-2.904839,0.386792,-0.330786,new


In [19]:
all_imgs_results_big_model.shape

(3467, 60)

In [20]:
all_imgs_results_big_model.to_csv("all_imgs_results_big_model.csv", index=True)

# Small model

## Model loading

In [21]:
class MultiHeadResNet_SmallModel(nn.Module):
    def __init__(self, num_classes_prdtype, num_classes_weight, num_classes_halal, num_classes_healthy):
        super(MultiHeadResNet_SmallModel, self).__init__()
        self.base_model = models.resnet18(pretrained=True)
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Identity()

        # Define custom fully connected layers for each prediction head
        self.fc_prdtype = nn.Linear(num_ftrs, num_classes_prdtype)
        self.fc_weight = nn.Linear(num_ftrs, num_classes_weight)
        self.fc_halal = nn.Linear(num_ftrs, num_classes_halal)
        self.fc_healthy = nn.Linear(num_ftrs, num_classes_healthy)

    def forward(self, x):
        x = self.base_model(x)
        prdtype = self.fc_prdtype(x)
        weight = self.fc_weight(x)
        halal = self.fc_halal(x)
        healthy = self.fc_healthy(x)
        return prdtype, weight, halal, healthy

    
# load label encoder 
def load_label_encoder_small_model():
    le_prdtype = pickle.loads(open("../small_model/output/le_prdtype.pickle", "rb").read())
    le_weight = pickle.loads(open("../small_model/output/le_weight.pickle", "rb").read())
    le_halal = pickle.loads(open("../small_model/output/le_halal.pickle", "rb").read())
    le_healthy = pickle.loads(open("../small_model/output/le_healthy.pickle", "rb").read())
    
    return le_prdtype, le_weight, le_halal, le_healthy

le_prdtype, le_weight, le_halal, le_healthy = load_label_encoder_small_model()

# Load the trained MultiHeadResNet model
def load_model():
    # Verify the number of classes for each label
    num_classes_prdtype = len(le_prdtype.classes_)
    num_classes_weight = len(le_weight.classes_)
    num_classes_halal = len(le_halal.classes_)
    num_classes_healthy = len(le_healthy.classes_)
    # print(num_classes_prdtype)
    # print(num_classes_healthy)

    custom_resnet_model = MultiHeadResNet_SmallModel(
        num_classes_prdtype=num_classes_prdtype,
        num_classes_weight=num_classes_weight,
        num_classes_halal=num_classes_halal,
        num_classes_healthy=num_classes_healthy
    )

    model_path = '../small_model/output/multi_head_model.pth'
    # print("test1")
    if os.path.exists(model_path):
        custom_resnet_model.load_state_dict(torch.load(model_path, map_location=CONFIGS['DEVICE']))
    else:
        raise FileNotFoundError(f"Model file not found: {model_path}")
    # print("test2")
    custom_resnet_model.to(CONFIGS['DEVICE'])
    custom_resnet_model.eval()
    return custom_resnet_model
 
small_model = load_model()

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


## Scoring on new imgs

In [22]:
new_imgs_df = pd.read_csv("../small_model/new_imgs_list.csv")
new_imgs_df.reset_index(drop=True, inplace=True)

# ADHOC: change the new imgs to existing type
new_imgs_df['label'] = 'AdultMilk_1-99g_Halal_NonHealthy'
new_imgs_df['ProductType'] = 'AdultMilk'
new_imgs_df['Weight'] = '1-99g'
new_imgs_df['HalalStatus'] = 'Halal'
new_imgs_df['HealthStatus'] = 'NonHealthy'

new_imgs_df.head()

Unnamed: 0,filepath,label,ProductType,Weight,HalalStatus,HealthStatus
0,5131704785418_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,AdultMilk,1-99g,Halal,NonHealthy
1,5141704785419_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,AdultMilk,1-99g,Halal,NonHealthy
2,5151704785420_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,AdultMilk,1-99g,Halal,NonHealthy
3,5161704785422_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,AdultMilk,1-99g,Halal,NonHealthy
4,5171704785423_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,AdultMilk,1-99g,Halal,NonHealthy


In [23]:
new_imgs_results_small_model = pd.read_csv("../small_model/new_imgs_results_small_model.csv")
new_imgs_results_small_model = new_imgs_results_small_model.loc[new_imgs_results_small_model.Filename.isin(new_imgs_df.filepath)]
new_imgs_results_small_model.reset_index(drop=True, inplace=True)
new_imgs_results_small_model.head()

Unnamed: 0,Filename,CorrectTotalLabel,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5191704785428_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,7.467985,-0.455903,-1.165236,-0.801069,-1.308381,-3.237241,-4.273062,-2.624339,...,-0.989483,-1.511478,-1.901927,-1.902937,-1.082075,-1.482124,2.499995,-3.375038,-3.133447,1.911802
1,5131704785418_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,10.115109,-0.741757,-1.852127,-2.395665,-1.157235,-3.507548,-4.208895,-1.817594,...,-1.521327,-3.257041,-1.830936,-2.256523,-1.258991,-1.090913,3.655709,-5.586953,-4.097594,3.037369
2,5161704785422_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,10.826588,-0.742384,-0.055338,-3.254927,0.313838,-3.043004,-3.574817,-1.882671,...,0.850131,-3.959396,-0.241461,-2.913476,-0.968537,-1.842475,4.468235,-5.253281,-4.024357,4.338975
3,5221704785433_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,10.402787,-1.199843,-1.399939,-3.267868,-0.732076,-2.685109,-4.122629,-1.684803,...,-1.405638,-3.556877,-0.852787,-1.800704,-0.837352,-1.654783,2.876932,-4.567831,-4.165457,3.584487
4,5171704785423_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,9.232785,-0.72291,-1.045706,-1.95562,-0.930536,-2.480214,-3.910032,-1.965349,...,-1.367624,-3.448782,-0.539122,-1.300058,-0.615255,-1.874075,2.897145,-3.574468,-4.037078,3.543715


In [24]:
new_imgs_results_small_model.shape

(10, 59)

In [25]:
# Create a copy of the current column names to a list
new_columns = new_imgs_results_small_model.columns.tolist()

# Modify the first two elements
new_columns[0] = 'filepath'
new_columns[1] = 'label'

# Assign the modified list of column names back to the DataFrame
new_imgs_results_small_model.columns = new_columns
new_imgs_results_small_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5191704785428_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,7.467985,-0.455903,-1.165236,-0.801069,-1.308381,-3.237241,-4.273062,-2.624339,...,-0.989483,-1.511478,-1.901927,-1.902937,-1.082075,-1.482124,2.499995,-3.375038,-3.133447,1.911802
1,5131704785418_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,10.115109,-0.741757,-1.852127,-2.395665,-1.157235,-3.507548,-4.208895,-1.817594,...,-1.521327,-3.257041,-1.830936,-2.256523,-1.258991,-1.090913,3.655709,-5.586953,-4.097594,3.037369
2,5161704785422_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,10.826588,-0.742384,-0.055338,-3.254927,0.313838,-3.043004,-3.574817,-1.882671,...,0.850131,-3.959396,-0.241461,-2.913476,-0.968537,-1.842475,4.468235,-5.253281,-4.024357,4.338975
3,5221704785433_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,10.402787,-1.199843,-1.399939,-3.267868,-0.732076,-2.685109,-4.122629,-1.684803,...,-1.405638,-3.556877,-0.852787,-1.800704,-0.837352,-1.654783,2.876932,-4.567831,-4.165457,3.584487
4,5171704785423_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,9.232785,-0.72291,-1.045706,-1.95562,-0.930536,-2.480214,-3.910032,-1.965349,...,-1.367624,-3.448782,-0.539122,-1.300058,-0.615255,-1.874075,2.897145,-3.574468,-4.037078,3.543715


In [26]:
# Check if any name from 'extracted_names' is not in 'df' and add it as a new column
new_prdtype = list(set(all_imgs_results_big_model.columns) - set(new_imgs_results_small_model.columns))

if len(new_prdtype)>0:
    for col in new_prdtype:
        new_imgs_results_small_model[col] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=new_imgs_results_small_model.shape[0])  # Initialize new columns

new_imgs_results_small_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,img_type
0,5191704785428_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,7.467985,-0.455903,-1.165236,-0.801069,-1.308381,-3.237241,-4.273062,-2.624339,...,-1.511478,-1.901927,-1.902937,-1.082075,-1.482124,2.499995,-3.375038,-3.133447,1.911802,-14.907014
1,5131704785418_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,10.115109,-0.741757,-1.852127,-2.395665,-1.157235,-3.507548,-4.208895,-1.817594,...,-3.257041,-1.830936,-2.256523,-1.258991,-1.090913,3.655709,-5.586953,-4.097594,3.037369,-15.152727
2,5161704785422_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,10.826588,-0.742384,-0.055338,-3.254927,0.313838,-3.043004,-3.574817,-1.882671,...,-3.959396,-0.241461,-2.913476,-0.968537,-1.842475,4.468235,-5.253281,-4.024357,4.338975,-14.836858
3,5221704785433_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,10.402787,-1.199843,-1.399939,-3.267868,-0.732076,-2.685109,-4.122629,-1.684803,...,-3.556877,-0.852787,-1.800704,-0.837352,-1.654783,2.876932,-4.567831,-4.165457,3.584487,-15.454279
4,5171704785423_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,9.232785,-0.72291,-1.045706,-1.95562,-0.930536,-2.480214,-3.910032,-1.965349,...,-3.448782,-0.539122,-1.300058,-0.615255,-1.874075,2.897145,-3.574468,-4.037078,3.543715,-15.328803


In [27]:
new_imgs_results_small_model.shape

(10, 60)

## Scoring on main imgs

In [28]:
main_imgs_master_list = pd.read_csv("../master_list.csv")
main_imgs_master_list.head()

Unnamed: 0,filepath,xmin,ymin,xmax,ymax,label,ProductType,Weight,HalalStatus,HealthStatus,new_camera,tag
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,151,42,497,591,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,88,81,442,567,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,35,34,492,622,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,99,122,428,587,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,103,17,474,592,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,


In [29]:
main_imgs_results_small_model = []  # List to store the results
le_prdtype, le_weight, le_halal, le_healthy = load_label_encoder_small_model()

for idx, row in main_imgs_master_list.iterrows():
    image_path = "../all_images/" + row['filepath']
    frame = cv2.imread(image_path)

    # Preprocessing steps
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = cv2.resize(frame, (CONFIGS['SMALL_MODEL_IMG_SIZE'], CONFIGS['SMALL_MODEL_IMG_SIZE']))
    frame = frame.transpose((2, 0, 1))
    frame = torch.from_numpy(frame).float()
    frame = transforms_test(frame).unsqueeze(0).to(CONFIGS['DEVICE'])

    # Perform prediction
    with torch.no_grad():
        out1, out2, out3, out4 = small_model(frame)
    
    # Extract and store the results
    prediction_row = [row['filepath'], row['label']]
    prediction_row.extend(out1.cpu().numpy().flatten())
    prediction_row.extend(out2.cpu().numpy().flatten())
    prediction_row.extend(out3.cpu().numpy().flatten())
    prediction_row.extend(out4.cpu().numpy().flatten())
    main_imgs_results_small_model.append(prediction_row)


# Define column names for the new DataFrame
column_names = ['filepath', 'label']
column_names += ['ProductType_' + name for name in le_prdtype.classes_]
column_names += ['Weight_' + name for name in le_weight.classes_]
column_names += ['HalalStatus_' + name for name in le_halal.classes_]
column_names += ['HealthStatus_' + name for name in le_healthy.classes_]


# Create the DataFrame
main_imgs_results_small_model = pd.DataFrame(main_imgs_results_small_model, columns=column_names)
main_imgs_results_small_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,Sugar_400-499g_NonHalal_NonHealthy,-3.593526,-0.834642,-0.036741,-0.449187,2.034535,-2.399136,-1.877603,-3.115813,...,7.441017,-1.874432,-1.630215,-3.345615,0.154984,-4.595025,-1.709256,0.742516,-4.266921,1.251087
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,Sugar_400-499g_NonHalal_NonHealthy,-3.660402,-2.571959,0.224866,-2.18387,-0.816285,-2.118711,-1.726907,-4.019604,...,6.182879,-2.065965,-2.811384,-5.012823,0.3272,-1.995343,-1.523705,1.806258,-4.367521,0.752712
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,Sugar_400-499g_NonHalal_NonHealthy,-0.960679,-3.024067,0.840389,-1.106379,-1.344823,-1.859851,-0.9774,-2.414862,...,7.435597,0.022743,-2.701884,-5.151578,-0.955235,-1.765768,-3.311442,2.412364,-4.632944,4.555023
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,Sugar_400-499g_NonHalal_NonHealthy,-2.445158,-2.502353,-0.03009,-2.150193,-0.496334,-1.86334,-1.602406,-2.944317,...,5.663602,-1.559119,-3.43228,-4.082855,-0.294293,-0.931375,-1.490215,1.05406,-3.768671,-0.287054
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,Sugar_400-499g_NonHalal_NonHealthy,-3.006482,-2.20697,-2.354793,-1.360836,-0.575183,-2.720716,-0.906097,-2.562416,...,8.325328,-0.317981,-1.517341,-4.780657,-3.586008,-2.380914,-5.034392,3.682207,-6.363514,2.529441


In [30]:
main_imgs_results_small_model.shape

(4458, 59)

In [31]:
# Check if any name from 'extracted_names' is not in 'df' and add it as a new column
new_prdtype = list(set(all_imgs_results_big_model.columns) - set(main_imgs_results_small_model.columns))

if len(new_prdtype)>0:
    for col in new_prdtype:
        main_imgs_results_small_model[col] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=main_imgs_results_small_model.shape[0])  # Initialize new columns

main_imgs_results_small_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,img_type
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,Sugar_400-499g_NonHalal_NonHealthy,-3.593526,-0.834642,-0.036741,-0.449187,2.034535,-2.399136,-1.877603,-3.115813,...,-1.874432,-1.630215,-3.345615,0.154984,-4.595025,-1.709256,0.742516,-4.266921,1.251087,-14.744955
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,Sugar_400-499g_NonHalal_NonHealthy,-3.660402,-2.571959,0.224866,-2.18387,-0.816285,-2.118711,-1.726907,-4.019604,...,-2.065965,-2.811384,-5.012823,0.3272,-1.995343,-1.523705,1.806258,-4.367521,0.752712,-14.986271
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,Sugar_400-499g_NonHalal_NonHealthy,-0.960679,-3.024067,0.840389,-1.106379,-1.344823,-1.859851,-0.9774,-2.414862,...,0.022743,-2.701884,-5.151578,-0.955235,-1.765768,-3.311442,2.412364,-4.632944,4.555023,-15.342698
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,Sugar_400-499g_NonHalal_NonHealthy,-2.445158,-2.502353,-0.03009,-2.150193,-0.496334,-1.86334,-1.602406,-2.944317,...,-1.559119,-3.43228,-4.082855,-0.294293,-0.931375,-1.490215,1.05406,-3.768671,-0.287054,-15.23036
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,Sugar_400-499g_NonHalal_NonHealthy,-3.006482,-2.20697,-2.354793,-1.360836,-0.575183,-2.720716,-0.906097,-2.562416,...,-0.317981,-1.517341,-4.780657,-3.586008,-2.380914,-5.034392,3.682207,-6.363514,2.529441,-15.048306


In [32]:
main_imgs_results_small_model.shape

(4458, 60)

## All scorings from small model

In [33]:
main_imgs_results_small_model['img_type'] = "existing"
new_imgs_results_small_model['img_type'] = "new"
all_imgs_results_small_model = pd.concat([main_imgs_results_small_model, new_imgs_results_small_model], axis=0)
all_imgs_results_small_model.reset_index(drop=True, inplace=True)
all_imgs_results_small_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,img_type
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,Sugar_400-499g_NonHalal_NonHealthy,-3.593526,-0.834642,-0.036741,-0.449187,2.034535,-2.399136,-1.877603,-3.115813,...,-1.874432,-1.630215,-3.345615,0.154984,-4.595025,-1.709256,0.742516,-4.266921,1.251087,existing
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,Sugar_400-499g_NonHalal_NonHealthy,-3.660402,-2.571959,0.224866,-2.18387,-0.816285,-2.118711,-1.726907,-4.019604,...,-2.065965,-2.811384,-5.012823,0.3272,-1.995343,-1.523705,1.806258,-4.367521,0.752712,existing
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,Sugar_400-499g_NonHalal_NonHealthy,-0.960679,-3.024067,0.840389,-1.106379,-1.344823,-1.859851,-0.9774,-2.414862,...,0.022743,-2.701884,-5.151578,-0.955235,-1.765768,-3.311442,2.412364,-4.632944,4.555023,existing
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,Sugar_400-499g_NonHalal_NonHealthy,-2.445158,-2.502353,-0.03009,-2.150193,-0.496334,-1.86334,-1.602406,-2.944317,...,-1.559119,-3.43228,-4.082855,-0.294293,-0.931375,-1.490215,1.05406,-3.768671,-0.287054,existing
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,Sugar_400-499g_NonHalal_NonHealthy,-3.006482,-2.20697,-2.354793,-1.360836,-0.575183,-2.720716,-0.906097,-2.562416,...,-0.317981,-1.517341,-4.780657,-3.586008,-2.380914,-5.034392,3.682207,-6.363514,2.529441,existing


In [34]:
all_imgs_results_small_model.tail()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BabyMilkPowder,ProductType_Babyfood,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_BreakfastCerealsCornflakes,ProductType_CannedPacketCreamersSweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,img_type
4463,5151704785420_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,9.389573,-2.244836,-1.150056,-2.540976,-0.885031,-3.015272,-3.191794,-0.882367,...,-3.866841,-2.000998,-2.607693,-1.701272,-0.689729,2.932843,-4.720117,-3.526289,3.709331,new
4464,5201704785430_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,10.029627,-0.458718,-1.476229,-2.266092,-1.438141,-3.720444,-4.646971,-2.070412,...,-2.709237,-0.505904,-1.779749,-0.78844,-2.308589,3.605042,-4.219198,-4.332861,2.899928,new
4465,5211704785432_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,9.932009,-0.960113,-1.285829,-2.377143,-0.638341,-3.437519,-4.333635,-2.049438,...,-3.223189,-0.972247,-1.808021,-0.82121,-1.525869,3.330086,-4.575498,-3.431279,3.549468,new
4466,5181704785427_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,9.380651,-0.720351,-0.960074,-1.787928,-0.386909,-3.336283,-3.495619,-1.825738,...,-2.817045,-0.112533,-2.407389,-0.821929,-1.994526,3.785856,-4.707764,-3.405584,2.569691,new
4467,5141704785419_.pic.jpg,AdultMilk_1-99g_Halal_NonHealthy,7.635812,-0.835084,-0.277589,-1.891913,0.151093,-2.115914,-3.328227,-1.542094,...,-2.875587,-0.740422,-1.826304,-1.601982,-1.813279,2.628387,-3.892666,-3.14226,2.987863,new


In [35]:
all_imgs_results_small_model.shape

(4468, 60)

In [36]:
all_imgs_results_small_model.to_csv("all_imgs_results_small_model.csv", index=True)

# Bayesian model

In [37]:
prdtype_cols = [col for col in all_imgs_results_small_model.columns if col.startswith('ProductType_')]

In [38]:
all_imgs_results_small_model_prdtype = all_imgs_results_small_model[['label']+prdtype_cols]
all_imgs_results_big_model_prdtype = all_imgs_results_big_model[['label']+prdtype_cols]
all_imgs_results_small_model_prdtype = all_imgs_results_small_model_prdtype.sort_values(by='label').reset_index(drop=True)
all_imgs_results_big_model_prdtype = all_imgs_results_big_model_prdtype.sort_values(by='label').reset_index(drop=True)

In [39]:
assert (all_imgs_results_small_model_prdtype['label'][(all_imgs_results_small_model_prdtype['label'] == all_imgs_results_big_model_prdtype['label'])]).all()

ValueError: Can only compare identically-labeled Series objects

In [None]:
all_imgs_results_small_model_prdtype['label_prdtype'] = all_imgs_results_small_model_prdtype['label'].str.split('_').str[0]
all_imgs_results_big_model_prdtype['label_prdtype'] = all_imgs_results_big_model_prdtype['label'].str.split('_').str[0]

In [None]:
# Remove the prefix from column names
all_imgs_results_small_model_prdtype.columns = [col.replace("ProductType_", '') if col.startswith("ProductType_") else col for col in all_imgs_results_small_model_prdtype.columns]
all_imgs_results_big_model_prdtype.columns = [col.replace("ProductType_", '') if col.startswith("ProductType_") else col for col in all_imgs_results_big_model_prdtype.columns]

In [None]:
prdtype_label_encoder = LabelEncoder()
truelabel = prdtype_label_encoder.fit_transform(all_imgs_results_big_model_prdtype['label_prdtype'])

In [None]:
# Assuming 'category_names' is the list of unique category names in the order they appear in logitscoresA
category_names = list(all_imgs_results_small_model_prdtype['label_prdtype'].unique())
category_to_encoded = {name: prdtype_label_encoder.transform([name])[0] for name in category_names}

# Reorder columns of logitscoresA and logitscoresB to match the order of encoded labels
ordered_columns = [category_names[i] for i in prdtype_label_encoder.transform(category_names)]
logitscoresA = all_imgs_results_big_model_prdtype[ordered_columns].values
logitscoresB = all_imgs_results_small_model_prdtype[ordered_columns].values


In [None]:
all_imgs_results_big_model.head()

In [None]:
# big model accuracy - total
pred_big_model_prdtype = np.argmax(logitscoresA, axis=1)
sum(pred_big_model_prdtype == truelabel) / len(truelabel)

In [None]:
# small model accuracy - total
pred_small_model_prdtype = np.argmax(logitscoresB, axis=1)
sum(pred_small_model_prdtype == truelabel) / len(truelabel)

In [None]:
# big model accuracy - new imgs
indices = np.where(truelabel == category_to_encoded['JennyBakery'])
sum(pred_big_model_prdtype[indices] == truelabel[indices]) / len(indices[0].tolist())

In [None]:
# small model accuracy - new imgs
indices = np.where(truelabel == category_to_encoded['JennyBakery'])
sum(pred_small_model_prdtype[indices] == truelabel[indices]) / len(indices[0].tolist())

In [None]:
len(truelabel)

In [None]:
logitscoresA.shape[1]

In [None]:
ordered_columns

In [None]:
# import pymc3 as pm
# import theano.tensor as tt
# import numpy as np
# import scipy.stats

# # Sample data setup (replace with your actual data)
# # logitscoresA and logitscoresB are matrices of logit scores for each category from classifiers A and B
# # truelabel is an already existing 1D array of integers representing the true labels
# indices = [np.random.choice(100, 3, replace=False)]  # Replace with your indices for missing data

# N = len(truelabel)
# L = logitscoresA.shape[1]
# missingidx = indices[0].tolist()  # Indices of missing data

# # Initialize truelabel_with_missing with the original truelabel and set missing indices to -1
# truelabel_with_missing = np.array(truelabel, dtype=np.int)
# truelabel_with_missing[missingidx] = -1

# # Mask the missing values
# masked_truelabel = np.ma.masked_where(truelabel_with_missing == -1, truelabel_with_missing)

# with pm.Model() as model:
#     # Priors
#     muA1 = pm.Normal('muA1', mu=0, sigma=10)
#     muA0 = pm.Normal('muA0', mu=0, sigma=10)
#     sigmaA = pm.Uniform('sigmaA', lower=0.01, upper=1.0)
#     muB1 = pm.Normal('muB1', mu=0, sigma=10)
#     muB0 = pm.Normal('muB0', mu=0, sigma=10)
#     sigmaB = pm.Uniform('sigmaB', lower=0.01, upper=1.0)
#     rho = pm.Uniform('rho', lower=-1, upper=1)
    
#     # Uniform prior over labels
#     labelprob = pm.Dirichlet('labelprob', a=tt.ones(L))

#     # Likelihood
#     muA = pm.math.switch(tt.eq(tt.arange(L), masked_truelabel[:, None]), muA1, muA0)
#     muB = pm.math.switch(tt.eq(tt.arange(L), masked_truelabel[:, None]), muB1, muB0)
    
#     logitscoresA_obs = pm.Normal('logitscoresA_obs', mu=muA, sigma=sigmaA, observed=logitscoresA)
#     logitscoresB_obs = pm.Normal('logitscoresB_obs', mu=muB + rho * (logitscoresA - muA) / sigmaA, sigma=tt.sqrt((1 - rho ** 2) * sigmaB ** 2), observed=logitscoresB)
    
#     # Define the categorical distribution for the true labels
#     truelabel_obs = pm.Categorical('truelabel_obs', p=labelprob, observed=masked_truelabel)

#     # Inference
#     trace = pm.sample(2000, tune=500, cores=1)

#     # Plotting within the model context
#     # az.plot_trace(trace)
#     # plt.show()

#     # Posterior predictive checks
#     ppc = pm.sample_posterior_predictive(trace, var_names=['truelabel_obs'])

# # Process the posterior predictive checks for missing indices
# infer_labels = []
# for idx in missingidx:
#     label_samples = ppc['truelabel_obs'][:, idx]
#     inferred_label = scipy.stats.mode(label_samples).mode[0]
#     infer_labels.append(inferred_label)

# # Output the inferred labels for missing indices
# print("Inferred labels for missing indices:", infer_labels)

In [None]:
np.unique(infer_labels)

In [None]:
~np.isin(np.arange(N), missingidx)

In [None]:
np.array(truelabel, dtype=np.int)[missingidx]

In [None]:
pred_small_model_prdtype[missingidx]

In [None]:
pred_big_model_prdtype[missingidx]

In [None]:
missingidx

In [None]:
ppc['truelabel_obs'].shape