In [1]:
import pandas as pd
import torch.nn as nn
import pickle
import torch
from torchvision import models
from torchvision.models import detection, resnet50, ResNet50_Weights
import os
import numpy as np
import cv2
from torchvision import transforms
import pymc3 as pm
import theano.tensor as tt
from sklearn.preprocessing import LabelEncoder
import scipy

In [2]:
CONFIGS = {
    # determine the current device and based on that set the pin memory
    # flag
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    # specify ImageNet mean and standard deviation
    "IMG_MEAN": [0.485, 0.456, 0.406],
    "IMG_STD": [0.229, 0.224, 0.225],
    "MC_DROPOUT_ENABLED": False,  # Switch to enable/disable MC Dropout for confidence score
    "NUM_DROPOUT_RUNS": 3,
    "CONFIDENCE_THRESHOLD": 0,
    "BIG_MODEL_IMG_SIZE": 320,
    "SMALL_MODEL_IMG_SIZE": 60,
    "MEAN_PRIOR": -15,
}

# Big model

## Model loading

In [3]:
class MultiHeadResNet_BigModel(nn.Module):
    def __init__(self, num_classes_prdtype, num_classes_weight, num_classes_halal, num_classes_healthy):
        super(MultiHeadResNet_BigModel, self).__init__()
        self.base_model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Identity()

        # Define custom fully connected layers for each prediction head
        self.fc_prdtype = nn.Linear(num_ftrs, num_classes_prdtype)
        self.fc_weight = nn.Linear(num_ftrs, num_classes_weight)
        self.fc_halal = nn.Linear(num_ftrs, num_classes_halal)
        self.fc_healthy = nn.Linear(num_ftrs, num_classes_healthy)
        self.fc_bbox = nn.Linear(num_ftrs, 4)

    def forward(self, x):
        x = self.base_model(x)
        prdtype = self.fc_prdtype(x)
        weight = self.fc_weight(x)
        halal = self.fc_halal(x)
        healthy = self.fc_healthy(x)
        box = self.fc_bbox(x)
        return prdtype, weight, halal, healthy, box

    
# load label encoder 
def load_label_encoder_big_model():
    le_prdtype = pickle.loads(open("../big_model/le_prdtype.pickle", "rb").read())
    le_weight = pickle.loads(open("../big_model/le_weight.pickle", "rb").read())
    le_halal = pickle.loads(open("../big_model/le_halal.pickle", "rb").read())
    le_healthy = pickle.loads(open("../big_model/le_healthy.pickle", "rb").read())
    
    return le_prdtype, le_weight, le_halal, le_healthy

le_prdtype, le_weight, le_halal, le_healthy = load_label_encoder_big_model()

# Load the trained MultiHeadResNet model
def load_model():
    # Verify the number of classes for each label
    num_classes_prdtype = len(le_prdtype.classes_)
    num_classes_weight = len(le_weight.classes_)
    num_classes_halal = len(le_halal.classes_)
    num_classes_healthy = len(le_healthy.classes_)
    # print(num_classes_prdtype)
    # print(num_classes_healthy)

    custom_resnet_model = MultiHeadResNet_BigModel(
        num_classes_prdtype=num_classes_prdtype,
        num_classes_weight=num_classes_weight,
        num_classes_halal=num_classes_halal,
        num_classes_healthy=num_classes_healthy
    )

    model_path = '../big_model/multi_head_model.pth'
    # print("test1")
    if os.path.exists(model_path):
        custom_resnet_model.load_state_dict(torch.load(model_path, map_location=CONFIGS['DEVICE']))
    else:
        raise FileNotFoundError(f"Model file not found: {model_path}")
    # print("test2")
    custom_resnet_model.to(CONFIGS['DEVICE'])
    custom_resnet_model.eval()
    return custom_resnet_model

big_model = load_model()

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


## Scoring on main imgs

In [4]:
main_imgs_results_big_model = pd.read_csv("main_imgs_results_big_model.csv")
main_imgs_results_big_model.head()

Unnamed: 0,Filename,CorrectTotalLabel,ProductType_AdultMilk,ProductType_Babyfood,ProductType_Babymilk-powder,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_Breakfast-cereals-cornflakes,ProductType_Canned-Packet-Creamers-Sweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,20231215_output_frame_0244,PotatoSticks_1-99g_Halal_NonHealthy,-1.587444,-2.496907,-3.559909,-2.273423,1.671694,-0.49073,-3.641042,-2.682706,...,3.005991,-1.938856,-3.626286,-2.632926,-4.078401,-3.550343,2.026452,-3.092768,-3.751043,2.996024
1,output_frame_0617,Kaya_400-499g_Halal_NonHealthy,-4.127325,-1.5334,-1.689596,-1.041802,-1.024461,-4.602612,-2.558664,0.07872,...,7.988635,-1.308213,-4.263793,-2.572152,-3.962276,-3.851224,4.229371,-4.268166,-4.517332,4.193677
2,IMG_5422_jpeg.rf.40db706026c5805998a4ef32aaea01bd,BiscuitsCrackersCookies_100-199g_Halal_NonHealthy,-3.575266,-1.385911,-2.772634,-0.931874,7.930703,-4.527698,-3.865532,-3.374679,...,-1.14045,-3.05643,-2.557728,-2.868939,-1.775112,-3.771607,2.320258,-3.072567,-3.815917,3.348428
3,IMG_9829_JPG.rf.9dcff8b32299bc2d7285db157c38e6d6,OtherNoodles_500-599g_NonHalal_NonHealthy,-3.725879,-4.687295,-3.840399,-1.293027,-0.585406,-5.36357,-4.092983,-4.843928,...,-1.957848,4.860765,-1.20123,-0.325409,-2.343176,-3.91664,-2.494664,2.036901,-4.395383,3.689328
4,2023_10_25_11_52_33_260721,MaternalMilkPowder_600-699g_Halal_NonHealthy,2.656255,0.116325,2.72413,-3.962348,-1.72144,-6.606655,-3.900061,-5.223273,...,-2.115172,-1.551353,7.838285,-1.78831,-0.916579,-1.628614,1.99219,-3.352175,-3.212959,2.271578


In [5]:
# Create a copy of the current column names to a list
new_columns = main_imgs_results_big_model.columns.tolist()

# Modify the first two elements
new_columns[0] = 'filepath'
new_columns[1] = 'label'

# Assign the modified list of column names back to the DataFrame
main_imgs_results_big_model.columns = new_columns


In [6]:
big_model_pred_col_name_original = main_imgs_results_big_model.columns[2:].tolist()

In [7]:
new_imgs_results_small_model = pd.read_csv("new_imgs_results_small_model.csv")
new_imgs_results_small_model.head()

