In [4]:
import os
import shutil
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlxim.model import create_model
from mlxim.data import LabelFolderDataset, DataLoader
from mlxim.trainer.trainer import Trainer
from mlxim.model._utils import load_weights
from mlxim.io.image import read_rgb

import sys
sys.path.insert(0, "../")
import utils

In [61]:
from mlxim.model import list_models
list_models()

Available models:
	- resnet18
	- resnet34
	- resnet50
	- resnet101
	- resnet152
	- wide_resnet50_2
	- wide_resnet101_2
	- vit_base_patch16_224
	- vit_base_patch16_224.swag_lin
	- vit_base_patch16_224.dino
	- vit_base_patch32_224
	- vit_base_patch16_384.swag_e2e
	- vit_large_patch16_224
	- vit_large_patch16_224.swag_lin
	- vit_large_patch16_512.swag_e2e
	- vit_huge_patch14_224.swag_lin
	- vit_huge_patch14_518.swag_e2e
	- vit_small_patch14_518.dinov2
	- vit_base_patch14_518.dinov2
	- vit_large_patch14_518.dinov2
	- vit_small_patch16_224.dino
	- vit_small_patch8_224.dino
	- vit_base_patch8_224.dino
	- swin_tiny_patch4_window7_224
	- swin_small_patch4_window7_224
	- swin_base_patch4_window7_224
	- swin_v2_tiny_patch4_window8_256
	- swin_v2_small_patch4_window8_256
	- swin_v2_base_patch4_window8_256


In [7]:
def _convert_to_float(img):
    return img.astype(float)

In [17]:

train_dataset = LabelFolderDataset(
    root_dir="../testing_data/screenshot_clusters/Messages_kb",
    class_map={0: "keyboard", 1: "no_keyboard"},
    transform=_convert_to_float
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=4,
    
)
# model = create_model("resnet18") # pretrained weights loaded from HF
# optimizer = optim.Adam(learning_rate=1e-3)

> [INFO] dataset sanity check OK
 ------- LabelFolderDataset stats -------
	- label 0 - ['keyboard'] - 6463/12233 -> 52.833%
	- label 1 - ['no_keyboard'] - 5770/12233 -> 47.167%
 -------------------------------------


In [4]:
trainer = Trainer(model = model,
        optimizer = optimizer,
        loss_fn = nn.losses.cross_entropy,
        train_loader = train_loader,
        max_epochs = 1, 
        loss_fn_args={})

In [5]:
trainer.train()


******** epoch 0/1 ********

> iter=[0/3059] | train_loss=8.088 | train_throughput=1.98 images/second | lr=0.00100
> iter=[305/3059] | train_loss=0.657 | train_throughput=2.83 images/second | lr=0.00100
> iter=[610/3059] | train_loss=0.516 | train_throughput=2.83 images/second | lr=0.00100
> iter=[915/3059] | train_loss=0.417 | train_throughput=2.83 images/second | lr=0.00100
> iter=[1220/3059] | train_loss=0.350 | train_throughput=2.80 images/second | lr=0.00100
> iter=[1525/3059] | train_loss=0.294 | train_throughput=2.82 images/second | lr=0.00100
> iter=[1830/3059] | train_loss=0.253 | train_throughput=2.84 images/second | lr=0.00100
> iter=[2135/3059] | train_loss=0.226 | train_throughput=2.83 images/second | lr=0.00100
> iter=[2440/3059] | train_loss=0.201 | train_throughput=2.83 images/second | lr=0.00100
> iter=[2745/3059] | train_loss=0.188 | train_throughput=2.83 images/second | lr=0.00100
> iter=[3050/3059] | train_loss=0.172 | train_throughput=2.82 images/second | lr=0.001

In [7]:
trainer.model.save_weights("keyboard_pred.npz")

# Get total accuracy and incorrect preds

In [2]:
model = create_model("resnet18")
model = load_weights(model, "keyboard_pred.npz")

In [3]:
def load_img(img_path):
    im = read_rgb(img_path)
    im = _convert_to_float(im)
    im = mx.array(im)
    im = mx.expand_dims(im, axis=0)
    return im

