In [None]:
!pip -q install daft pyarrow pandas numpy scikit-learn

import os
os.environ["DO_NOT_TRACK"] = "true"

import numpy as np
import pandas as pd
import daft
from daft import col

print("Daft version:", getattr(daft, "__version__", "unknown"))

URL = "https://github.com/Eventual-Inc/mnist-json/raw/master/mnist_handwritten_test.json.gz"

df = daft.read_json(URL)
print("\nSchema (sampled):")
print(df.schema())

print("\nPeek:")
df.show(5)

In [None]:
def to_28x28(pixels):
    arr = np.array(pixels, dtype=np.float32)
    if arr.size != 784:
        return None
    return arr.reshape(28, 28)

df2 = (
    df
    .with_column(
        "img_28x28",
        col("image").apply(to_28x28, return_dtype=daft.DataType.python())
    )
    .with_column(
        "pixel_mean",
        col("img_28x28").apply(lambda x: float(np.mean(x)) if x is not None else None,
                               return_dtype=daft.DataType.float32())
    )
    .with_column(
        "pixel_std",
        col("img_28x28").apply(lambda x: float(np.std(x)) if x is not None else None,
                               return_dtype=daft.DataType.float32())
    )
)

print("\nAfter reshaping + simple features:")
df2.select("label", "pixel_mean", "pixel_std").show(5)

In [None]:
@daft.udf(return_dtype=daft.DataType.list(daft.DataType.float32()), batch_size=512)
def featurize(images_28x28):
    out = []
    for img in images_28x28.to_pylist():
        if img is None:
            out.append(None)
            continue
        img = np.asarray(img, dtype=np.float32)
        row_sums = img.sum(axis=1) / 255.0
        col_sums = img.sum(axis=0) / 255.0
        total = img.sum() + 1e-6
        ys, xs = np.indices(img.shape)
        cy = float((ys * img).sum() / total) / 28.0
        cx = float((xs * img).sum() / total) / 28.0
        vec = np.concatenate([row_sums, col_sums, np.array([cy, cx, img.mean()/255.0, img.std()/255.0], dtype=np.float32)])
        out.append(vec.astype(np.float32).tolist())
    return out

df3 = df2.with_column("features", featurize(col("img_28x28")))

print("\nFeature column created (list[float]):")
df3.select("label", "features").show(2)

In [None]:
label_stats = (
    df3.groupby("label")
       .agg(
           col("label").count().alias("n"),
           col("pixel_mean").mean().alias("mean_pixel_mean"),
           col("pixel_std").mean().alias("mean_pixel_std"),
       )
       .sort("label")
)

print("\nLabel distribution + summary stats:")
label_stats.show(10)

df4 = df3.join(label_stats, on="label", how="left")

print("\nJoined label stats back onto each row:")
df4.select("label", "n", "mean_pixel_mean", "mean_pixel_std").show(5)

In [1]:
small = df4.select("label", "features").collect().to_pandas()

small = small.dropna(subset=["label", "features"]).reset_index(drop=True)

X = np.vstack(small["features"].apply(np.array).values).astype(np.float32)
y = small["label"].astype(int).values

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=7, stratify=y)

clf = LogisticRegression(max_iter=1000, n_jobs=None)
clf.fit(X_train, y_train)

pred = clf.predict(X_test)
acc = accuracy_score(y_test, pred)

print("\nBaseline accuracy (feature-engineered LogisticRegression):", round(acc, 4))
print("\nClassification report:")
print(classification_report(y_test, pred, digits=4))

out_df = df4.select("label", "features", "pixel_mean", "pixel_std", "n")
out_path = "/content/daft_mnist_features.parquet"
out_df.write_parquet(out_path)

print("\nWrote parquet to:", out_path)

df_back = daft.read_parquet(out_path)
print("\nRead-back check:")
df_back.show(3)

label Int64,features List[Float32],pixel_mean Float32,pixel_std Float32,n UInt64
2,"[0, 0, 0, 2.6941175, 5.0666666, 5.380392, 4.7058825, 3.3411765, 2.7019608, 2.372549, 2.3647058, 1.9372549, 2.5686274, 2.6352942, 2.8392158, 2.772549, 2.5843136, 2.7019608, 2.835294, 4.109804, 12.529411, 10.937255, 3.8039215, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.09803922, 2.3294117, 6.337255, 8.035295, 7.309804, 7.0862746, 7.4980392, 7.9764705, 9.188235, 10.34902, 7.062745, 1.9607843, 1.4196079, 1.4196079, 1.054902, 1.1019608, 0.654902, 0, 0, 0, 0, 0.49524328, 0.48875672, 0.10316627, 0.2726879]",26.307398,69.535416,1032
2,"[0, 0, 0, 3.0431373, 7.062745, 8.407844, 7.847059, 4.607843, 4.0941176, 4.090196, 3.9333334, 3.5843136, 5.0313725, 8.149019, 10.333333, 10.65098, 9.317647, 8.207843, 11.976471, 14.447059, 14.129412, 10.815686, 4.3137255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.7176471, 7.098039, 10.717647, 12.556863, 11.207843, 11.282353, 12.3450985, 13.078431, 14.603922, 14.843137, 12.592156, 11.215686, 6.4862747, 4.6039214, 4.239216, 3.4980392, 2.482353, 0.4745098, 0, 0, 0, 0.5010501, 0.51200235, 0.1964836, 0.37072396]",50.103317,94.53461,1032
6,"[0, 0, 0, 2, 2.7490196, 3.2509804, 3.5019608, 3.2509804, 5.254902, 8.505882, 9.501961, 9.254902, 9.254902, 9.505882, 8.752941, 8.505882, 9, 9.254902, 9.254902, 9.752941, 12.25098, 11, 7.0039215, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5.262745, 13.254902, 14.258823, 11.752941, 9.25098, 5.752941, 3.2509804, 6.5058823, 9.25098, 9.74902, 9.752941, 10.0039215, 9.25098, 8, 8.0039215, 8.74902, 7.2509804, 1.5058824, 0, 0, 0, 0, 0.5116571, 0.49304307, 0.19235694, 0.36872894]",49.05102,94.02588,958