Unnamed: 0,Filename,CorrectTotalLabel,ProductType_AdultMilk,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_CerealBeverage-powder,ProductType_Coffee,ProductType_HoneyOtherSpreads,ProductType_JennyBakery,ProductType_Nuts,...,Weight_1-99g,Weight_200-299g,Weight_300-399g,Weight_400-499g,Weight_500-599g,Weight_800-899g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,IMG_3442_jpeg.rf.3f2785b9cb1ba9a272b60ead15de65e4,Nuts_300-399g_Halal_NonHealthy,-2.208533,-0.570678,-0.431555,-1.769167,-1.163584,-0.342905,-0.103244,5.4338,...,-2.803985,-0.518957,5.729784,-0.945454,-0.719616,0.117333,2.305885,-3.553557,-3.736882,2.210839
1,2023_10_25_11_23_15_511639,AdultMilk_1-99g_Halal_NonHealthy,7.738378,-1.703542,-0.51047,-2.001149,-2.509459,-0.663433,-2.828204,-1.897847,...,6.64024,-1.481095,-3.000674,-2.034991,-2.920876,0.346742,2.217192,-5.763519,-3.322881,4.433139
2,IMG_6461_jpeg.rf.f0150fe3f237233cdbcf34b8bbff3699,HoneyOtherSpreads_300-399g_NonHalal_NonHealthy,-1.331515,-0.76914,-1.431097,-1.83164,-0.163435,5.959034,-1.270023,-0.488899,...,-1.770536,-1.523189,6.178998,-2.019716,-1.540292,-1.861531,-0.880123,4.226662,-3.344068,3.447545
3,5201704785430_.pic,JennyBakery_500-599g_NonHalal_NonHealthy,-1.841695,-0.555624,-0.648654,-0.662479,-0.292953,-0.246531,6.119444,-0.843022,...,-2.126385,-2.102927,-1.257894,-0.933212,4.589014,-2.215257,-1.95677,2.945171,-1.432672,3.168404
4,IMG_6395_jpeg.rf.79d289c209bef1d45039bd34546c1560,Coffee_200-299g_NonHalal_NonHealthy,-2.355077,-0.714852,-2.914923,-1.843791,6.080708,-0.29952,-1.770164,-0.720695,...,-2.035496,5.82118,-0.741027,-0.893428,-1.008977,-1.597336,-2.281424,3.639594,-1.952233,3.726935


In [8]:
# Extract column names that start with 'ProductType'
all_prdtypes_new_imgs = [col for col in new_imgs_results_small_model.columns if col.startswith('ProductType')]
# all_prdtypes_new_imgs = [col.split('_', 1)[1] for col in all_prdtypes_new_imgs]
# all_prdtypes_new_imgs

In [9]:
# Check if any name from 'extracted_names' is not in 'df' and add it as a new column
new_prdtype = list(set(all_prdtypes_new_imgs) - set(main_imgs_results_big_model.columns))

if len(new_prdtype)==1:
    main_imgs_results_big_model[new_prdtype[0]] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=main_imgs_results_big_model.shape[0])  # Initialize new columns

main_imgs_results_big_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_Babyfood,ProductType_Babymilk-powder,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_Breakfast-cereals-cornflakes,ProductType_Canned-Packet-Creamers-Sweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,ProductType_JennyBakery
0,20231215_output_frame_0244,PotatoSticks_1-99g_Halal_NonHealthy,-1.587444,-2.496907,-3.559909,-2.273423,1.671694,-0.49073,-3.641042,-2.682706,...,-1.938856,-3.626286,-2.632926,-4.078401,-3.550343,2.026452,-3.092768,-3.751043,2.996024,-14.959163
1,output_frame_0617,Kaya_400-499g_Halal_NonHealthy,-4.127325,-1.5334,-1.689596,-1.041802,-1.024461,-4.602612,-2.558664,0.07872,...,-1.308213,-4.263793,-2.572152,-3.962276,-3.851224,4.229371,-4.268166,-4.517332,4.193677,-14.926259
2,IMG_5422_jpeg.rf.40db706026c5805998a4ef32aaea01bd,BiscuitsCrackersCookies_100-199g_Halal_NonHealthy,-3.575266,-1.385911,-2.772634,-0.931874,7.930703,-4.527698,-3.865532,-3.374679,...,-3.05643,-2.557728,-2.868939,-1.775112,-3.771607,2.320258,-3.072567,-3.815917,3.348428,-14.93675
3,IMG_9829_JPG.rf.9dcff8b32299bc2d7285db157c38e6d6,OtherNoodles_500-599g_NonHalal_NonHealthy,-3.725879,-4.687295,-3.840399,-1.293027,-0.585406,-5.36357,-4.092983,-4.843928,...,4.860765,-1.20123,-0.325409,-2.343176,-3.91664,-2.494664,2.036901,-4.395383,3.689328,-15.504523
4,2023_10_25_11_52_33_260721,MaternalMilkPowder_600-699g_Halal_NonHealthy,2.656255,0.116325,2.72413,-3.962348,-1.72144,-6.606655,-3.900061,-5.223273,...,-1.551353,7.838285,-1.78831,-0.916579,-1.628614,1.99219,-3.352175,-3.212959,2.271578,-15.146358


## Scoring on new imgs

In [10]:
new_imgs_df = pd.read_csv("../small_model/new_imgs_list.csv")
new_imgs_df.reset_index(drop=True, inplace=True)
new_imgs_df.head()

Unnamed: 0,filepath,label,ProductType,Weight,HalalStatus,HealthStatus
0,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
1,5141704785419_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
2,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy


In [11]:
transforms_test = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=CONFIGS['IMG_MEAN'], std=CONFIGS['IMG_STD'])
])

In [12]:
new_imgs_results_big_model = []  # List to store the results

for idx, row in new_imgs_df.iterrows():
    image_path = "../small_model/new_imgs/" + row['filepath']
    frame = cv2.imread(image_path)

    # Preprocessing steps
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = cv2.resize(frame, (CONFIGS['BIG_MODEL_IMG_SIZE'], CONFIGS['BIG_MODEL_IMG_SIZE']))
    frame = frame.transpose((2, 0, 1))
    frame = torch.from_numpy(frame).float()
    frame = transforms_test(frame).unsqueeze(0).to(CONFIGS['DEVICE'])

    # Perform prediction
    with torch.no_grad():
        out1, out2, out3, out4, _ = big_model(frame)

    # Extract and store the results
    prediction_row = [row['filepath'], row['label']]
    prediction_row.extend(out1.cpu().numpy().flatten())
    prediction_row.extend(out2.cpu().numpy().flatten())
    prediction_row.extend(out3.cpu().numpy().flatten())
    prediction_row.extend(out4.cpu().numpy().flatten())
    new_imgs_results_big_model.append(prediction_row)