In [8]:
wrong_preds = list()
d = "../testing_data/screenshot_clusters/Messages_kb/no_keyboard"
keyboard_pred = "../testing_data/screenshot_clusters/Messages_kb/keyboard_pred"
utils.make_dir(keyboard_pred)
for f in os.listdir(d):
    path = os.path.join(d, f)
    im = load_img(path)
    logits = model(im)
    pred = mx.argsort(logits, axis=1)[:, -1][0]
    if pred == 0:
        wrong_preds.append(path)
        shutil.copy(path, os.path.join(keyboard_pred, f))

FileNotFoundError: The path ../testing_data/screenshot_clusters/Messages_kb/no_keyboard/com-google-android-apps-messaging_1721684986456.jpg does not exist

In [None]:
len(wrong_preds)

In [58]:
if pred == 1:
    print('h')

h


# Custom train

In [4]:
model.train()
for epoch in range(1):
    for batch in train_loader:
        x, target = batch
        x = x.astype(mx.float32) / 255.0
        train_step_fn = nn.value_and_grad(model, train_step)
        loss, grads = train_step_fn(model, x, target)
        optimizer.update(model, grads)
        mx.eval(model.state, optimizer.state)
model.save_weights("keyboard_pred.npz")

In [4]:
losses = list()
for batch in train_loader:
    x, target = batch
    x = x.astype(mx.float32) / 255.0
    logits = model(x)
    loss = mx.mean(nn.losses.cross_entropy(logits, target))
    losses.append(loss)

In [12]:
mo

In [9]:
loss

array(0.031888, dtype=float32)

In [6]:
target

array([0, 1, 0, 0], dtype=int32)

In [7]:
pred = nn.softmax(logits)

In [8]:
pred

array([[0.995103, 0.00489663, 7.35773e-16, ..., 1.29947e-15, 1.1011e-14, 1.21547e-14],
       [0.111527, 0.888473, 6.30638e-12, ..., 1.08314e-11, 1.42913e-10, 4.22585e-10],
       [0.995898, 0.00410203, 2.20345e-16, ..., 4.10853e-16, 3.60592e-15, 3.41826e-15],
       [0.999717, 0.00028271, 2.49472e-17, ..., 3.39398e-17, 3.00098e-16, 4.19621e-16]], dtype=float32)

In [5]:
logits

array([[15.9034, 10.5891, -18.9373, ..., -18.3685, -16.2315, -16.1327],
       [7.06692, 9.14216, -16.529, ..., -15.9882, -13.4084, -12.3242],
       [16.6627, 11.1706, -19.3845, ..., -18.7615, -16.5894, -16.6428],
       [18.3014, 10.1306, -19.9281, ..., -19.6203, -17.4408, -17.1055]], dtype=float32)

In [11]:
logits.shape

(4, 1000)

In [8]:
losses

[array(1.28949, dtype=float32),
 array(0.484149, dtype=float32),
 array(1.01843, dtype=float32),
 array(1.07637, dtype=float32),
 array(0.332378, dtype=float32),
 array(1.59821, dtype=float32),
 array(0.446189, dtype=float32),
 array(0.394225, dtype=float32),
 array(0.863487, dtype=float32),
 array(0.539639, dtype=float32),
 array(0.683185, dtype=float32),
 array(0.291809, dtype=float32),
 array(1.37085, dtype=float32),
 array(1.64025, dtype=float32),
 array(0.501101, dtype=float32),
 array(0.781838, dtype=float32),
 array(0.386523, dtype=float32),
 array(0.691342, dtype=float32),
 array(0.320483, dtype=float32),
 array(0.779269, dtype=float32),
 array(0.495409, dtype=float32),
 array(0.403279, dtype=float32),
 array(0.318139, dtype=float32),
 array(0.446375, dtype=float32),
 array(1.21727, dtype=float32),
 array(0.475137, dtype=float32),
 array(0.435272, dtype=float32),
 array(1.26628, dtype=float32),
 array(0.506136, dtype=float32),
 array(1.59153, dtype=float32),
 array(1.10125, dty