In [16]:
import cv2
import glob
import numpy as np
import pandas as pd

import timm 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import Dataset, DataLoader

from collections import Counter
from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from torch.utils.data import Dataset, DataLoader

flist = glob.glob("data/unlabel_negative_patch_2048/*.jpg")
transforms = A.Compose([ 
    A.Resize(width=224, height=224, p=1.0),
    A.Normalize(p=1.0),
    ToTensorV2()
])

In [17]:
model_name = "resnet18"
data_util_percentage = 100

model = timm.create_model(model_name, pretrained=False, num_classes=2).to("cuda:0")
model.load_state_dict(torch.load(f"weights/{model_name}_{str(data_util_percentage)}%_weak-model/weak_model.pt"))
model.eval()

ref_patches = glob.glob(f"results/false_positive_{model_name}_{str(data_util_percentage)}%_weak-model/*.png")
ref_feats = []

for patch in ref_patches:
    query_image = cv2.imread(patch)
    query_image = transforms(image=query_image)['image']
    query_feat = model.forward_features(query_image.unsqueeze(0).to("cuda:0"))
    ref_feats.append(query_feat.reshape(1, -1).detach())

ref_feats = torch.cat(ref_feats, dim=0)

In [18]:
class PatchDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df
        self.transform = transform
           
    def __len__(self):    
        return len(self.df)
    
    def __getitem__(self, idx):
        fname = self.df[idx]
        x = cv2.imread(fname)
        x = self.transform(image=x)['image']
                
        return fname.split("/")[-1], x
    
batch_size = 32
candidates = glob.glob("data/unlabel_negative_patch_2048/*.jpg")
cand_dataset = PatchDataset(candidates, transforms)
cand_dataloader = DataLoader(cand_dataset, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=32)

In [19]:
from tqdm.notebook import tqdm

query_images = []
with tqdm(total=len(ref_feats), desc="Ref Image") as outer:
    with tqdm(total=len(cand_dataloader), desc="Query Images") as inner:
        for i, ref_feat in enumerate(ref_feats):
            line_result = {"fname": [], "dist": []}
            inner.reset()
            
            for batch in cand_dataloader:
                try:
                    cand_feat = model.forward_features(batch[1].to("cuda:0")).reshape(batch_size, -1).detach()
                    dist = nn.PairwiseDistance(p=2)(ref_feat, cand_feat).detach().cpu().numpy().tolist()
                    line_result["fname"].extend([b for b in batch[0]])
                    line_result["dist"].extend(dist)
                except Exception as e:
                    line_result["fname"].extend([b for b in batch[0]])
                    line_result["dist"].extend([9999] * batch_size)
                        
                inner.update()

            line_result = pd.DataFrame(line_result)
            query_images.extend(line_result.sort_values("dist", ascending=True).head(20).fname.values.tolist())
            outer.update()
            inner.refresh()
            
query_images = np.unique(query_images).tolist()

Ref Image:   0%|          | 0/12 [00:00<?, ?it/s]

Query Images:   0%|          | 0/46 [00:00<?, ?it/s]

In [20]:
import os
import shutil

results_folder_fname = f"results/aug-data_{model_name}_{str(data_util_percentage)}%"
if os.path.isdir(results_folder_fname) is False:
    os.mkdir(results_folder_fname)

for f in query_images:
    shutil.copyfile(f"data/unlabel_negative_patch_2048/{f.split('/')[-1]}", 
                    os.path.join(results_folder_fname, f.split('/')[-1]))