# Define column names for the new DataFrame
column_names = ['filepath', 'label']
column_names += big_model_pred_col_name_original

# Create the DataFrame
new_imgs_results_big_model = pd.DataFrame(new_imgs_results_big_model, columns=column_names)
new_imgs_results_big_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_Babyfood,ProductType_Babymilk-powder,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_Breakfast-cereals-cornflakes,ProductType_Canned-Packet-Creamers-Sweet,...,Weight_400-499g,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-4.03672,-0.261086,-0.850562,-2.136504,-0.287036,-5.450655,-3.437403,-3.360319,...,-2.111177,-1.777058,-1.678948,0.534387,-0.942685,-2.938613,4.652159,-5.311422,0.440016,-1.574372
1,5141704785419_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.987413,-0.235551,-0.138102,-1.88994,-0.242081,-4.247354,-3.341923,-3.569029,...,-2.444483,-0.382582,-1.367087,0.568081,-0.701762,-2.520679,1.503756,-2.084439,0.783709,-1.582145
2,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-4.175052,-0.719495,-1.335449,-3.06461,0.316437,-5.040257,-3.596658,-3.615277,...,-2.124496,-1.862571,-1.151982,1.287045,-1.623109,-3.315204,2.790001,-3.489346,-0.40789,-0.574956
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-4.210509,-0.793994,-1.282921,-2.872747,-0.510361,-5.225197,-3.520112,-2.811852,...,0.84523,-2.310489,-1.154076,-0.022997,-2.183432,-3.124709,2.588875,-3.55255,-1.222334,0.505059
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.034265,-1.384678,-1.404768,-1.696951,0.030562,-4.573633,-2.98474,-3.017923,...,-1.461493,-1.154279,-0.802572,-0.251512,-0.776558,-2.234903,3.26945,-4.011051,-1.188951,0.407649


In [13]:
new_imgs_results_big_model.shape

(10, 60)

In [14]:
if len(new_prdtype)==1:
    new_imgs_results_big_model[new_prdtype[0]] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=new_imgs_results_big_model.shape[0])  # Initialize new columns

new_imgs_results_big_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_Babyfood,ProductType_Babymilk-powder,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_Breakfast-cereals-cornflakes,ProductType_Canned-Packet-Creamers-Sweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,ProductType_JennyBakery
0,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-4.03672,-0.261086,-0.850562,-2.136504,-0.287036,-5.450655,-3.437403,-3.360319,...,-1.777058,-1.678948,0.534387,-0.942685,-2.938613,4.652159,-5.311422,0.440016,-1.574372,-15.221451
1,5141704785419_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.987413,-0.235551,-0.138102,-1.88994,-0.242081,-4.247354,-3.341923,-3.569029,...,-0.382582,-1.367087,0.568081,-0.701762,-2.520679,1.503756,-2.084439,0.783709,-1.582145,-15.210123
2,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-4.175052,-0.719495,-1.335449,-3.06461,0.316437,-5.040257,-3.596658,-3.615277,...,-1.862571,-1.151982,1.287045,-1.623109,-3.315204,2.790001,-3.489346,-0.40789,-0.574956,-15.039245
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-4.210509,-0.793994,-1.282921,-2.872747,-0.510361,-5.225197,-3.520112,-2.811852,...,-2.310489,-1.154076,-0.022997,-2.183432,-3.124709,2.588875,-3.55255,-1.222334,0.505059,-15.674862
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.034265,-1.384678,-1.404768,-1.696951,0.030562,-4.573633,-2.98474,-3.017923,...,-1.154279,-0.802572,-0.251512,-0.776558,-2.234903,3.26945,-4.011051,-1.188951,0.407649,-15.813519


In [15]:
new_imgs_results_big_model.shape

(10, 61)

In [16]:
main_imgs_results_big_model.shape

(3457, 61)

## All scorings from big model

In [17]:
all_imgs_results_big_model = pd.concat([main_imgs_results_big_model, new_imgs_results_big_model], axis=0)
all_imgs_results_big_model.reset_index(drop=True, inplace=True)
all_imgs_results_big_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_Babyfood,ProductType_Babymilk-powder,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_Breakfast-cereals-cornflakes,ProductType_Canned-Packet-Creamers-Sweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,ProductType_JennyBakery
0,20231215_output_frame_0244,PotatoSticks_1-99g_Halal_NonHealthy,-1.587444,-2.496907,-3.559909,-2.273423,1.671694,-0.49073,-3.641042,-2.682706,...,-1.938856,-3.626286,-2.632926,-4.078401,-3.550343,2.026452,-3.092768,-3.751043,2.996024,-14.959163
1,output_frame_0617,Kaya_400-499g_Halal_NonHealthy,-4.127325,-1.5334,-1.689596,-1.041802,-1.024461,-4.602612,-2.558664,0.07872,...,-1.308213,-4.263793,-2.572152,-3.962276,-3.851224,4.229371,-4.268166,-4.517332,4.193677,-14.926259
2,IMG_5422_jpeg.rf.40db706026c5805998a4ef32aaea01bd,BiscuitsCrackersCookies_100-199g_Halal_NonHealthy,-3.575266,-1.385911,-2.772634,-0.931874,7.930703,-4.527698,-3.865532,-3.374679,...,-3.05643,-2.557728,-2.868939,-1.775112,-3.771607,2.320258,-3.072567,-3.815917,3.348428,-14.93675
3,IMG_9829_JPG.rf.9dcff8b32299bc2d7285db157c38e6d6,OtherNoodles_500-599g_NonHalal_NonHealthy,-3.725879,-4.687295,-3.840399,-1.293027,-0.585406,-5.36357,-4.092983,-4.843928,...,4.860765,-1.20123,-0.325409,-2.343176,-3.91664,-2.494664,2.036901,-4.395383,3.689328,-15.504523
4,2023_10_25_11_52_33_260721,MaternalMilkPowder_600-699g_Halal_NonHealthy,2.656255,0.116325,2.72413,-3.962348,-1.72144,-6.606655,-3.900061,-5.223273,...,-1.551353,7.838285,-1.78831,-0.916579,-1.628614,1.99219,-3.352175,-3.212959,2.271578,-15.146358


