<a href="https://colab.research.google.com/github/matt14e/StitchAI/blob/main/Stitchmodel3_(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip -q install lightning torchvision pyembroidery pillow

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import glob, pathlib, textwrap, os

# List every .ipynb file reachable from the current directory
candidates = glob.glob("**/*.ipynb", recursive=True)

print(textwrap.dedent(f"""
    🔍 Found {len(candidates)} notebook(s):
    ----------------------------------------
    {os.linesep.join(candidates) or '(none)'}
"""))


In [None]:
# ▶︎ 2.  Install nbformat (small dependency)
!pip -q install nbformat
import nbformat

# ▶︎ 3.  ABSOLUTE path to your notebook  ←── copy-paste exactly as shown
nb_path = "drive/MyDrive/Colab Notebooks/Stitchmodel3 (1).ipynb"

# ▶︎ 4.  Load → strip bad widgets → save
with open(nb_path, encoding="utf-8") as f:
    nb = nbformat.read(f, as_version=nbformat.NO_CONVERT)

removed = nb.metadata.pop("widgets", None)   # drop entire widgets block

with open(nb_path, "w", encoding="utf-8") as f:
    nbformat.write(nb, f)

print(
    f"✅ Cleaned '{nb_path}'. "
    + ("widgets metadata removed." if removed else "No widgets metadata found.")
)

In [None]:
#patch cell
# 🛠  Make GitHub happy: ensure metadata.widgets.state exists
import nbformat, os, io, json, IPython

nb_path = os.environ["COLAB_NOTEBOOK_NAME"]      # always the open notebook
print("↪︎ Patching", nb_path)

# -- read current in-memory file -------------
with open(nb_path, encoding="utf-8") as f:
    nb = nbformat.read(f, as_version=nbformat.NO_CONVERT)

# -- guarantee valid widgets structure -------
widgets = nb.metadata.setdefault("widgets", {})
widgets.setdefault("state", {})                 # <- the missing bit!

# -- write back in place ----------------------
with open(nb_path, "w", encoding="utf-8") as f:
    nbformat.write(nb, f)

print("✅ Added empty widgets.state — now click File ▸ Save a copy in GitHub")


In [None]:
# 🛠 Make GitHub happy: be sure metadata.widgets.state exists
import nbformat, os, glob, pathlib, textwrap

# 1) Try env-var first, else auto-detect in cwd, else manual fallback
nb_path = os.environ.get("COLAB_NOTEBOOK_NAME")
if not nb_path:
    matches = glob.glob("*.ipynb")
    if len(matches) == 1:
        nb_path = matches[0]                            # unique match
    else:
        raise FileNotFoundError(textwrap.dedent(f"""
            Couldn't auto-detect the notebook file.
            Please set nb_path manually, e.g.:
                nb_path = "drive/MyDrive/Colab Notebooks/Stitchmodel3 (1).ipynb"
        """))

print("↪︎ Patching →", nb_path)

# 2) Load, patch, save
with open(nb_path, encoding="utf-8") as f:
    nb = nbformat.read(f, as_version=nbformat.NO_CONVERT)

widgets = nb.metadata.setdefault("widgets", {})
widgets.setdefault("state", {})               #  <-- the key GitHub needs

with open(nb_path, "w", encoding="utf-8") as f:
    nbformat.write(nb, f)

print("✅ Added empty widgets.state — now click File ▸ Save a copy in GitHub")


In [None]:
# Edit here if folders are moved
DATA_ROOT = "/content/drive/MyDrive/Embroidery Files"
IMG_DIR   = f"{DATA_ROOT}/PNG_image_files"   # PNGs
DST_DIR   = f"{DATA_ROOT}/DST_digitized_files"   # DSTs

In [None]:
#sanity check the folders
import pathlib, textwrap

print("📷  sample PNG files:")
for p in list(pathlib.Path(IMG_DIR).glob("*.[pP][nN][gG]"))[:5]:
    print("   ", p.name)

print("\n🧵  sample DST files:")
for p in list(pathlib.Path(DST_DIR).glob("*.[dD][sS][tT]"))[:5]:
    print("   ", p.name)


In [None]:
%%writefile dataset.py
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset
from pyembroidery import read

# --- file patterns -----------------------------------------------------------
IMG_EXTS = ("*.png", "*.PNG", "*.jpg", "*.JPG")
DST_EXTS = ("*.dst", "*.DST")

