In [1]:
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import timm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
class SoilNetDualHead(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.initial_conv = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.mnv2_block1 = nn.Sequential(*list(
            timm.create_model("mobilenetv2_100.ra_in1k", pretrained=False).blocks.children())[0:3]
        )
        self.channel_adapter = nn.Conv2d(32, 16, kernel_size=1, bias=False)
        self.mobilevit_full = timm.create_model("mobilevitv2_050", pretrained=False)
        self.mobilevit_encoder = self.mobilevit_full.stages
        self.mvit_to_mnv2 = nn.Conv2d(256, 32, kernel_size=1, bias=False)
        self.mnv2_block2 = nn.Sequential(*list(
            timm.create_model("mobilenetv2_100.ra_in1k", pretrained=False).blocks.children())[3:7]
        )
        self.final_conv = nn.Conv2d(320, 1280, kernel_size=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.light_dense = nn.Sequential(nn.Linear(1, 32), nn.ReLU(inplace=True))
        self.reg_head = nn.Sequential(
            nn.Linear(1280 + 32, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2)
        )
        self.cls_head = nn.Sequential(
            nn.Linear(1280 + 32, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes)
        )

    def forward(self, x_img, x_light):
        x = self.initial_conv(x_img)
        x = self.mnv2_block1(x)
        x = self.channel_adapter(x)
        x = self.mobilevit_encoder(x)
        x = self.mvit_to_mnv2(x)
        x = self.mnv2_block2(x)
        x = self.final_conv(x)
        x = self.pool(x)
        x_img_feat = torch.flatten(x, 1)
        x_light_feat = self.light_dense(x_light)
        x_concat = torch.cat([x_img_feat, x_light_feat], dim=1)
        reg_out = self.reg_head(x_concat)
        cls_out = self.cls_head(x_concat)
        return reg_out, cls_out

# ====================== Cấu hình ======================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Transform giống lúc train
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load model
model = SoilNetDualHead(num_classes=10).to(device)   # num_classes có thể thay đổi nếu bạn biết chính xác

# Đường dẫn đến checkpoint tốt nhất của bạn
checkpoint_path = "/home/diy-hus/SoilNet_WSL/Best_finetuned_VicReg_mu_25.pth"          # ← thay bằng đường dẫn đầy đủ nếu cần, ví dụ: r"D:\soilNet\best_model.pth"

if not os.path.exists(checkpoint_path):
    raise FileNotFoundError(f"Không tìm thấy file: {checkpoint_path}")

# Load weights
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
else:
    state_dict = checkpoint

model.load_state_dict(state_dict, strict=False)   # strict=False để bỏ qua các key không khớp nếu có
model.eval()
print("Đã load checkpoint thành công!")

# ====================== Hàm inference ======================
def predict_soil_moisture(image_path, light_value=50.0):
    """
    Input:
        image_path: đường dẫn đến ảnh đất
        light_value: giá trị ánh sáng (0-100), mặc định 50 nếu không biết
    Output: SM_0 và SM_20 (đơn vị %)
    """
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Không tìm thấy ảnh: {image_path}")

    img = Image.open(image_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)          # shape: [1,3,224,224]

    light_tensor = torch.tensor([[light_value / 100.0]], dtype=torch.float32).to(device)

    with torch.no_grad():
        reg_out, _ = model(img_tensor, light_tensor)
    
    sm0, sm20 = reg_out[0].cpu().numpy() * 100               # scale về %
    return sm0, sm20

# ====================== Ví dụ sử dụng ======================
if __name__ == "__main__":
    # Thay bằng đường dẫn ảnh thật của bạn
    test_image = r"/home/diy-hus/SoilNet_WSL/z5719915064697_6860e93b3637879e75266a9cab0b000e_M_30_40_light_brown_M_40_50_light_brown_M_50_60_dark_brown_M_60_70_wet_gray.jpg"
    light_val = 65.0  # ví dụ, thay đổi tùy ý

    try:
        sm0, sm20 = predict_soil_moisture(test_image, light_val)
        print(f"Ảnh: {test_image}")
        print(f"→ SM_0  (0cm): {sm0:.2f} %")
        print(f"→ SM_20 (20cm): {sm20:.2f} %")
    except Exception as e:
        print("Lỗi:", e)

Using device: cuda
Đã load checkpoint thành công!
Ảnh: /home/diy-hus/SoilNet_WSL/z5719915064697_6860e93b3637879e75266a9cab0b000e_M_30_40_light_brown_M_40_50_light_brown_M_50_60_dark_brown_M_60_70_wet_gray.jpg
→ SM_0  (0cm): 85.95 %
→ SM_20 (20cm): 90.20 %