In [18]:
all_imgs_results_big_model.tail()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_Babyfood,ProductType_Babymilk-powder,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_Breakfast-cereals-cornflakes,ProductType_Canned-Packet-Creamers-Sweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,ProductType_JennyBakery
3462,5181704785427_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.981382,0.204884,-0.888769,-1.459656,-0.265949,-4.899653,-3.37602,-3.299238,...,-1.399942,-0.566004,-0.221259,-0.983596,-2.287874,2.158207,-2.922357,-0.8609,0.138896,-15.087132
3463,5191704785428_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.195043,0.605347,-0.309571,-1.839244,-0.29703,-5.46614,-3.455538,-3.365381,...,-1.34863,-0.21073,0.874258,-0.686935,-1.921793,4.353158,-5.152314,-0.286985,-0.730851,-15.301917
3464,5201704785430_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.386767,-0.337914,-1.020709,-1.718369,-0.23264,-4.663236,-3.434904,-3.262553,...,-1.249037,-1.585015,-0.359816,-1.064754,-2.303001,3.028043,-3.850751,-0.858515,0.188629,-14.334265
3465,5211704785432_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-3.317349,-0.832935,-0.64032,-1.589897,-0.683308,-5.127884,-3.067908,-3.241553,...,-1.210088,-1.176974,0.562943,-0.614958,-1.835613,3.485469,-4.153513,0.489794,-1.378488,-14.452216
3466,5221704785433_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.975153,-1.05683,-0.679077,-2.187956,-0.266953,-5.08936,-2.946344,-2.968317,...,-0.78634,-0.110103,0.111246,0.022641,-1.805679,3.500244,-4.25217,-0.454828,-0.375982,-15.227116


In [19]:
all_imgs_results_big_model.shape

(3467, 61)

In [20]:
all_imgs_results_big_model.to_csv("all_imgs_results_big_model.csv", index=True)

# Small model

## Model loading

In [21]:
class MultiHeadResNet_SmallModel(nn.Module):
    def __init__(self, num_classes_prdtype, num_classes_weight, num_classes_halal, num_classes_healthy):
        super(MultiHeadResNet_SmallModel, self).__init__()
        self.base_model = models.resnet18(pretrained=True)
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Identity()

        # Define custom fully connected layers for each prediction head
        self.fc_prdtype = nn.Linear(num_ftrs, num_classes_prdtype)
        self.fc_weight = nn.Linear(num_ftrs, num_classes_weight)
        self.fc_halal = nn.Linear(num_ftrs, num_classes_halal)
        self.fc_healthy = nn.Linear(num_ftrs, num_classes_healthy)

    def forward(self, x):
        x = self.base_model(x)
        prdtype = self.fc_prdtype(x)
        weight = self.fc_weight(x)
        halal = self.fc_halal(x)
        healthy = self.fc_healthy(x)
        return prdtype, weight, halal, healthy

    
# load label encoder 
def load_label_encoder_small_model():
    le_prdtype = pickle.loads(open("../small_model/output/le_prdtype.pickle", "rb").read())
    le_weight = pickle.loads(open("../small_model/output/le_weight.pickle", "rb").read())
    le_halal = pickle.loads(open("../small_model/output/le_halal.pickle", "rb").read())
    le_healthy = pickle.loads(open("../small_model/output/le_healthy.pickle", "rb").read())
    
    return le_prdtype, le_weight, le_halal, le_healthy

le_prdtype, le_weight, le_halal, le_healthy = load_label_encoder_small_model()

# Load the trained MultiHeadResNet model
def load_model():
    # Verify the number of classes for each label
    num_classes_prdtype = len(le_prdtype.classes_)
    num_classes_weight = len(le_weight.classes_)
    num_classes_halal = len(le_halal.classes_)
    num_classes_healthy = len(le_healthy.classes_)
    # print(num_classes_prdtype)
    # print(num_classes_healthy)

    custom_resnet_model = MultiHeadResNet_SmallModel(
        num_classes_prdtype=num_classes_prdtype,
        num_classes_weight=num_classes_weight,
        num_classes_halal=num_classes_halal,
        num_classes_healthy=num_classes_healthy
    )

    model_path = '../small_model/output/multi_head_model.pth'
    # print("test1")
    if os.path.exists(model_path):
        custom_resnet_model.load_state_dict(torch.load(model_path, map_location=CONFIGS['DEVICE']))
    else:
        raise FileNotFoundError(f"Model file not found: {model_path}")
    # print("test2")
    custom_resnet_model.to(CONFIGS['DEVICE'])
    custom_resnet_model.eval()
    return custom_resnet_model
 
small_model = load_model()

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


## Scoring on new imgs

In [22]:
new_imgs_df = pd.read_csv("../small_model/new_imgs_list.csv")
new_imgs_df.reset_index(drop=True, inplace=True)
new_imgs_df.head()

Unnamed: 0,filepath,label,ProductType,Weight,HalalStatus,HealthStatus
0,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
1,5141704785419_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
2,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
3,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy
4,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,JennyBakery,500-599g,NonHalal,NonHealthy


In [23]:
new_imgs_results_small_model = pd.read_csv("../small_model/new_imgs_results_small_model.csv")
new_imgs_results_small_model = new_imgs_results_small_model.loc[new_imgs_results_small_model.Filename.isin(new_imgs_df.filepath)]
new_imgs_results_small_model.reset_index(drop=True, inplace=True)
new_imgs_results_small_model.head()

Unnamed: 0,Filename,CorrectTotalLabel,ProductType_AdultMilk,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_CerealBeverage-powder,ProductType_Coffee,ProductType_HoneyOtherSpreads,ProductType_JennyBakery,ProductType_Nuts,...,Weight_1-99g,Weight_200-299g,Weight_300-399g,Weight_400-499g,Weight_500-599g,Weight_800-899g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.147366,-2.371338,-1.988551,-0.223216,-1.250532,-2.203384,7.016516,-1.064522,...,-2.214612,-1.035467,-1.500059,-2.540139,5.367744,-2.173713,-2.707652,2.858616,-4.770081,0.812059
1,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-0.632431,-1.257829,-1.865115,-1.02985,-1.761014,-1.383238,6.514539,-1.818218,...,-2.43755,-2.271448,-1.792775,-1.941356,6.102351,-1.492687,-3.665095,3.724813,-4.397122,0.834202
2,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.465411,-2.08965,-1.700874,-0.933922,-1.756638,-0.89477,6.986915,-1.52172,...,-2.14349,-1.823497,-1.557975,-2.061523,5.425486,-1.660315,-3.228247,3.492168,-4.48942,1.09929
3,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.558612,-1.284596,-1.218449,-0.979794,-1.369679,-1.926156,6.73735,-1.105155,...,-1.399189,-1.711422,-2.247945,-1.673357,5.824388,-1.978551,-2.403049,3.220788,-4.598861,0.903832
4,5191704785428_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.998316,-1.860975,-1.638578,-0.757086,-1.686744,-1.167751,6.4787,-0.786045,...,-2.267077,-2.070362,-1.445823,-2.463024,5.510421,-1.757346,-2.597294,2.747287,-4.970373,1.185428


In [24]:
new_imgs_results_small_model.shape

