# ImageNet 1K data quality analyze


In [1]:
import lance
import duckdb
import torchvision
import torch
import pandas as pd

In [2]:
%load_ext sql
%sql duckdb:///:memory:

{}


In [3]:
uri = "imagenet.lance"

ds = lance.dataset(uri)

In [4]:
ds.schema

id: int32
image: extension<image[binary]<ImageBinaryType>>
label: int16
name: dictionary<values=string, indices=int16, ordered=0>
split: dictionary<values=string, indices=int8, ordered=0>

In [5]:
ds

<lance.lib.FileSystemDataset at 0x7f24ff2a54b0>

In [6]:
# %%sql --lance

# SELECT image, label FROM ds LIMIT 5

In [7]:
%%sql

SELECT split, count(split) FROM ds GROUP BY split

Took 0.005731344223022461


Unnamed: 0,split,count(split)
0,train,10000
1,test,10000
2,validation,10000


In [8]:
from torchvision.models import resnet50, vit_b_16

resnet = resnet50(weights="DEFAULT").cuda()
vit = vit_b_16(weights="DEFAULT").cuda()

In [9]:
# TODO: make easy conversion between lance.Dataset and pytorch.Dataset

from lance.pytorch import Dataset
dataset = Dataset(uri, columns=["id", "image"], mode="batch")
print(dataset.schema)

id: int32
image: extension<image[binary]<ImageBinaryType>>
label: int16
name: dictionary<values=string, indices=int16, ordered=0>
split: dictionary<values=string, indices=int8, ordered=0>


In [10]:
transform = torchvision.models.ResNet50_Weights.DEFAULT.transforms()
dataset = Dataset(
    uri, 
    columns=["id", "image"],
    mode="batch",
    batch_size=128
)
results = []
with torch.no_grad():
    resnet.eval()
    for batch in dataset:
        imgs = [transform(img).cuda() for img in batch[1]]
        prediction = resnet(torch.stack(imgs)).squeeze(0).softmax(0)
        topk = torch.topk(prediction, 2)
        for pk, scores, indices in zip(
            batch[0], topk.values.tolist(), topk.indices.tolist()
        ):
            results.append({
                "id": pk.item(),
                "resnet": {
                    "label": indices[0], 
                    "score": scores[0], 
                    "second_label": indices[1],  # Secondary guess
                    "second_score": scores[1],  # Confidence of the secondary guess.
                }
            })
            
df = pd.DataFrame(data=results)
df.astype({"id": "int32"})
df

Unnamed: 0,id,resnet
0,1,"{'label': 726, 'score': 0.694125771522522, 'se..."
1,2,"{'label': 917, 'score': 0.8421049118041992, 's..."
2,3,"{'label': 13, 'score': 0.8555712103843689, 'se..."
3,4,"{'label': 939, 'score': 0.9331550002098083, 's..."
4,5,"{'label': 6, 'score': 0.9336700439453125, 'sec..."
...,...,...
29995,29996,"{'label': 658, 'score': 0.9771027565002441, 's..."
29996,29997,"{'label': 433, 'score': 0.9925998449325562, 's..."
29997,29998,"{'label': 47, 'score': 0.9799998998641968, 'se..."
29998,29999,"{'label': 801, 'score': 0.9724040627479553, 's..."


In [19]:
import pyarrow as pa
df = df.astype({"id": "int32"})
tab = pa.Table.from_pandas(df)


In [12]:
ds = ds.merge(tab, left_on="id", right_on="id")

In [13]:
ds.schema

id: int32
image: extension<image[binary]<ImageBinaryType>>
label: int16
name: dictionary<values=string, indices=int16, ordered=0>
split: dictionary<values=string, indices=int8, ordered=0>
resnet: struct<label: int64, score: double, second_label: int64, second_score: double>
  child 0, label: int64
  child 1, score: double
  child 2, second_label: int64
  child 3, second_score: double

In [14]:
ds = lance.dataset(uri)

In [15]:
ds.schema

id: int32
image: extension<image[binary]<ImageBinaryType>>
label: int16
name: dictionary<values=string, indices=int16, ordered=0>
split: dictionary<values=string, indices=int8, ordered=0>
resnet: struct<label: int64, score: double, second_label: int64, second_score: double>
  child 0, label: int64
  child 1, score: double
  child 2, second_label: int64
  child 3, second_score: double

In [18]:
%%sql

SELECT * EXCLUDE(image) FROM ds WHERE label != resnet.label LIMIT 10


(duckdb.InvalidInputException) Invalid Input Error: arrow_scan: get_next failed(): IOError: Failed to open local file 'imagenet.lance/data/imagenet.lance/data/5293ea0f-7291-4632-b77c-63d65f2fbe0f.lance'. Detail: [errno 2] No such file or directory
[SQL: SELECT * EXCLUDE(image) FROM ds WHERE label != resnet.label LIMIT 10]
(Background on this error at: https://sqlalche.me/e/14/f405)


In [17]:
duckdb.query("SELECT id, label from ds LIMIT 5").arrow()

pyarrow.Table
id: int32
label: int16
----
id: [[1,2,3,4,5]]
label: [[726,917,13,939,6]]