# Analyze ImageNet-1K Data Quality and Model Performance



In [1]:
import lance
import duckdb
import torchvision
import torch
import pandas as pd
import pyarrow as pa

In [2]:
%load_ext sql
%sql duckdb:///:memory:

{}


In [3]:
uri = "imagenet.lance"

ds = lance.dataset(uri)

In [4]:
ds.schema

id: int32
image: extension<image[binary]<ImageBinaryType>>
label: int16
name: dictionary<values=string, indices=int16, ordered=0>
split: dictionary<values=string, indices=int8, ordered=0>

In [5]:
# %%sql --lance

# SELECT image, label FROM ds LIMIT 5

In [6]:
%%sql

SELECT split, count(split) FROM ds GROUP BY split

Took 0.005430459976196289


Unnamed: 0,split,count(split)
0,train,10000
1,validation,10000
2,test,10000


# Use two official pre-trained models ResNet and VisionTransform

* The ResNet model is based on the [Deep Residual Learning for Image Recognition paper](https://arxiv.org/abs/1512.03385)
* The VisionTransformer model is based on the [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale paper](https://arxiv.org/abs/2010.11929).

In [7]:
from torchvision.models import resnet50, vit_b_16
import torch

device = torch.device(
    'cuda' if torch.cuda.is_available() else (
        "mps" if torch.backends.mps.is_available() else 'cpu')
)

resnet = resnet50(weights="DEFAULT").to(device)
vit = vit_b_16(weights="DEFAULT").to(device)

In [8]:
# TODO: make easy conversion between lance.Dataset and lance.pytorch.Dataset

from lance.pytorch import Dataset
dataset = Dataset(uri, columns=["id", "image"], mode="batch")

In [9]:
transform = torchvision.models.ResNet50_Weights.DEFAULT.transforms()

def run_inference(uri: str, model, transform, col_name: str) -> pa.Table:
    dataset = Dataset(
        uri, 
        columns=["id", "image"],
        mode="batch",
        batch_size=128
    )
    results = []
    with torch.no_grad():
        model.eval()
        
        for batch in dataset:
            imgs = [transform(img).to(device) for img in batch[1]]
            prediction = resnet(torch.stack(imgs)).squeeze(0).softmax(0)
            topk = torch.topk(prediction, 2)
            for pk, scores, indices in zip(
                batch[0], topk.values.tolist(), topk.indices.tolist()
            ):
                results.append({
                    "id": pk.item(),
                    col_name: {
                        "label": indices[0], 
                        "score": scores[0], 
                        "second_label": indices[1],  # Secondary guess
                        "second_score": scores[1],  # Confidence of the secondary guess.
                    }
                })
    df = pd.DataFrame(data=results)
    df = df.astype({"id": "int32"})
    return pa.Table.from_pandas(df)

resnet_table = run_inference(uri, resnet, torchvision.models.ResNet50_Weights.DEFAULT.transforms(), "resnet")
vit_table = run_inference(uri, vit, torchvision.models.ViT_L_16_Weights.DEFAULT.transforms(), "vit")

In [10]:
ds = ds.merge(resnet_table, left_on="id", right_on="id")
ds = ds.merge(vit_table, left_on="id", right_on="id")

## Model Performance

In [11]:
%%sql

SELECT 
  SUM(CAST(resnet.label == label AS FLOAT)) / COUNT(label) as resnet_precision,
  SUM(CAST(vit.label == label AS FLOAT)) / COUNT(label) as vit_precision
FROM ds 
WHERE split = 'validation'

Took 0.009811878204345703


Unnamed: 0,resnet_precision,vit_precision
0,0.7921,0.789


Using DuckDB / SQL, it is trivial to slice into each label class to see model performance in each category.

In [22]:
%%sql 

SELECT
  DISTINCT(name),
  SUM(CAST(resnet.label == label AS FLOAT)) / COUNT(label) as resnet_precision,
  SUM(CAST(vit.label == label AS FLOAT)) / COUNT(label) as vit_precision
FROM ds 
WHERE split = 'validation'
GROUP BY name
ORDER BY resnet_precision ASC
LIMIT 10

Took 0.01755666732788086


Unnamed: 0,name,resnet_precision,vit_precision
0,"sunglasses, dark glasses, shades",0.0,0.0
1,corn,0.0,0.0
2,crutch,0.0,0.2
3,Appenzeller,0.1,0.2
4,"spotlight, spot",0.142857,0.142857
5,maillot,0.142857,0.142857
6,soup bowl,0.166667,0.166667
7,"projectile, missile",0.181818,0.181818
8,"garden spider, Aranea diademata",0.2,0.2
9,"green snake, grass snake",0.222222,0.111111


## Find Potential Mislabels

If the two models strongly agree with each other (i.e., same label and confience score is high), however, the predict label is not what ground truth describes.


In [14]:
%%sql

WITH label_names AS (SELECT DISTINCT label, name FROM ds)

SELECT ds.name AS gt, ds.label as gt_label,
  resnet.label as resnet_label,
  vit.label as vit_label,
  label_names.name as predict_name,
  resnet.score as predict_score
FROM ds, label_names 
WHERE
  split != 'test'
  AND ds.label !=  resnet.label 
  AND resnet.label == vit.label
  AND resnet.label = label_names.label
ORDER BY resnet.score DESC
LIMIT 20

Took 0.02150559425354004


Unnamed: 0,gt,gt_label,resnet_label,vit_label,predict_name,predict_score
0,holster,597,465,465,bulletproof vest,0.992044
1,crate,519,950,950,orange,0.981238
2,stretcher,830,407,407,ambulance,0.980197
3,"harmonica, mouth organ, harp, mouth harp",593,684,684,"ocarina, sweet potato",0.979551
4,table lamp,846,619,619,"lampshade, lamp shade",0.979382
5,goose,99,921,921,"book jacket, dust cover, dust jacket, dust wra...",0.978036
6,jigsaw puzzle,611,207,207,golden retriever,0.977936
7,plate rack,729,534,534,"dishwasher, dish washer, dishwashing machine",0.976898
8,"black stork, Ciconia nigra",128,23,23,vulture,0.97683
9,"passenger car, coach, carriage",705,820,820,steam locomotive,0.97505


The reverse order of the above query (`ORDER BY score ASC`) is also very informative, 
as it shows where the weak agreement cross different models.

In [15]:
%%sql

WITH label_names AS (SELECT DISTINCT label, name FROM ds)

SELECT ds.name AS gt, ds.label as gt_label,
  resnet.label as resnet_label,
  vit.label as vit_label,
  label_names.name as predict_name,
  resnet.score as resnet_score,
  vit.score as vit_score
FROM ds, label_names 
WHERE
  split != 'test'
  AND ds.label !=  resnet.label 
  AND resnet.label == vit.label
  AND resnet.label = label_names.label
ORDER BY resnet.score ASC
LIMIT 20

Took 0.01698017120361328


Unnamed: 0,gt,gt_label,resnet_label,vit_label,predict_name,resnet_score,vit_score
0,"bobsled, bobsleigh, bob",450,670,670,"motor scooter, scooter",0.122233,0.164919
1,conch,112,113,113,snail,0.126517,0.213129
2,military uniform,652,860,860,"tobacco shop, tobacconist shop, tobacconist",0.134379,0.318211
3,"scale, weighing machine",778,826,826,"stopwatch, stop watch",0.138554,0.164369
4,"goldfinch, Carduelis carduelis",11,325,325,"sulphur butterfly, sulfur butterfly",0.14466,0.168432
5,"corkscrew, bottle screw",512,596,596,hatchet,0.148554,0.170637
6,bath towel,434,793,793,shower cap,0.156855,0.120169
7,Shih-Tzu,155,219,219,"cocker spaniel, English cocker spaniel, cocker",0.162032,0.188047
8,red wine,966,621,621,"lawn mower, mower",0.171566,0.184307
9,dumbbell,543,422,422,barbell,0.173661,0.178519