(10, 21)

In [25]:
# Create a copy of the current column names to a list
new_columns = new_imgs_results_small_model.columns.tolist()

# Modify the first two elements
new_columns[0] = 'filepath'
new_columns[1] = 'label'

# Assign the modified list of column names back to the DataFrame
new_imgs_results_small_model.columns = new_columns
new_imgs_results_small_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_CerealBeverage-powder,ProductType_Coffee,ProductType_HoneyOtherSpreads,ProductType_JennyBakery,ProductType_Nuts,...,Weight_1-99g,Weight_200-299g,Weight_300-399g,Weight_400-499g,Weight_500-599g,Weight_800-899g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.147366,-2.371338,-1.988551,-0.223216,-1.250532,-2.203384,7.016516,-1.064522,...,-2.214612,-1.035467,-1.500059,-2.540139,5.367744,-2.173713,-2.707652,2.858616,-4.770081,0.812059
1,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-0.632431,-1.257829,-1.865115,-1.02985,-1.761014,-1.383238,6.514539,-1.818218,...,-2.43755,-2.271448,-1.792775,-1.941356,6.102351,-1.492687,-3.665095,3.724813,-4.397122,0.834202
2,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.465411,-2.08965,-1.700874,-0.933922,-1.756638,-0.89477,6.986915,-1.52172,...,-2.14349,-1.823497,-1.557975,-2.061523,5.425486,-1.660315,-3.228247,3.492168,-4.48942,1.09929
3,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.558612,-1.284596,-1.218449,-0.979794,-1.369679,-1.926156,6.73735,-1.105155,...,-1.399189,-1.711422,-2.247945,-1.673357,5.824388,-1.978551,-2.403049,3.220788,-4.598861,0.903832
4,5191704785428_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.998316,-1.860975,-1.638578,-0.757086,-1.686744,-1.167751,6.4787,-0.786045,...,-2.267077,-2.070362,-1.445823,-2.463024,5.510421,-1.757346,-2.597294,2.747287,-4.970373,1.185428


In [26]:
# Check if any name from 'extracted_names' is not in 'df' and add it as a new column
new_prdtype = list(set(all_imgs_results_big_model.columns) - set(new_imgs_results_small_model.columns))

if len(new_prdtype)>0:
    for col in new_prdtype:
        new_imgs_results_small_model[col] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=new_imgs_results_small_model.shape[0])  # Initialize new columns

new_imgs_results_small_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_CerealBeverage-powder,ProductType_Coffee,ProductType_HoneyOtherSpreads,ProductType_JennyBakery,ProductType_Nuts,...,ProductType_MaternalMilkPowder,ProductType_NutellaChocolate,ProductType_Babymilk-powder,ProductType_Sardines,ProductType_Canned-Packet-Creamers-Sweet,ProductType_Flour,Weight_1000-1999g,Weight_900-999g,ProductType_OtherBakingNeeds,ProductType_Tea-powder-leaves-
0,5151704785420_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.147366,-2.371338,-1.988551,-0.223216,-1.250532,-2.203384,7.016516,-1.064522,...,-14.696456,-15.029445,-14.697664,-15.056149,-15.452013,-14.790854,-14.92173,-15.187017,-15.27231,-14.957445
1,5161704785422_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-0.632431,-1.257829,-1.865115,-1.02985,-1.761014,-1.383238,6.514539,-1.818218,...,-15.135721,-15.108998,-15.232126,-14.736608,-15.018925,-15.439603,-14.915325,-14.912819,-14.568728,-14.681042
2,5131704785418_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.465411,-2.08965,-1.700874,-0.933922,-1.756638,-0.89477,6.986915,-1.52172,...,-15.115242,-14.779723,-14.733982,-14.830189,-15.29867,-15.166452,-14.635122,-14.74406,-15.143663,-14.357549
3,5171704785423_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.558612,-1.284596,-1.218449,-0.979794,-1.369679,-1.926156,6.73735,-1.105155,...,-14.553035,-15.405671,-14.729835,-14.686917,-14.92776,-14.335607,-14.838941,-14.74381,-15.083127,-14.979224
4,5191704785428_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.998316,-1.860975,-1.638578,-0.757086,-1.686744,-1.167751,6.4787,-0.786045,...,-15.117487,-15.184627,-15.176205,-15.030695,-14.265709,-14.462826,-14.862856,-14.165603,-15.570852,-15.041956


In [27]:
new_imgs_results_small_model.shape

(10, 61)

## Scoring on main imgs

In [28]:
main_imgs_master_list = pd.read_csv("../master_list.csv")
main_imgs_master_list.head()

Unnamed: 0,filepath,xmin,ymin,xmax,ymax,label,ProductType,Weight,HalalStatus,HealthStatus,new_camera,tag
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,151,42,497,591,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,88,81,442,567,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,35,34,492,622,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,99,122,428,587,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,103,17,474,592,Sugar_400-499g_NonHalal_NonHealthy,Sugar,400-499g,NonHalal,NonHealthy,0,


In [29]:
main_imgs_results_small_model = []  # List to store the results
le_prdtype, le_weight, le_halal, le_healthy = load_label_encoder_small_model()

for idx, row in main_imgs_master_list.iterrows():
    image_path = "../all_images/" + row['filepath']
    frame = cv2.imread(image_path)

    # Preprocessing steps
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = cv2.resize(frame, (CONFIGS['SMALL_MODEL_IMG_SIZE'], CONFIGS['SMALL_MODEL_IMG_SIZE']))
    frame = frame.transpose((2, 0, 1))
    frame = torch.from_numpy(frame).float()
    frame = transforms_test(frame).unsqueeze(0).to(CONFIGS['DEVICE'])

    # Perform prediction
    with torch.no_grad():
        out1, out2, out3, out4 = small_model(frame)
    
    # Extract and store the results
    prediction_row = [row['filepath'], row['label']]
    prediction_row.extend(out1.cpu().numpy().flatten())
    prediction_row.extend(out2.cpu().numpy().flatten())
    prediction_row.extend(out3.cpu().numpy().flatten())
    prediction_row.extend(out4.cpu().numpy().flatten())
    main_imgs_results_small_model.append(prediction_row)


# Define column names for the new DataFrame
column_names = ['filepath', 'label']
column_names += ['ProductType_' + name for name in le_prdtype.classes_]
column_names += ['Weight_' + name for name in le_weight.classes_]
column_names += ['HalalStatus_' + name for name in le_halal.classes_]
column_names += ['HealthStatus_' + name for name in le_healthy.classes_]