class EmbroDataset(Dataset):
    """
    Yields (image_tensor, stitch_tensor) pairs.

    image_tensor  : C × H × W   float32 in [0, 1]
    stitch_tensor : L × 3       (Δx, Δy, flag)  float32
    """
    def __init__(self, img_dir, dst_dir, transform=None, max_len=4096):
        img_dir, dst_dir = Path(img_dir), Path(dst_dir)

        # -------- gather every image and dst file into dicts ---------------
        img_files = {}
        for pat in IMG_EXTS:
            for p in img_dir.glob(pat):
                img_files[p.stem] = p

        dst_files = {}
        for pat in DST_EXTS:
            for d in dst_dir.glob(pat):
                dst_files[d.stem] = d

        # -------- keep only names that exist in *both* dicts --------------
        self.common_names = sorted(img_files.keys() & dst_files.keys())
        if not self.common_names:
            raise RuntimeError("No matching (image, DST) pairs found!")

        self.img_files  = img_files
        self.dst_files  = dst_files
        self.transform  = transform
        self.max_len    = max_len

    # ------------------------------------------------------------------------
    def __len__(self):
        return len(self.common_names)

    def __getitem__(self, idx):
        name = self.common_names[idx]

        # load & optionally transform image
        img_path = self.img_files[name]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)

        # load DST stitches
        dst_path  = self.dst_files[name]
        pattern   = read(str(dst_path))
        stitches  = torch.tensor(pattern.stitches,
                                 dtype=torch.float32)[: self.max_len]

        return img, stitches


In [None]:
#dont run

%%writefile dataset.py
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset
from pyembroidery import read

IMG_EXTS = ("*.png", "*.PNG", "*.jpg", "*.JPG")
DST_EXTS = ("*.dst", "*.DST")

class EmbroDataset(Dataset):
    """
    (image_tensor, stitch_tensor) pairs
    image_tensor  : C×H×W  float32 [0,1]
    stitch_tensor : L×3    (Δx, Δy, flag)
    """
    def __init__(self, img_dir, dst_dir, transform=None, max_len=4096):
        self.img_paths = []
        for pat in IMG_EXTS:
            self.img_paths.extend(Path(img_dir).glob(pat))
        self.img_paths = sorted(self.img_paths)

        self.dst_dir   = Path(dst_dir)
        self.transform = transform
        self.max_len   = max_len

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)

        stem = img_path.stem
        for ext in DST_EXTS:
            cand = self.dst_dir / f"{stem}{ext[1:]}"   # '*.dst' ➜ '.dst'
            if cand.exists():
                dst_path = cand
                break
        else:
            raise FileNotFoundError(f"No DST match for {stem}")

        pattern  = read(str(dst_path))
        stitches = torch.tensor(pattern.stitches,
                                dtype=torch.float32)[: self.max_len]
        return img, stitches


In [None]:
%%writefile model.py
import torch, torch.nn as nn

class EmbroNet(nn.Module):
    """Tiny CNN encoder + GRU decoder baseline."""
    def __init__(self, hidden=256):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1), nn.ReLU(),
            nn.Conv2d(64,128, 3, 2, 1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)          # 128×1×1
        )
        self.enc_fc  = nn.Linear(128, hidden)

        self.gru     = nn.GRU(3, hidden, num_layers=2, batch_first=True)
        self.dec_fc  = nn.Linear(hidden, 3)

    def forward(self, img, prev_cmds):
        B = img.size(0)
        h0 = self.encoder(img).view(B, -1)      # B×128
        h0 = torch.tanh(self.enc_fc(h0)).unsqueeze(0).repeat(2, 1, 1)
        out, _ = self.gru(prev_cmds, h0)        # B×L×H
        return self.dec_fc(out)                 # B×L×3


In [None]:
import torch
import lightning as L
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Resize
from dataset import EmbroDataset
from model   import EmbroNet

# ---------- 1. custom collate_fn ---------------------------------
def pad_collate(batch):
    """
    batch = list of (img_tensor, seq_tensor) pairs
    Returns:
        imgs   : B×3×128×128
        tgt    : B×Lmax×3  (padded with 0s)
        lens   : list[int] (original sequence lengths)
    """
    imgs, seqs = zip(*batch)
    imgs = torch.stack(imgs)                   # all same size

    lens    = [s.size(0) for s in seqs]
    Lmax    = max(lens)
    padded  = torch.zeros(len(seqs), Lmax, 3)  # default zeros = padding
    for i, s in enumerate(seqs):
        padded[i, : lens[i], :] = s
    return imgs, padded, torch.tensor(lens)

# ---------- 2. LightningModule -----------------------------------
class LitModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = EmbroNet()
        self.loss_fn = torch.nn.MSELoss(reduction="mean")

    def forward(self, img, prev_cmds):
        return self.net(img, prev_cmds)

    def training_step(self, batch, _):
        img, tgt, lens = batch                 # tgt: B×Lmax×3
        pred = self(img, tgt[:, :-1])

        # compute loss mask-aware
        loss = 0.0
        for i, L in enumerate(lens):
            loss += self.loss_fn(pred[i, : L-1], tgt[i, 1 : L])  # valid region
        loss = loss / len(lens)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), 1e-3)

