# ImageNet 1K Data Quality


In [1]:
import lance
import duckdb
import torchvision
import torch
import pandas as pd

In [2]:
%load_ext sql
%sql duckdb:///:memory:

{}


In [3]:
uri = "imagenet.lance"

ds = lance.dataset(uri)

In [4]:
ds.schema

id: int32
image: extension<image[binary]<ImageBinaryType>>
label: int16
name: dictionary<values=string, indices=int16, ordered=0>
split: dictionary<values=string, indices=int8, ordered=0>

In [6]:
# %%sql --lance

# SELECT image, label FROM ds LIMIT 5

In [7]:
%%sql

SELECT split, count(split) FROM ds GROUP BY split

Took 0.006806373596191406


Unnamed: 0,split,count(split)
0,train,10000
1,test,10000
2,validation,10000


# Use two official pre-trained models ResNet and VisionTransform

* The ResNet model is based on the [Deep Residual Learning for Image Recognition paper](https://arxiv.org/abs/1512.03385)
* The VisionTransformer model is based on the [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale paper](https://arxiv.org/abs/2010.11929).

In [52]:
from torchvision.models import resnet50, vit_b_16

resnet = resnet50(weights="DEFAULT").cuda()
vit = vit_b_16(weights="DEFAULT").cuda()

In [53]:
# TODO: make easy conversion between lance.Dataset and pytorch.Dataset

from lance.pytorch import Dataset
dataset = Dataset(uri, columns=["id", "image"], mode="batch")

In [22]:
transform = torchvision.models.ResNet50_Weights.DEFAULT.transforms()

def run_inference(uri: str, model, transform, col_name: str) -> pa.Table:
    dataset = Dataset(
        uri, 
        columns=["id", "image"],
        mode="batch",
        batch_size=128
    )
    results = []
    with torch.no_grad():
        model.eval()
        
        for batch in dataset:
            imgs = [transform(img).cuda() for img in batch[1]]
            prediction = resnet(torch.stack(imgs)).squeeze(0).softmax(0)
            topk = torch.topk(prediction, 2)
            for pk, scores, indices in zip(
                batch[0], topk.values.tolist(), topk.indices.tolist()
            ):
                results.append({
                    "id": pk.item(),
                    col_name: {
                        "label": indices[0], 
                        "score": scores[0], 
                        "second_label": indices[1],  # Secondary guess
                        "second_score": scores[1],  # Confidence of the secondary guess.
                    }
                })
    df = pd.DataFrame(data=results)
    df = df.astype({"id": "int32"})
    return df

resnet_df = run_inference(uri, resnet, torchvision.models.ResNet50_Weights.DEFAULT.transforms(), "resnet")
vit_df = run_inference(uri, vit, torchvision.models.ViT_L_16_Weights.DEFAULT.transforms(), "vit")

display(resnet_df)
display(vit_df)

Unnamed: 0,id,resnet
0,1,"{'label': 726, 'score': 0.694125771522522, 'se..."
1,2,"{'label': 917, 'score': 0.8421049118041992, 's..."
2,3,"{'label': 13, 'score': 0.8555712103843689, 'se..."
3,4,"{'label': 939, 'score': 0.9331550002098083, 's..."
4,5,"{'label': 6, 'score': 0.9336700439453125, 'sec..."
...,...,...
29995,29996,"{'label': 658, 'score': 0.9771027565002441, 's..."
29996,29997,"{'label': 433, 'score': 0.9925998449325562, 's..."
29997,29998,"{'label': 47, 'score': 0.9799998998641968, 'se..."
29998,29999,"{'label': 801, 'score': 0.9724040627479553, 's..."


Unnamed: 0,id,vit
0,1,"{'label': 726, 'score': 0.6790212392807007, 's..."
1,2,"{'label': 917, 'score': 0.8430271148681641, 's..."
2,3,"{'label': 13, 'score': 0.8379276990890503, 'se..."
3,4,"{'label': 939, 'score': 0.9347924590110779, 's..."
4,5,"{'label': 6, 'score': 0.9419549703598022, 'sec..."
...,...,...
29995,29996,"{'label': 658, 'score': 0.9710634350776672, 's..."
29996,29997,"{'label': 433, 'score': 0.9846142530441284, 's..."
29997,29998,"{'label': 47, 'score': 0.9774839878082275, 'se..."
29998,29999,"{'label': 801, 'score': 0.9822220802307129, 's..."


In [23]:
import pyarrow as pa
resnet_table = pa.Table.from_pandas(resnet_df)
vit_table = pa.Table.from_pandas(vit_df)


In [24]:
ds = ds.merge(resnet_table, left_on="id", right_on="id")
ds = ds.merge(vit_table, left_on="id", right_on="id")

In [26]:
ds = lance.dataset(uri)

In [28]:
%%sql

SELECT * EXCLUDE(image)
FROM ds 
WHERE 
  label != resnet.label AND split != 'test'
LIMIT 20


Took 0.00791788101196289


Unnamed: 0,id,label,name,split,resnet,vit
0,13,575,"golfcart, golf cart",train,"{'label': 785, 'score': 0.3907656967639923, 's...","{'label': 785, 'score': 0.35822126269340515, '..."
1,16,219,"cocker spaniel, English cocker spaniel, cocker",train,"{'label': 213, 'score': 0.8196388483047485, 's...","{'label': 220, 'score': 0.6967827677726746, 's..."
2,45,968,cup,train,"{'label': 899, 'score': 0.4809149503707886, 's...","{'label': 529, 'score': 0.6069347858428955, 's..."
3,52,304,"leaf beetle, chrysomelid",train,"{'label': 305, 'score': 0.8165855407714844, 's...","{'label': 305, 'score': 0.783991813659668, 'se..."
4,54,940,spaghetti squash,train,"{'label': 936, 'score': 0.70237797498703, 'sec...","{'label': 936, 'score': 0.7519966959953308, 's..."
5,60,67,"diamondback, diamondback rattlesnake, Crotalus...",train,"{'label': 68, 'score': 0.8087248802185059, 'se...","{'label': 68, 'score': 0.8199281096458435, 'se..."
6,71,459,"brassiere, bra, bandeau",train,"{'label': 824, 'score': 0.8276260495185852, 's...","{'label': 824, 'score': 0.895666778087616, 'se..."
7,79,782,"screen, CRT screen",train,"{'label': 664, 'score': 0.8757317066192627, 's...","{'label': 664, 'score': 0.8014572858810425, 's..."
8,86,256,"Newfoundland, Newfoundland dog",train,"{'label': 205, 'score': 0.6166163086891174, 's...","{'label': 208, 'score': 0.601875901222229, 'se..."
9,105,371,"patas, hussar monkey, Erythrocebus patas",train,"{'label': 370, 'score': 0.8230134844779968, 's...","{'label': 370, 'score': 0.8000330328941345, 's..."


# Find Potential Mislabels

If the two models strongly agree with each other (i.e., same label and confience score is high), however, the predict label is not what ground truth describes.


In [51]:
%%sql

WITH label_names AS (SELECT DISTINCT label, name FROM ds)

SELECT ds.name AS gt, ds.label as gt_label,
  resnet.label as resnet_label,
  vit.label as vit_label,
  label_names.name as predict_name,
  resnet.score as predict_score
FROM ds, label_names 
WHERE
  split != 'test'
  AND ds.label !=  resnet.label 
  AND resnet.label == vit.label
  AND resnet.label = label_map.label
ORDER BY resnet.score DESC
LIMIT 20

Took 0.017688751220703125


Unnamed: 0,gt,gt_label,resnet_label,vit_label,predict_name,predict_score
0,holster,597,465,465,bulletproof vest,0.992044
1,crate,519,950,950,orange,0.981238
2,stretcher,830,407,407,ambulance,0.980197
3,"harmonica, mouth organ, harp, mouth harp",593,684,684,"ocarina, sweet potato",0.979551
4,table lamp,846,619,619,"lampshade, lamp shade",0.979382
5,goose,99,921,921,"book jacket, dust cover, dust jacket, dust wra...",0.978036
6,jigsaw puzzle,611,207,207,golden retriever,0.977936
7,plate rack,729,534,534,"dishwasher, dish washer, dishwashing machine",0.976898
8,"black stork, Ciconia nigra",128,23,23,vulture,0.97683
9,"passenger car, coach, carriage",705,820,820,steam locomotive,0.97505


The reverse order of the above query (`ORDER BY score ASC`) is also very informative, 
as it shows where the worst detections cross different models.

In [56]:
%%sql

WITH label_names AS (SELECT DISTINCT label, name FROM ds)

SELECT ds.name AS gt, ds.label as gt_label,
  resnet.label as resnet_label,
  vit.label as vit_label,
  label_names.name as predict_name,
  resnet.score as resnet_score,
  vit.score as vit_score
FROM ds, label_names 
WHERE
  split != 'test'
  AND ds.label !=  resnet.label 
  AND resnet.label == vit.label
  AND resnet.label = label_names.label
ORDER BY resnet.score ASC
LIMIT 20

Took 0.017398595809936523


Unnamed: 0,gt,gt_label,resnet_label,vit_label,predict_name,resnet_score,vit_score
0,"bobsled, bobsleigh, bob",450,670,670,"motor scooter, scooter",0.122233,0.164919
1,conch,112,113,113,snail,0.126517,0.213129
2,military uniform,652,860,860,"tobacco shop, tobacconist shop, tobacconist",0.134379,0.318211
3,"scale, weighing machine",778,826,826,"stopwatch, stop watch",0.138554,0.164369
4,"goldfinch, Carduelis carduelis",11,325,325,"sulphur butterfly, sulfur butterfly",0.14466,0.168432
5,"corkscrew, bottle screw",512,596,596,hatchet,0.148554,0.170637
6,bath towel,434,793,793,shower cap,0.156855,0.120169
7,Shih-Tzu,155,219,219,"cocker spaniel, English cocker spaniel, cocker",0.162032,0.188047
8,red wine,966,621,621,"lawn mower, mower",0.171566,0.184307
9,dumbbell,543,422,422,barbell,0.173661,0.178519