# Create the DataFrame
main_imgs_results_small_model = pd.DataFrame(main_imgs_results_small_model, columns=column_names)
main_imgs_results_small_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_CerealBeverage-powder,ProductType_Coffee,ProductType_HoneyOtherSpreads,ProductType_JennyBakery,ProductType_Nuts,...,Weight_1-99g,Weight_200-299g,Weight_300-399g,Weight_400-499g,Weight_500-599g,Weight_800-899g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,Sugar_400-499g_NonHalal_NonHealthy,-2.020259,0.163589,-1.813907,-3.735976,1.387246,-0.651287,-2.600132,-0.727436,...,-1.949923,3.06635,-0.163866,0.050203,-2.694914,-1.263772,2.164521,-0.87571,-1.927328,1.24159
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,Sugar_400-499g_NonHalal_NonHealthy,-2.641503,0.462159,-0.569475,-2.081272,0.113894,1.860511,-1.228555,-1.344834,...,-2.98093,1.640464,0.333417,0.724001,-2.085833,-0.309271,-0.488723,-0.346736,-1.337615,1.005025
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,Sugar_400-499g_NonHalal_NonHealthy,-0.671336,0.410651,-1.233106,-0.897077,-0.460516,-1.281306,-1.350093,3.289432,...,-0.170819,0.474442,1.491116,-0.791732,-1.938339,-1.256038,2.769793,-2.604231,-2.954857,1.74513
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,Sugar_400-499g_NonHalal_NonHealthy,-0.981495,0.316365,-1.370192,-3.57127,2.038956,1.858516,-1.785913,-1.991065,...,-1.17929,2.026139,0.406941,0.789664,-1.613296,-1.482098,-0.939,0.84435,-2.002786,1.324591
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,Sugar_400-499g_NonHalal_NonHealthy,-2.975871,-0.488483,-1.954772,-1.458116,1.271758,-0.236836,-0.708023,0.353072,...,-2.878357,2.866438,0.611732,-1.107133,-1.39973,-1.165009,0.037813,0.362254,-3.250632,1.42515


In [30]:
main_imgs_results_small_model.shape

(3457, 21)

In [31]:
# Check if any name from 'extracted_names' is not in 'df' and add it as a new column
new_prdtype = list(set(all_imgs_results_big_model.columns) - set(main_imgs_results_small_model.columns))

if len(new_prdtype)>0:
    for col in new_prdtype:
        main_imgs_results_small_model[col] = np.random.normal(loc=CONFIGS["MEAN_PRIOR"], scale=np.sqrt(0.1), size=main_imgs_results_small_model.shape[0])  # Initialize new columns

main_imgs_results_small_model.head()  # Display the updated DataFrame for verificatio

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_CerealBeverage-powder,ProductType_Coffee,ProductType_HoneyOtherSpreads,ProductType_JennyBakery,ProductType_Nuts,...,ProductType_MaternalMilkPowder,ProductType_NutellaChocolate,ProductType_Babymilk-powder,ProductType_Sardines,ProductType_Canned-Packet-Creamers-Sweet,ProductType_Flour,Weight_1000-1999g,Weight_900-999g,ProductType_OtherBakingNeeds,ProductType_Tea-powder-leaves-
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,Sugar_400-499g_NonHalal_NonHealthy,-2.020259,0.163589,-1.813907,-3.735976,1.387246,-0.651287,-2.600132,-0.727436,...,-14.523182,-15.670701,-15.098836,-14.750618,-14.601651,-15.032916,-14.629478,-14.076728,-14.983516,-15.088025
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,Sugar_400-499g_NonHalal_NonHealthy,-2.641503,0.462159,-0.569475,-2.081272,0.113894,1.860511,-1.228555,-1.344834,...,-15.142751,-14.929293,-15.060374,-15.736236,-15.091257,-15.487034,-15.558057,-15.340283,-14.700655,-15.038746
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,Sugar_400-499g_NonHalal_NonHealthy,-0.671336,0.410651,-1.233106,-0.897077,-0.460516,-1.281306,-1.350093,3.289432,...,-15.822778,-14.624924,-15.273829,-15.047397,-15.183998,-15.031076,-15.190045,-14.509043,-15.013265,-15.480076
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,Sugar_400-499g_NonHalal_NonHealthy,-0.981495,0.316365,-1.370192,-3.57127,2.038956,1.858516,-1.785913,-1.991065,...,-15.238906,-15.691112,-15.272874,-15.001783,-14.726097,-15.203532,-15.368335,-14.785217,-14.786104,-15.488359
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,Sugar_400-499g_NonHalal_NonHealthy,-2.975871,-0.488483,-1.954772,-1.458116,1.271758,-0.236836,-0.708023,0.353072,...,-14.739255,-15.0759,-14.665226,-15.199823,-15.588275,-14.399301,-14.750325,-15.193965,-14.934311,-14.80777


In [32]:
main_imgs_results_small_model.shape

(3457, 61)

## All scorings from small model

In [33]:
all_imgs_results_small_model = pd.concat([main_imgs_results_small_model, new_imgs_results_small_model], axis=0)
all_imgs_results_small_model.reset_index(drop=True, inplace=True)
all_imgs_results_small_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_CerealBeverage-powder,ProductType_Coffee,ProductType_HoneyOtherSpreads,ProductType_JennyBakery,ProductType_Nuts,...,ProductType_MaternalMilkPowder,ProductType_NutellaChocolate,ProductType_Babymilk-powder,ProductType_Sardines,ProductType_Canned-Packet-Creamers-Sweet,ProductType_Flour,Weight_1000-1999g,Weight_900-999g,ProductType_OtherBakingNeeds,ProductType_Tea-powder-leaves-
0,IMG_20230428_123528_jpg.rf.5687b7b914f6d9aa98c...,Sugar_400-499g_NonHalal_NonHealthy,-2.020259,0.163589,-1.813907,-3.735976,1.387246,-0.651287,-2.600132,-0.727436,...,-14.523182,-15.670701,-15.098836,-14.750618,-14.601651,-15.032916,-14.629478,-14.076728,-14.983516,-15.088025
1,IMG_20230428_123522_jpg.rf.204ff37f497f2dce442...,Sugar_400-499g_NonHalal_NonHealthy,-2.641503,0.462159,-0.569475,-2.081272,0.113894,1.860511,-1.228555,-1.344834,...,-15.142751,-14.929293,-15.060374,-15.736236,-15.091257,-15.487034,-15.558057,-15.340283,-14.700655,-15.038746
2,IMG_20230428_123708_jpg.rf.141ecd0cefaea75c0b7...,Sugar_400-499g_NonHalal_NonHealthy,-0.671336,0.410651,-1.233106,-0.897077,-0.460516,-1.281306,-1.350093,3.289432,...,-15.822778,-14.624924,-15.273829,-15.047397,-15.183998,-15.031076,-15.190045,-14.509043,-15.013265,-15.480076
3,IMG_20230428_123521_jpg.rf.1069b402272252862ec...,Sugar_400-499g_NonHalal_NonHealthy,-0.981495,0.316365,-1.370192,-3.57127,2.038956,1.858516,-1.785913,-1.991065,...,-15.238906,-15.691112,-15.272874,-15.001783,-14.726097,-15.203532,-15.368335,-14.785217,-14.786104,-15.488359
4,IMG_20230428_123659_jpg.rf.5e1b6c4caabe48cf360...,Sugar_400-499g_NonHalal_NonHealthy,-2.975871,-0.488483,-1.954772,-1.458116,1.271758,-0.236836,-0.708023,0.353072,...,-14.739255,-15.0759,-14.665226,-15.199823,-15.588275,-14.399301,-14.750325,-15.193965,-14.934311,-14.80777


