### Load corn ids + late-season chips + point labels

import os, json
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import cv2
from tqdm import tqdm

RAW_DIR="data/raw"
PROC_DIR="data/processed"
DATE="0731"  # late season (adjust if Task 3 specifies another date)
CHIP_DIR=os.path.join(PROC_DIR,"subplots",f"chips_{DATE}")
CORN_IDS_PATH=os.path.join(PROC_DIR,"corn_subplots.json")
POINT_DIR=os.path.join(RAW_DIR,"labels","corn_points",DATE)  # {subplot_id}.json with [{"x":..,"y":..},...]

with open(CORN_IDS_PATH,"r") as f:
    corn_ids=json.load(f)

### Points â†’ density map

def load_points(path):
    with open(path,"r") as f:
        pts=json.load(f)
    # support list of dicts or list of [x,y]
    out=[]
    for p in pts:
        if isinstance(p, dict):
            out.append((float(p["x"]), float(p["y"])))
        else:
            out.append((float(p[0]), float(p[1])))
    return out

def points_to_density(pts, H, W, sigma=3.0):
    den = np.zeros((H,W), dtype=np.float32)
    if len(pts)==0: return den
    # place gaussians
    for (x,y) in pts:
        ix, iy = int(round(x)), int(round(y))
        if ix<0 or ix>=W or iy<0 or iy>=H: 
            continue
        den[iy, ix] += 1.0
    den = cv2.GaussianBlur(den, ksize=(0,0), sigmaX=sigma, sigmaY=sigma)
    # keep total count approximately preserved (optional normalization)
    s = den.sum()
    if s > 1e-6:
        den *= (len(pts) / s)
    return den

### Dataset

def load_npz_x(path):
    z=np.load(path, allow_pickle=False)
    return z["x"].astype(np.float32)

class CornDensityDataset(Dataset):
    def __init__(self, ids, chip_dir, point_dir, sigma=3.0):
        self.ids=ids; self.chip_dir=chip_dir; self.point_dir=point_dir; self.sigma=sigma

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        sid=self.ids[idx]
        x = load_npz_x(os.path.join(self.chip_dir, f"{sid}.npz"))  # (C,H,W)
        H,W = x.shape[1], x.shape[2]
        pts = load_points(os.path.join(self.point_dir, f"{sid}.json"))
        den = points_to_density(pts, H, W, sigma=self.sigma)[None,...]  # (1,H,W)
        return torch.from_numpy(x), torch.from_numpy(den), sid

### Density Model

class DensityNet(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True),
        )
        self.dec = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(64, 32, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(32, 1, 1)
        )
    def forward(self,x):
        return self.dec(self.enc(x))

### Train + MAE on counts

In [None]:
import random
random.shuffle(corn_ids)
n=len(corn_ids)
train_ids=corn_ids[:int(0.8*n)]
val_ids=corn_ids[int(0.8*n):]

tmp = load_npz_x(os.path.join(CHIP_DIR, f"{train_ids[0]}.npz"))
in_ch = tmp.shape[0]

model = DensityNet(in_ch)
device="cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

opt=torch.optim.AdamW(model.parameters(), lr=2e-4)
mse=nn.MSELoss()

train_loader=DataLoader(CornDensityDataset(train_ids, CHIP_DIR, POINT_DIR, sigma=3.0), batch_size=4, shuffle=True, num_workers=2)
val_loader=DataLoader(CornDensityDataset(val_ids, CHIP_DIR, POINT_DIR, sigma=3.0), batch_size=4, shuffle=False, num_workers=2)

best=1e9
for epoch in range(25):
    model.train()
    for x,den,_ in train_loader:
        x,den=x.to(device),den.to(device)
        opt.zero_grad()
        pred=model(x)
        loss=mse(pred,den)
        loss.backward()
        opt.step()

    model.eval()
    maes=[]
    with torch.no_grad():
        for x,den,_ in val_loader:
            x=x.to(device)
            pred=model(x).cpu().numpy()
            gt=den.numpy()
            for i in range(pred.shape[0]):
                pred_cnt=float(pred[i,0].sum())
                gt_cnt=float(gt[i,0].sum())
                maes.append(abs(pred_cnt-gt_cnt))
    mae=float(np.mean(maes)) if maes else 0.0
    print("epoch",epoch,"val stand-count MAE",mae)
    if mae < best:
        best=mae
        torch.save(model.state_dict(), os.path.join(PROC_DIR,"corn_density_net.pt"))