### Load models + IDs + helpers

import os, json
import numpy as np
import pandas as pd
import torch
import cv2
from tqdm import tqdm

PROC_DIR="data/processed"
RAW_DIR="data/raw"

# Corn IDs
with open(os.path.join(PROC_DIR,"corn_subplots.json"),"r") as f:
    corn_ids=json.load(f)

device="cuda" if torch.cuda.is_available() else "cpu"

def load_npz_x(path):
    z=np.load(path, allow_pickle=False)
    return z["x"].astype(np.float32)

### Load Task 2 model + predict weed counts (Jun 26)

DATE_T2="0626"
CHIP_DIR_T2=os.path.join(PROC_DIR,"subplots",f"chips_{DATE_T2}")

# Re-declare model class (keep notebook independent)
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self,in_ch,out_ch):
        super().__init__()
        self.net=nn.Sequential(
            nn.Conv2d(in_ch,out_ch,3,1,1), nn.BatchNorm2d(out_ch), nn.ReLU(True),
            nn.Conv2d(out_ch,out_ch,3,1,1), nn.BatchNorm2d(out_ch), nn.ReLU(True),
        )
    def forward(self,x): return self.net(x)

class UNetSmall(nn.Module):
    def __init__(self,in_ch):
        super().__init__()
        self.e1=ConvBlock(in_ch,32); self.p1=nn.MaxPool2d(2)
        self.e2=ConvBlock(32,64);   self.p2=nn.MaxPool2d(2)
        self.e3=ConvBlock(64,128);  self.p3=nn.MaxPool2d(2)
        self.b=ConvBlock(128,256)
        self.u3=nn.ConvTranspose2d(256,128,2,2); self.d3=ConvBlock(256,128)
        self.u2=nn.ConvTranspose2d(128,64,2,2);  self.d2=ConvBlock(128,64)
        self.u1=nn.ConvTranspose2d(64,32,2,2);   self.d1=ConvBlock(64,32)
        self.out=nn.Conv2d(32,1,1)

    def forward(self,x):
        e1=self.e1(x); e2=self.e2(self.p1(e1)); e3=self.e3(self.p2(e2))
        b=self.b(self.p3(e3))
        d3=self.d3(torch.cat([self.u3(b), e3],1))
        d2=self.d2(torch.cat([self.u2(d3), e2],1))
        d1=self.d1(torch.cat([self.u1(d2), e1],1))
        return self.out(d1)

def mask_to_count(mask_prob, thr=0.5, min_area=8):
    m = (mask_prob >= thr).astype(np.uint8)*255
    m = cv2.morphologyEx(m, cv2.MORPH_OPEN, np.ones((3,3),np.uint8), iterations=1)
    m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, np.ones((3,3),np.uint8), iterations=1)
    n, lbl, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)
    areas = stats[1:, cv2.CC_STAT_AREA] if n>1 else []
    return int(np.sum(np.array(areas) >= min_area))

# Load model
tmp = load_npz_x(os.path.join(CHIP_DIR_T2, f"{corn_ids[0]}.npz"))
in_ch = tmp.shape[0]
weed_model = UNetSmall(in_ch).to(device)
weed_model.load_state_dict(torch.load(os.path.join(PROC_DIR,"weed_segmenter_unet.pt"), map_location=device))
weed_model.eval()

weed_rows=[]
with torch.no_grad():
    for sid in tqdm(corn_ids):
        x = load_npz_x(os.path.join(CHIP_DIR_T2, f"{sid}.npz"))
        xt = torch.from_numpy(x)[None].float().to(device)
        prob = torch.sigmoid(weed_model(xt)).cpu().numpy()[0,0]
        cnt = mask_to_count(prob, thr=0.5, min_area=8)
        weed_rows.append({"subplot_id": sid, "weed_count": cnt})

weed_df = pd.DataFrame(weed_rows)
weed_df.head()

### Load Task 3 model + predict stand counts (late season)

DATE_T3="0731"
CHIP_DIR_T3=os.path.join(PROC_DIR,"subplots",f"chips_{DATE_T3}")

class DensityNet(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True),
        )
        self.dec = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(128, 64, 3, 1, 1), nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(64, 32, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(32, 1, 1)
        )
    def forward(self,x):
        return self.dec(self.enc(x))

tmp = load_npz_x(os.path.join(CHIP_DIR_T3, f"{corn_ids[0]}.npz"))
in_ch = tmp.shape[0]
stand_model = DensityNet(in_ch).to(device)
stand_model.load_state_dict(torch.load(os.path.join(PROC_DIR,"corn_density_net.pt"), map_location=device))
stand_model.eval()

stand_rows=[]
with torch.no_grad():
    for sid in tqdm(corn_ids):
        x = load_npz_x(os.path.join(CHIP_DIR_T3, f"{sid}.npz"))
        xt = torch.from_numpy(x)[None].float().to(device)
        den = stand_model(xt).cpu().numpy()[0,0]
        cnt = float(den.sum())
        stand_rows.append({"subplot_id": sid, "stand_count": cnt})

stand_df = pd.DataFrame(stand_rows)
stand_df.head()

### Export

OUT_DIR="outputs"
os.makedirs(OUT_DIR, exist_ok=True)

weed_csv = os.path.join(OUT_DIR, "TeamName_Task2_WeedCount.csv")
stand_csv = os.path.join(OUT_DIR, "TeamName_Task3_StandCount.csv")

weed_df.to_csv(weed_csv, index=False)
stand_df.to_csv(stand_csv, index=False)

weed_csv, stand_csv