In [None]:
import os
import time
import random
import shutil
from pathlib import Path
from PIL import Image
import torch
import numpy as np
import pandas as pd
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

from models import *
from image_utils import *
from train_utils import *
from model_utils import *

# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# --- Step 1: Train TuiGAN on one healthy/sick pair ---
print("\n🚀 Step 1: Train TuiGAN on single healthy-sick pair")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_ITERS = 4000
NUM_SCALES = 5

pathA = "data/healthy2sick_A.jpg"  # single healthy
pathB = "data/healthy2sick_B.jpg"  # single sick
data_name = "healthy2sick"

imgA, imgB = read_domains("data", pathA, pathB, resize=True)
listA = construct_scale_pyramid(normalize_image_to_tensor(imgA).to(DEVICE)); listA.reverse()
listB = construct_scale_pyramid(normalize_image_to_tensor(imgB).to(DEVICE)); listB.reverse()

listg_ab, listg_ba, listd_a, listd_b = create_models(num_scale=NUM_SCALES, device=DEVICE)

start_time = time.time()
train(listA, listB, (listg_ab, listg_ba, listd_a, listd_b))
end_time = time.time()

save_models(listg_ab, listg_ba, listd_a, listd_b, data_name)
print(f"✅ TuiGAN model trained in {(end_time - start_time):.2f} seconds")

# --- Step 2: Prepare dataset folders ---
print("\n📁 Step 2: Preparing train/test folders")

src_healthy = "../data/plant_pathology/healthy"
src_sick = "../data/plant_pathology/sick"
dst_train_h = "./data/plant_pathology/train/healthy"
dst_train_s = "./data/plant_pathology/train/sick"
dst_test_h = "./data/plant_pathology/test/healthy"
dst_test_s = "./data/plant_pathology/test/sick"

for f in [dst_train_h, dst_train_s, dst_test_h, dst_test_s]:
    os.makedirs(f, exist_ok=True)

healthy_imgs = sorted(os.listdir(src_healthy))
random.shuffle(healthy_imgs)
train_healthy = healthy_imgs[:416]
test_healthy = healthy_imgs[416:]

for img in train_healthy:
    shutil.copy(os.path.join(src_healthy, img), os.path.join(dst_train_h, img))
for img in test_healthy:
    shutil.copy(os.path.join(src_healthy, img), os.path.join(dst_test_h, img))

sick_imgs = sorted(os.listdir(src_sick))
test_sick = random.sample(sick_imgs, 81)
train_sick_real = list(set(sick_imgs) - set(test_sick))
train_sick_sample = random.choice(train_sick_real)
shutil.copy(os.path.join(src_sick, train_sick_sample), os.path.join(dst_train_s, train_sick_sample))
for img in test_sick:
    shutil.copy(os.path.join(src_sick, img), os.path.join(dst_test_s, img))

print("✅ Dataset folders ready")

# --- Step 3: Generate synthetic sick images using TuiGAN ---
print("\n🎨 Step 3: Generating fake sick images using TuiGAN")

output_folder = dst_train_s
healthy_folder = dst_train_h

healthy_list = sorted(os.listdir(healthy_folder))
listg_ab, listg_ba, listd_a, listd_b = load_models("healthy2sick", device=DEVICE)

for i, fname in enumerate(healthy_list):
    pathA = os.path.join(healthy_folder, fname)
    imgA = Image.open(pathA).convert("RGB")
    listA = construct_scale_pyramid(normalize_image_to_tensor(imgA).to(DEVICE)); listA.reverse()
    fake_img, _ = generate_outputs((listA, listB), (listg_ab, listg_ba, listd_a, listd_b), NUM_SCALES - 1)
    save_path = os.path.join(output_folder, f"tuigan_fake_sick_{i+1}.jpg")
    fake_img.save(save_path)

print(f"✅ Generated {len(healthy_list)} synthetic sick images to: {output_folder}")

# --- Step 4: Classification (ResNet18 & VGG16) ---
print("\n🧠 Step 4: Classification on augmented dataset")

train_transform = transforms.Compose([
    transforms.Resize((288, 288)),
    transforms.ToTensor()
])
test_transform = transforms.Compose([
    transforms.Resize((288, 288)),
    transforms.ToTensor()
])

train_dataset = datasets.ImageFolder('./data/plant_pathology/train', transform=train_transform)
test_dataset = datasets.ImageFolder('./data/plant_pathology/test', transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

def train_and_eval(model_name):
    print(f"📌 Training model: {model_name}")
    if model_name == "resnet18":
        model = models.resnet18(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, 2)
    elif model_name == "vgg16":
        model = models.vgg16(pretrained=True)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, 2)
    else:
        raise ValueError("Unsupported model")

    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(10):
        model.train()
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    # Evaluation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = 100 * correct / total
    print(f"✅ {model_name} accuracy: {acc:.2f}%")
    return acc

resnet_acc = train_and_eval("resnet18")
vgg_acc = train_and_eval("vgg16")

# --- Step 5: Save results to CSV ---
print("\n💾 Step 5: Saving results to tuigan_43_results.csv")

df = pd.DataFrame([
    {"Model": "ResNet18 (TuiGAN)", "Accuracy (%)": round(resnet_acc, 2)},
    {"Model": "VGG16 (TuiGAN)", "Accuracy (%)": round(vgg_acc, 2)}
])
df.to_csv("tuigan_plant_results.csv", index=False)
print("📁 Results saved to tuigan_43_results.csv")