In [31]:
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from pathlib import Path
from func import ShopeeDataset, device, ShopeeNet, torch, get_train_transforms, np, f1_score_cal
from torch.utils.data import DataLoader
from sklearn.neighbors import NearestNeighbors
from efficientnet_pytorch import EfficientNet

In [2]:
path = Path.home() / 'OneDrive - Seagroup/computer_vison/shopee_item_images/'
path_img = path / 'train_images'

df = pd.read_csv(path / 'train.csv')
df["filepath"] = df["image"].map(lambda x: str(path_img / x))
df.head()

Unnamed: 0,posting_id,image,image_phash,title,label_group,filepath
0,train_129225211,0000a68812bc7e98c42888dfb1c07da0.jpg,94974f937d4c2433,Paper Bag Victoria Secret,249114794,/Users/kevin/OneDrive - Seagroup/computer_viso...
1,train_3386243561,00039780dfc94d01db8676fe789ecd05.jpg,af3f9460c2838f0f,"Double Tape 3M VHB 12 mm x 4,5 m ORIGINAL / DO...",2937985045,/Users/kevin/OneDrive - Seagroup/computer_viso...
2,train_2288590299,000a190fdd715a2a36faed16e2c65df7.jpg,b94cb00ed3e50f78,Maling TTS Canned Pork Luncheon Meat 397 gr,2395904891,/Users/kevin/OneDrive - Seagroup/computer_viso...
3,train_2406599165,00117e4fc239b1b641ff08340b429633.jpg,8514fc58eafea283,Daster Batik Lengan pendek - Motif Acak / Camp...,4093212188,/Users/kevin/OneDrive - Seagroup/computer_viso...
4,train_3369186413,00136d1cf4edede0203f32f05f660588.jpg,a6f319f924ad708c,Nescafe \xc3\x89clair Latte 220ml,3648931069,/Users/kevin/OneDrive - Seagroup/computer_viso...


In [30]:
dataset_data = ShopeeDataset(csv=df, train=True)
data_loader = DataLoader(dataset_data, batch_size=16)

print(f"Dataset Len: {len(dataset_data):,}\nImage Shape [0]: {dataset_data[0][0].shape}")

Dataset Len: 34,250
Image Shape [0]: torch.Size([3, 256, 256])


In [54]:
model_effnet = EfficientNet.from_name("efficientnet-b3").to(device)

embeddings = []
with torch.no_grad():
    for image, label in tqdm(data_loader):
        image = image.to(device)
        img_embeddings = model_effnet(image)
        img_embeddings = img_embeddings.detach().cpu().numpy()
        embeddings.append(img_embeddings)
        
# Concatenate all embeddings
all_image_embeddings = np.concatenate(embeddings)
print("image_embeddings shape: {:,}/{:,}".format(all_image_embeddings.shape[0], all_image_embeddings.shape[1]))

# Clean memory
del model_effnet
_ = gc.collect()

In [55]:
tmp.shape

(512, 512, 3)

In [12]:
CLASSES = 11014
BATCH_SIZE = 32
dim = (512, 512)
model_params = {
    'n_classes': 11014,
    'model_name': 'efficientnet_b3',
    'use_fc': False,
    'fc_dim': 512,
    'dropout': 0.0,
    'loss_module': 'arcface',
    's': 30.0,
    'margin': 0.50,
    'ls_eps': 0.0,
    'theta_zero': 0.785,
    'pretrained': True
}

model_name = 'efficientnet_b3'
model = ShopeeNet(**model_params)
model = model.to(device)
model.eval()

image_dataset = ShopeeDataset(image_paths=image_paths.values,transforms=get_train_transforms(dim))
image_loader = DataLoader(image_dataset,
                          batch_size=BATCH_SIZE,
                          pin_memory=True,
                          num_workers=4)

embeds = []
with torch.no_grad():
    for img,label in tqdm(image_loader): 
        img = img.to(device)
        label = label.to(device)
        feat, _ = model(img,label)
        image_embeddings = feat.detach().cpu().numpy()
        embeds.append(image_embeddings)
image_embeddings = np.concatenate(embeds)

Model building for efficientnet_b3 backbone


100%|██████████████████████████████████████████████████████████████████████████████| 1071/1071 [03:55<00:00,  4.54it/s]


In [4]:
model = NearestNeighbors(n_neighbors=50)
model.fit(image_embeddings)
distances, indices = model.kneighbors(image_embeddings)

In [6]:
# Iterate through different thresholds to maximize cv, run this in interactive mode, then replace else clause with a solid threshold

thresholds = list(np.arange(0.01, 35, 1))
scores = []
for threshold in thresholds:
    predictions = []
    for k in range(image_embeddings.shape[0]):
        idx = np.where(distances[k] < threshold)[0]
        ids = indices[k, idx]
        posting_ids = ' '.join(df['posting_id'].iloc[ids].values)
        predictions.append(posting_ids)
    df['pred_matches'] = predictions
    df['f1'] = f1_score_cal(df['matches'], df['pred_matches'])
    score = df['f1'].mean()
    print(f'Our f1 score for threshold {threshold} is {score}')
    scores.append(score)

thresholds_scores = pd.DataFrame({'thresholds': thresholds, 'scores': scores})
max_score = thresholds_scores[thresholds_scores['scores'] == thresholds_scores['scores'].max()]
best_threshold = max_score['thresholds'].values[0]
best_score = max_score['scores'].values[0]
print(f'Our best score is {best_score} and has a threshold {best_threshold}')

# Use threshold
predictions = []
for k in range(image_embeddings.shape[0]):
    idx = np.where(distances[k,] < best_threshold)[0]
    ids = indices[k,idx]
    posting_ids = df['posting_id'].iloc[ids].values
    predictions.append(posting_ids)

Our f1 score for threshold 0.01 is 0.045308828548734095
Our f1 score for threshold 1.01 is 0.045308828548734095
Our f1 score for threshold 2.01 is 0.045308828548734095
Our f1 score for threshold 3.01 is 0.045308828548734095
Our f1 score for threshold 4.01 is 0.045308828548734095
Our f1 score for threshold 5.01 is 0.045308828548734095
Our f1 score for threshold 6.01 is 0.045308828548734095
Our f1 score for threshold 7.01 is 0.045308828548734095
Our f1 score for threshold 8.01 is 0.045308828548734095
Our f1 score for threshold 9.01 is 0.045308828548734095
Our f1 score for threshold 10.01 is 0.045308828548734095
Our f1 score for threshold 11.01 is 0.045308828548734095
Our f1 score for threshold 12.01 is 0.045308828548734095
Our f1 score for threshold 13.01 is 0.045308828548734095
Our f1 score for threshold 14.01 is 0.045308828548734095
Our f1 score for threshold 15.01 is 0.045308828548734095
Our f1 score for threshold 16.01 is 0.045308828548734095
Our f1 score for threshold 17.01 is 0.045

In [7]:
df['image_predictions'] = predictions

In [10]:
df.to_feather('test.ftr')

In [13]:
torch.cuda.empty_cache()