### Imports + load chips + labels

import os, json
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

PROC_DIR = "data/processed"
RAW_DIR = "data/raw"
DATE_STACK = ["0626", "0710", "0731"]   # best: use all three dates

LABEL_CSV = os.path.join(RAW_DIR, "labels", "crop_type.csv")  # subplot_id,crop (5 classes)
SPLIT_OUT = os.path.join(PROC_DIR, "splits")
os.makedirs(SPLIT_OUT, exist_ok=True)

labels = pd.read_csv(LABEL_CSV)
labels.head()

### Dataset: stack multiple dates into one tensor

def load_npz(path):
    z = np.load(path, allow_pickle=False)
    x = z["x"]
    meta = json.loads(str(z["meta"]))
    return x, meta

class MultiDateSubplotDataset(Dataset):
    def __init__(self, ids, label_map, proc_dir, dates):
        self.ids = ids
        self.label_map = label_map
        self.proc_dir = proc_dir
        self.dates = dates

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        sid = self.ids[idx]
        xs = []
        for d in self.dates:
            p = os.path.join(self.proc_dir, "subplots", f"chips_{d}", f"{sid}.npz")
            x, _ = load_npz(p)
            xs.append(x)
        x = np.concatenate(xs, axis=0)  # (C_total,H,W)
        y = self.label_map[sid]
        return torch.from_numpy(x).float(), torch.tensor(y).long(), sid

label_map = {str(r.subplot_id): int(r.crop) for r in labels.itertuples()}
all_ids = [sid for sid in label_map.keys() if all(os.path.exists(os.path.join(PROC_DIR, "subplots", f"chips_{d}", f"{sid}.npz")) for d in DATE_STACK)]
len(all_ids)

### Split (grouped if you have plot_id; otherwise simple)

# If you have plot_id in labels, group-split by plot_id. Otherwise random split.
import random
random.shuffle(all_ids)
n = len(all_ids)
train_ids = all_ids[:int(0.7*n)]
val_ids   = all_ids[int(0.7*n):int(0.85*n)]
test_ids  = all_ids[int(0.85*n):]

with open(os.path.join(SPLIT_OUT,"train_ids.json"),"w") as f: json.dump(train_ids,f)
with open(os.path.join(SPLIT_OUT,"val_ids.json"),"w") as f: json.dump(val_ids,f)
with open(os.path.join(SPLIT_OUT,"test_ids.json"),"w") as f: json.dump(test_ids,f)
len(train_ids), len(val_ids), len(test_ids)

### Simple classifier (custom first conv for C channels)

import torchvision.models as models

def make_resnet18(in_ch, n_classes=5):
    m = models.resnet18(weights=None)
    m.conv1 = nn.Conv2d(in_ch, 64, kernel_size=7, stride=2, padding=3, bias=False)
    m.fc = nn.Linear(m.fc.in_features, n_classes)
    return m

C_total = None
tmp_x,_,_ = MultiDateSubplotDataset([train_ids[0]], label_map, PROC_DIR, DATE_STACK)[0]
C_total = tmp_x.shape[0]
model = make_resnet18(C_total, n_classes=5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

### Train loop + export corn_subplots.json

from sklearn.metrics import f1_score, confusion_matrix

train_ds = MultiDateSubplotDataset(train_ids, label_map, PROC_DIR, DATE_STACK)
val_ds   = MultiDateSubplotDataset(val_ids,   label_map, PROC_DIR, DATE_STACK)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=16, shuffle=False, num_workers=2, pin_memory=True)

opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
crit = nn.CrossEntropyLoss()

best_f1 = -1
for epoch in range(15):
    model.train()
    for x,y,_ in train_loader:
        x,y = x.to(device), y.to(device)
        opt.zero_grad()
        logits = model(x)
        loss = crit(logits,y)
        loss.backward()
        opt.step()

    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for x,y,_ in val_loader:
            x = x.to(device)
            logits = model(x)
            pred = logits.argmax(1).cpu().numpy().tolist()
            ys += y.numpy().tolist()
            ps += pred
    f1 = f1_score(ys, ps, average="macro")
    print("epoch", epoch, "val macro-f1", f1)
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), os.path.join(PROC_DIR, "crop_classifier_resnet18.pt"))

# Inference on all subplots -> select corn IDs (adjust class index/name for corn)
CORN_CLASS = 0  # <-- set correctly based on your label encoding
model.load_state_dict(torch.load(os.path.join(PROC_DIR, "crop_classifier_resnet18.pt"), map_location=device))
model.eval()

corn_ids = []
with torch.no_grad():
    for sid in all_ids:
        x,_,_ = MultiDateSubplotDataset([sid], label_map, PROC_DIR, DATE_STACK)[0]
        logits = model(x[None].to(device))
        pred = int(logits.argmax(1).item())
        if pred == CORN_CLASS:
            corn_ids.append(sid)

corn_path = os.path.join(PROC_DIR, "corn_subplots.json")
with open(corn_path,"w") as f: json.dump(corn_ids,f)
corn_path, len(corn_ids)