### Load corn subplots + June 26 chips

import os, json
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import cv2

RAW_DIR="data/raw"
PROC_DIR="data/processed"
DATE="0626"
CORN_IDS_PATH=os.path.join(PROC_DIR,"corn_subplots.json")
WEED_MASK_DIR=os.path.join(RAW_DIR,"labels","weed_masks",DATE)  # e.g., {subplot_id}.png (0/1 or 0/255)

with open(CORN_IDS_PATH,"r") as f:
    corn_ids = json.load(f)

CHIP_DIR=os.path.join(PROC_DIR,"subplots",f"chips_{DATE}")
len(corn_ids), CHIP_DIR

### Dataset (x + weed mask)

def load_npz_x(path):
    z=np.load(path, allow_pickle=False)
    return z["x"].astype(np.float32)

class WeedSegDataset(Dataset):
    def __init__(self, ids, chip_dir, mask_dir):
        self.ids=ids
        self.chip_dir=chip_dir
        self.mask_dir=mask_dir

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        sid=self.ids[idx]
        x = load_npz_x(os.path.join(self.chip_dir, f"{sid}.npz"))  # (C,H,W)
        mpath = os.path.join(self.mask_dir, f"{sid}.png")
        mask = cv2.imread(mpath, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(mpath)
        mask = (mask > 0).astype(np.float32)  # (H,W)
        return torch.from_numpy(x), torch.from_numpy(mask[None,...]), sid

### Minimal U-Net (compact)

class ConvBlock(nn.Module):
    def __init__(self,in_ch,out_ch):
        super().__init__()
        self.net=nn.Sequential(
            nn.Conv2d(in_ch,out_ch,3,1,1), nn.BatchNorm2d(out_ch), nn.ReLU(True),
            nn.Conv2d(out_ch,out_ch,3,1,1), nn.BatchNorm2d(out_ch), nn.ReLU(True),
        )
    def forward(self,x): return self.net(x)

class UNetSmall(nn.Module):
    def __init__(self,in_ch):
        super().__init__()
        self.e1=ConvBlock(in_ch,32); self.p1=nn.MaxPool2d(2)
        self.e2=ConvBlock(32,64);   self.p2=nn.MaxPool2d(2)
        self.e3=ConvBlock(64,128);  self.p3=nn.MaxPool2d(2)
        self.b=ConvBlock(128,256)
        self.u3=nn.ConvTranspose2d(256,128,2,2); self.d3=ConvBlock(256,128)
        self.u2=nn.ConvTranspose2d(128,64,2,2);  self.d2=ConvBlock(128,64)
        self.u1=nn.ConvTranspose2d(64,32,2,2);   self.d1=ConvBlock(64,32)
        self.out=nn.Conv2d(32,1,1)

    def forward(self,x):
        e1=self.e1(x); e2=self.e2(self.p1(e1)); e3=self.e3(self.p2(e2))
        b=self.b(self.p3(e3))
        d3=self.d3(torch.cat([self.u3(b), e3],1))
        d2=self.d2(torch.cat([self.u2(d3), e2],1))
        d1=self.d1(torch.cat([self.u1(d2), e1],1))
        return self.out(d1)

### Train + weed count from mask

def dice_loss(logits, target, eps=1e-6):
    prob = torch.sigmoid(logits)
    inter = (prob*target).sum(dim=(2,3))
    union = prob.sum(dim=(2,3)) + target.sum(dim=(2,3))
    dice = (2*inter + eps)/(union + eps)
    return 1 - dice.mean()

def mask_to_count(mask_prob, thr=0.5, min_area=8):
    # mask_prob: (H,W) float in [0,1]
    m = (mask_prob >= thr).astype(np.uint8)*255
    m = cv2.morphologyEx(m, cv2.MORPH_OPEN, np.ones((3,3),np.uint8), iterations=1)
    m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, np.ones((3,3),np.uint8), iterations=1)
    n, lbl, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)
    # stats[0] is background
    areas = stats[1:, cv2.CC_STAT_AREA] if n>1 else []
    return int(np.sum(np.array(areas) >= min_area))

# Split corn ids for training this model (simple)
import random
random.shuffle(corn_ids)
n=len(corn_ids)
train_ids=corn_ids[:int(0.8*n)]
val_ids=corn_ids[int(0.8*n):]

tmp = load_npz_x(os.path.join(CHIP_DIR, f"{train_ids[0]}.npz"))
in_ch = tmp.shape[0]

model = UNetSmall(in_ch)
device="cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
opt=torch.optim.AdamW(model.parameters(), lr=2e-4)
bce=nn.BCEWithLogitsLoss()

train_loader=DataLoader(WeedSegDataset(train_ids, CHIP_DIR, WEED_MASK_DIR), batch_size=4, shuffle=True, num_workers=2)
val_loader=DataLoader(WeedSegDataset(val_ids, CHIP_DIR, WEED_MASK_DIR), batch_size=4, shuffle=False, num_workers=2)

best=1e9
for epoch in range(20):
    model.train()
    for x,mask,_ in train_loader:
        x,mask=x.to(device),mask.to(device)
        opt.zero_grad()
        logits=model(x)
        loss=bce(logits,mask) + dice_loss(logits,mask)
        loss.backward()
        opt.step()

    # validate by count MAE (needs GT counts; if you only have masks, derive GT counts similarly)
    model.eval()
    maes=[]
    with torch.no_grad():
        for x,mask,_ in val_loader:
            x=x.to(device)
            prob=torch.sigmoid(model(x)).cpu().numpy()  # (B,1,H,W)
            gt=mask.numpy()
            for i in range(prob.shape[0]):
                pred_c = mask_to_count(prob[i,0])
                gt_c   = mask_to_count(gt[i,0])
                maes.append(abs(pred_c-gt_c))
    mae=float(np.mean(maes)) if maes else 0.0
    print("epoch",epoch,"val count MAE",mae)
    if mae < best:
        best=mae
        torch.save(model.state_dict(), os.path.join(PROC_DIR,"weed_segmenter_unet.pt"))