In [34]:
all_imgs_results_small_model.tail()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_CerealBeverage-powder,ProductType_Coffee,ProductType_HoneyOtherSpreads,ProductType_JennyBakery,ProductType_Nuts,...,ProductType_MaternalMilkPowder,ProductType_NutellaChocolate,ProductType_Babymilk-powder,ProductType_Sardines,ProductType_Canned-Packet-Creamers-Sweet,ProductType_Flour,Weight_1000-1999g,Weight_900-999g,ProductType_OtherBakingNeeds,ProductType_Tea-powder-leaves-
3462,5181704785427_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.171966,-1.259724,-1.661261,-0.965875,-1.458263,-1.110078,5.579266,-1.103462,...,-15.15211,-14.441778,-14.739814,-14.631389,-15.041283,-14.976716,-15.46541,-14.862346,-15.080062,-15.348622
3463,5201704785430_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-2.130002,-2.354889,-0.707954,-1.655996,-1.911104,-0.936194,7.391262,-1.22591,...,-15.017373,-15.140136,-15.010095,-15.213837,-15.381639,-15.159058,-15.59905,-14.750674,-15.250578,-14.836721
3464,5221704785433_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.16544,-1.391705,-1.73015,-1.116981,-1.299423,-1.351957,5.95632,-1.349616,...,-14.711848,-14.72053,-14.777285,-14.464187,-14.804071,-14.788373,-14.590291,-15.10588,-14.973858,-14.922586
3465,5141704785419_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-0.644989,-1.19498,-1.084898,-1.202564,-1.92269,-0.043557,5.377807,-0.971059,...,-14.803693,-14.678619,-14.271427,-14.738265,-14.970498,-14.701385,-15.078151,-14.74308,-14.718485,-15.364645
3466,5211704785432_.pic.jpg,JennyBakery_500-599g_NonHalal_NonHealthy,-1.348141,-1.360862,-1.819118,-1.04397,-1.776025,-1.458479,6.231342,-1.379669,...,-15.007516,-14.80348,-15.406728,-14.627572,-14.975432,-14.692928,-14.768497,-15.050787,-15.197558,-14.478196


In [35]:
all_imgs_results_small_model.shape

(3467, 61)

In [36]:
all_imgs_results_small_model.to_csv("all_imgs_results_small_model.csv", index=True)

# Bayesian model

In [37]:
prdtype_cols = [col for col in all_imgs_results_small_model.columns if col.startswith('ProductType_')]

In [36]:
all_imgs_results_small_model_prdtype = all_imgs_results_small_model[['label']+prdtype_cols]
all_imgs_results_big_model_prdtype = all_imgs_results_big_model[['label']+prdtype_cols]
all_imgs_results_small_model_prdtype = all_imgs_results_small_model_prdtype.sort_values(by='label').reset_index(drop=True)
all_imgs_results_big_model_prdtype = all_imgs_results_big_model_prdtype.sort_values(by='label').reset_index(drop=True)

In [37]:
assert (all_imgs_results_small_model_prdtype['label'][(all_imgs_results_small_model_prdtype['label'] == all_imgs_results_big_model_prdtype['label'])]).all()

In [38]:
all_imgs_results_small_model_prdtype['label_prdtype'] = all_imgs_results_small_model_prdtype['label'].str.split('_').str[0]
all_imgs_results_big_model_prdtype['label_prdtype'] = all_imgs_results_big_model_prdtype['label'].str.split('_').str[0]

In [39]:
# Remove the prefix from column names
all_imgs_results_small_model_prdtype.columns = [col.replace("ProductType_", '') if col.startswith("ProductType_") else col for col in all_imgs_results_small_model_prdtype.columns]
all_imgs_results_big_model_prdtype.columns = [col.replace("ProductType_", '') if col.startswith("ProductType_") else col for col in all_imgs_results_big_model_prdtype.columns]

In [40]:
prdtype_label_encoder = LabelEncoder()
truelabel = prdtype_label_encoder.fit_transform(all_imgs_results_big_model_prdtype['label_prdtype'])

In [41]:
# Assuming 'category_names' is the list of unique category names in the order they appear in logitscoresA
category_names = list(all_imgs_results_small_model_prdtype['label_prdtype'].unique())
category_to_encoded = {name: prdtype_label_encoder.transform([name])[0] for name in category_names}

# Reorder columns of logitscoresA and logitscoresB to match the order of encoded labels
ordered_columns = [category_names[i] for i in prdtype_label_encoder.transform(category_names)]
logitscoresA = all_imgs_results_big_model_prdtype[ordered_columns].values
logitscoresB = all_imgs_results_small_model_prdtype[ordered_columns].values


In [42]:
all_imgs_results_big_model.head()