# ---------- 3. dataset & dataloader ------------------------------
tfms = Compose([Resize((128,128)), ToTensor()])
ds   = EmbroDataset(IMG_DIR, DST_DIR, tfms)
print("Dataset length:", len(ds))          # should be > 0

dl   = DataLoader(ds,
                  batch_size=8,
                  shuffle=True,
                  num_workers=0,            # change later if you like
                  collate_fn=pad_collate)   # ← custom collate!

# ---------- 4. train ---------------------------------------------
trainer = L.Trainer(fast_dev_run=True)
#trainer = L.Trainer(max_epochs=10,
                    #precision="16-mixed",
                    #accelerator="auto")
trainer.fit(LitModule(), dl)


In [None]:
import glob, pprint, os

ckpts = glob.glob("/content/**/*epoch*=*.ckpt", recursive=True)
if ckpts:
    pprint.pp(ckpts)
else:
    print("No checkpoints found – you may need to save one first.")

In [None]:
# after your training run finishes
trainer.save_checkpoint("/content/drive/MyDrive/EmbroideryTests/embronet_latest.ckpt")


In [None]:
# (A) List any checkpoints Lightning has already saved
import glob, pprint
ckpts = glob.glob("/content/**/checkpoints/*.ckpt", recursive=True)
ckpts += glob.glob("/content/drive/**/checkpoints/*.ckpt", recursive=True)
pprint.pp(ckpts)

# If the list prints something, copy one path:
# CKPT_PATH = "/content/lightning_logs/version_3/checkpoints/epoch=4-step=99.ckpt"

# (B) If the list is empty, save a quick checkpoint now (after training):
trainer.save_checkpoint("/content/drive/MyDrive/EmbroideryTests/embronet.ckpt")
CKPT_PATH = "/content/drive/MyDrive/EmbroideryTests/embronet.ckpt"


In [None]:
import os
print("checkpoint exists?", os.path.exists(CKPT_PATH))

In [None]:
#### ONE THAT WILL ACTUALLY WORK
# ================================================================
#  1‑CELL  PNG  ➜  EmbroNet  ➜  DST
#  (just edit the three paths + SCALE, then press ▶)
# ================================================================
CKPT_PATH = "/content/drive/MyDrive/EmbroideryTests/embronet.ckpt"  # <— your .ckpt
PNG_PATH  = "/content/drive/MyDrive/Embroidery Files/pngstart/MLPP_LOGO.PNG"                 # <— input PNG
OUT_PATH  = "/content/drive/MyDrive/Embroidery Files/dststop/test_logo.dst"         # <— output DST
SCALE     = 7.0          # multiply Δx,Δy by this (set 1.0 if you trained with normalised targets)

# --------------------------- code starts ---------------------------
import torch, numpy as np
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from pyembroidery import EmbPattern, write
from model import EmbroNet                                # <-- comes from model.py

# --- load network -----------------------------------------------
raw_state = torch.load(CKPT_PATH, map_location="cpu")
state     = raw_state["state_dict"] if "state_dict" in raw_state else raw_state
state     = {k.replace("net.", ""): v for k, v in state.items()}   # strip "net." prefix
net = EmbroNet().eval()
net.load_state_dict(state, strict=False)
print("✓ checkpoint loaded")

# --- preprocess PNG ----------------------------------------------
tfms = Compose([Resize((128,128)), ToTensor()])
img  = tfms(Image.open(PNG_PATH).convert("RGB")).unsqueeze(0)   # 1×3×128×128

# --- autoregressive inference ------------------------------------
MAX_LEN = 4096
with torch.no_grad():
    seq = torch.zeros(1, 1, 3)      # start token (Δx=Δy=flag=0)
    for _ in range(MAX_LEN):
        pred = net(img, seq)        # B×L×3
        seq  = torch.cat([seq, pred[:, -1:, :]], dim=1)
pred_seq = seq.squeeze(0)[1:]       # drop start token  (L×3)

# --- rescale if needed -------------------------------------------
pred_seq[:, :2] *= SCALE            # comment out if SCALE=1

# --- build EmbPattern & save DST ---------------------------------
pat = EmbPattern()
x = y = 0.0
for dx, dy, _ in pred_seq.numpy():
    x += dx;  y += dy
    pat.stitch(float(x), float(y))  # simple stitch; ignore flag for now
pat.end()

write(pat, OUT_PATH)
print("✓ saved", OUT_PATH)


In [None]:
img, tgt = next(iter(dl))
net = trainer.model.net.to(img.device)
with torch.no_grad():
    pred = net(img, tgt[:, :-1])
print("GT  :", tgt[0, :5])
print("Pred:", pred[0, :5])