Unnamed: 0,filepath,label,ProductType_AdultMilk,ProductType_Babyfood,ProductType_Babymilk-powder,ProductType_BeehoonVermicelliMeesua,ProductType_BiscuitsCrackersCookies,ProductType_Book,ProductType_Breakfast-cereals-cornflakes,ProductType_Canned-Packet-Creamers-Sweet,...,Weight_500-599g,Weight_600-699g,Weight_700-799g,Weight_800-899g,Weight_900-999g,HalalStatus_Halal,HalalStatus_NonHalal,HealthStatus_Healthy,HealthStatus_NonHealthy,ProductType_JennyBakery
0,20231215_output_frame_0244,PotatoSticks_1-99g_Halal_NonHealthy,-1.587444,-2.496907,-3.559909,-2.273423,1.671694,-0.49073,-3.641042,-2.682706,...,-1.938856,-3.626286,-2.632926,-4.078401,-3.550343,2.026452,-3.092768,-3.751043,2.996024,-19.986669
1,output_frame_0617,Kaya_400-499g_Halal_NonHealthy,-4.127325,-1.5334,-1.689596,-1.041802,-1.024461,-4.602612,-2.558664,0.07872,...,-1.308213,-4.263793,-2.572152,-3.962276,-3.851224,4.229371,-4.268166,-4.517332,4.193677,-20.119013
2,IMG_5422_jpeg.rf.40db706026c5805998a4ef32aaea01bd,BiscuitsCrackersCookies_100-199g_Halal_NonHealthy,-3.575266,-1.385911,-2.772634,-0.931874,7.930703,-4.527698,-3.865532,-3.374679,...,-3.05643,-2.557728,-2.868939,-1.775112,-3.771607,2.320258,-3.072567,-3.815917,3.348428,-19.627854
3,IMG_9829_JPG.rf.9dcff8b32299bc2d7285db157c38e6d6,OtherNoodles_500-599g_NonHalal_NonHealthy,-3.725879,-4.687295,-3.840399,-1.293027,-0.585406,-5.36357,-4.092983,-4.843928,...,4.860765,-1.20123,-0.325409,-2.343176,-3.91664,-2.494664,2.036901,-4.395383,3.689328,-20.315636
4,2023_10_25_11_52_33_260721,MaternalMilkPowder_600-699g_Halal_NonHealthy,2.656255,0.116325,2.72413,-3.962348,-1.72144,-6.606655,-3.900061,-5.223273,...,-1.551353,7.838285,-1.78831,-0.916579,-1.628614,1.99219,-3.352175,-3.212959,2.271578,-20.562292


In [43]:
# big model accuracy - total
pred_big_model_prdtype = np.argmax(logitscoresA, axis=1)
sum(pred_big_model_prdtype == truelabel) / len(truelabel)

0.9798096336890684

In [44]:
# small model accuracy - total
pred_small_model_prdtype = np.argmax(logitscoresB, axis=1)
sum(pred_small_model_prdtype == truelabel) / len(truelabel)

0.23853475627343523

In [45]:
# big model accuracy - new imgs
indices = np.where(truelabel == category_to_encoded['JennyBakery'])
sum(pred_big_model_prdtype[indices] == truelabel[indices]) / len(indices[0].tolist())

0.0

In [46]:
# small model accuracy - new imgs
indices = np.where(truelabel == category_to_encoded['JennyBakery'])
sum(pred_small_model_prdtype[indices] == truelabel[indices]) / len(indices[0].tolist())

1.0

In [47]:
len(truelabel)

3467

In [48]:
logitscoresA.shape[1]

43

In [49]:
import pymc3 as pm
import theano.tensor as tt
import numpy as np
import scipy.stats

# Sample data setup (replace with your actual data)
# logitscoresA and logitscoresB are matrices of logit scores for each category from classifiers A and B
# truelabel is an already existing 1D array of integers representing the true labels
indices = [np.random.choice(100, 3, replace=False)]  # Replace with your indices for missing data

N = len(truelabel)
L = logitscoresA.shape[1]
missingidx = indices[0].tolist()  # Indices of missing data

# Initialize truelabel_with_missing with the original truelabel and set missing indices to -1
truelabel_with_missing = np.array(truelabel, dtype=np.int)
truelabel_with_missing[missingidx] = -1

# Mask the missing values
masked_truelabel = np.ma.masked_where(truelabel_with_missing == -1, truelabel_with_missing)

with pm.Model() as model:
    # Priors
    muA1 = pm.Normal('muA1', mu=0, sigma=10)
    muA0 = pm.Normal('muA0', mu=0, sigma=10)
    sigmaA = pm.Uniform('sigmaA', lower=0.01, upper=1.0)
    muB1 = pm.Normal('muB1', mu=0, sigma=10)
    muB0 = pm.Normal('muB0', mu=0, sigma=10)
    sigmaB = pm.Uniform('sigmaB', lower=0.01, upper=1.0)
    rho = pm.Uniform('rho', lower=-1, upper=1)
    
    # Uniform prior over labels
    labelprob = pm.Dirichlet('labelprob', a=tt.ones(L))

    # Likelihood
    muA = pm.math.switch(tt.eq(tt.arange(L), masked_truelabel[:, None]), muA1, muA0)
    muB = pm.math.switch(tt.eq(tt.arange(L), masked_truelabel[:, None]), muB1, muB0)
    
    logitscoresA_obs = pm.Normal('logitscoresA_obs', mu=muA, sigma=sigmaA, observed=logitscoresA)
    logitscoresB_obs = pm.Normal('logitscoresB_obs', mu=muB + rho * (logitscoresA - muA) / sigmaA, sigma=tt.sqrt((1 - rho ** 2) * sigmaB ** 2), observed=logitscoresB)
    
    # Define the categorical distribution for the true labels
    truelabel_obs = pm.Categorical('truelabel_obs', p=labelprob, observed=masked_truelabel)

    # Inference
    trace = pm.sample(2000, tune=500, cores=1)

    # Plotting within the model context
    # az.plot_trace(trace)
    # plt.show()

    # Posterior predictive checks
    ppc = pm.sample_posterior_predictive(trace, var_names=['truelabel_obs'])

# Process the posterior predictive checks for missing indices
infer_labels = []
for idx in missingidx:
    label_samples = ppc['truelabel_obs'][:, idx]
    inferred_label = scipy.stats.mode(label_samples).mode[0]
    infer_labels.append(inferred_label)

# Output the inferred labels for missing indices
print("Inferred labels for missing indices:", infer_labels)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  app.launch_new_instance()
  return wrapped_(*args_, **kwargs_)
Sequential sampling (2 chains in 1 job)
CompoundStep
>NUTS: [labelprob, rho, sigmaB, muB0, muB1, sigmaA, muA0, muA1]
>CategoricalGibbsMetropolis: [truelabel_obs_missing]


  return _boost._beta_ppf(q, a, b)
Sampling 1 chain for 500 tune and 371 draw iterations (500 + 371 draws total) took 5946 seconds.
The acceptance probability does not match the target. It is 0.9999996489687598, but should be close to 0.8. Try to increase the number of tuning steps.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
Only one chain was sampled, this makes it impossible to run some convergence checks


Inferred labels for missing indices: [8, 10, 42]


In [50]:
np.unique(infer_labels)

array([ 8, 10, 42])

In [51]:
~np.isin(np.arange(N), missingidx)

array([ True,  True,  True, ...,  True,  True,  True])

In [52]:
np.array(truelabel, dtype=np.int)[missingidx]

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  """Entry point for launching an IPython kernel.


array([0, 0, 1])

In [53]:
pred_small_model_prdtype[missingidx]

array([ 0, 11,  0])

In [54]:
pred_big_model_prdtype[missingidx]

array([0, 0, 1])

In [55]:
missingidx

[53, 24, 66]

In [56]:
ppc['truelabel_obs'].shape

(371, 3467)