In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.backends.cudnn.benchmark = True
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)
import timm
import pandas as pd
import numpy as np
from PIL import Image, UnidentifiedImageError
import os
import glob
import random
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import logging
from torch.cuda.amp import autocast, GradScaler

In [2]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f"PyTorch version: {torch.__version__}")
logging.info(f"CUDA is available: {torch.cuda.is_available()}")
logging.info(f"Using device: {device}")

2025-11-15 08:47:14,307 - INFO - PyTorch version: 2.6.0+cu118
2025-11-15 08:47:14,311 - INFO - CUDA is available: True
2025-11-15 08:47:14,311 - INFO - Using device: cuda


In [3]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f"PyTorch version: {torch.__version__}")
logging.info(f"CUDA is available: {torch.cuda.is_available()}")
logging.info(f"Using device: {device}")

2025-11-15 08:47:14,330 - INFO - PyTorch version: 2.6.0+cu118
2025-11-15 08:47:14,331 - INFO - CUDA is available: True
2025-11-15 08:47:14,332 - INFO - Using device: cuda


In [4]:
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

In [5]:
def preprocess_unlabeled_images(image_paths):
    """Xác thực các ảnh không có nhãn và xóa các ảnh bị hỏng."""
    new_image_paths = []
    deleted_count = 0
    for path in tqdm(image_paths, desc="Preprocessing unlabeled images"):
        try:
            img = Image.open(path).convert("RGB")
            img.verify()
            new_image_paths.append(path)
        except (UnidentifiedImageError, FileNotFoundError, OSError) as e:
            logging.info(f"Image error: {path} ({str(e)})")
            try:
                if os.path.exists(path):
                    os.remove(path)
                    deleted_count += 1
                    logging.info(f"Deleted image: {path}")
            except (PermissionError, OSError) as e:
                logging.error(f"Cannot delete image {path}: {str(e)}")
    logging.info(f"Deleted {deleted_count} unlabeled images. {len(new_image_paths)} images remain.")
    return new_image_paths

In [6]:
image_dir = r"F:\unlabeled_images"
image_paths = []
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp"]:
    image_paths.extend(glob.glob(os.path.join(image_dir, "**", ext), recursive=True))
logging.info(f"Found {len(image_paths)} unlabeled images.")
#image_paths = preprocess_unlabeled_images(image_paths)

2025-11-15 08:47:14,499 - INFO - Found 11995 unlabeled images.


In [7]:
image_size = 224
transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [8]:
class UnlabeledImageDataset(Dataset):
    def __init__(self, image_paths, transform):
        self.image_paths = image_paths
        self.transform = transform
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        try:
            img_path = self.image_paths[idx]
            image = Image.open(img_path).convert("RGB")
            img1 = self.transform(image)
            img2 = self.transform(image)
            return img1, img2
        except Exception as e:
            logging.error(f"Error reading image {self.image_paths[idx]}: {e}. Skipping.")
            return None

def custom_collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return torch.tensor([]), torch.tensor([])
    return torch.utils.data.dataloader.default_collate(batch)

In [9]:
unlabeled_dataset = UnlabeledImageDataset(image_paths, transform)
unlabeled_dataloader = DataLoader(
    unlabeled_dataset,
    batch_size=16,
    shuffle=True,
    drop_last=True,
    num_workers=0,  
    pin_memory=True, 
    persistent_workers=False,  
    collate_fn=custom_collate_fn
)

In [10]:
class SoilNetDualHead(nn.Module):
    def __init__(self, num_classes=10, simclr_mode=False):
        super().__init__()
        self.simclr_mode = simclr_mode
        self.initial_conv = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.mnv2_block1 = nn.Sequential(*list(
            timm.create_model("mobilenetv2_100.ra_in1k", pretrained=True).blocks.children())[0:3]
        )
        self.channel_adapter = nn.Conv2d(32, 16, kernel_size=1, bias=False)
        self.mobilevit_full = timm.create_model("mobilevitv2_050", pretrained=True)
        self.mobilevit_encoder = self.mobilevit_full.stages
        self.mvit_to_mnv2 = nn.Conv2d(256, 32, kernel_size=1, bias=False)
        self.mnv2_block2 = nn.Sequential(*list(
            timm.create_model("mobilenetv2_100.ra_in1k", pretrained=True).blocks.children())[3:7]
        )
        self.final_conv = nn.Conv2d(320, 1280, kernel_size=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.light_dense = nn.Sequential(nn.Linear(1, 32), nn.ReLU(inplace=True))
        self.reg_head = nn.Sequential(
            nn.Linear(1280 + 32, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2)
        )
        self.cls_head = nn.Sequential(
            nn.Linear(1280 + 32, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes)
        )
    def forward(self, x_img, x_light=None):
        x = self.initial_conv(x_img)
        x = self.mnv2_block1(x)
        x = self.channel_adapter(x)
        x = self.mobilevit_encoder(x)
        x = self.mvit_to_mnv2(x)
        x = self.mnv2_block2(x)
        x = self.final_conv(x)
        x = self.pool(x)
        x_img_feat = torch.flatten(x, 1)
        if self.simclr_mode:
            return x_img_feat
        x_light_feat = self.light_dense(x_light)
        x_concat = torch.cat([x_img_feat, x_light_feat], dim=1)
        reg_out = self.reg_head(x_concat)
        cls_out = self.cls_head(x_concat)
        return reg_out, cls_out

class Projector(nn.Module):
    def __init__(self, input_dim=1280, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(inplace=True),
            nn.Linear(input_dim, proj_dim)
        )
    def forward(self, x):
        return self.net(x)

def vicreg_loss(z1, z2, lambda_=25.0, mu=25.0, nu=1.0, epsilon=1e-4):
    invariance_loss = F.mse_loss(z1, z2)
    def variance_term(z):
        z_std = torch.sqrt(z.var(dim=0) + epsilon)
        return torch.mean(F.relu(1 - z_std))
    var_loss = variance_term(z1) + variance_term(z2)
    def covariance_term(z):
        z = z - z.mean(dim=0)
        cov = (z.T @ z) / (z.shape[0] - 1)
        off_diag = cov - torch.diag(cov.diag())
        return off_diag.pow(2).sum() / z.shape[1]
    cov_loss = covariance_term(z1) + covariance_term(z2)
    return lambda_ * invariance_loss + mu * var_loss + nu * cov_loss

# Define num_classes
num_classes = 10

# Instantiate the model
model = SoilNetDualHead(num_classes=num_classes, simclr_mode=True).to(device)
projector = Projector(input_dim=1280, proj_dim=128).to(device)

2025-11-15 08:47:14,617 - INFO - Loading pretrained weights from Hugging Face hub (timm/mobilenetv2_100.ra_in1k)
2025-11-15 08:47:14,938 - INFO - [timm/mobilenetv2_100.ra_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-11-15 08:47:14,997 - INFO - Loading pretrained weights from Hugging Face hub (timm/mobilevitv2_050.cvnets_in1k)
2025-11-15 08:47:15,288 - INFO - [timm/mobilevitv2_050.cvnets_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-11-15 08:47:15,360 - INFO - Loading pretrained weights from Hugging Face hub (timm/mobilenetv2_100.ra_in1k)
2025-11-15 08:47:15,612 - INFO - [timm/mobilenetv2_100.ra_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


In [11]:
try:
    checkpoint = torch.load(r"C:\Users\PC\soilNet\Model\SoilNet_orginal.pth", map_location=device)
    
    model_state_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    
    # Update current model's state dict
    model_state_dict.update(pretrained_dict)
    model.load_state_dict(model_state_dict, strict=False)
    
    logging.info(f"Successfully loaded {len(pretrained_dict)} compatible pretrained weights.")
    logging.warning("Heads (cls_head, reg_head, light_dense) were NOT loaded — will be randomly initialized.")
    
except FileNotFoundError:
    logging.warning("Pre-trained model weights not found. Starting from scratch.")
except Exception as e:
    logging.error(f"Error loading pretrained weights: {e}")
    raise

2025-11-15 08:47:15,957 - INFO - Successfully loaded 846 compatible pretrained weights.


In [12]:
optimizer_vicreg = torch.optim.Adam(list(model.parameters()) + list(projector.parameters()), lr=1e-4)
scaler = GradScaler()

checkpoint_dir = r"F:\checkpoints_VicReg_original_NEW"
os.makedirs(checkpoint_dir, exist_ok=True)

  scaler = GradScaler()


In [13]:
def train_vicreg_with_mu(mu=25.0, num_epochs=150, metrics_df=None):
    vicreg_losses = []
   
    for epoch in range(1, num_epochs + 1):
        model.train()
        projector.train()
       
        running_vicreg_loss = 0.0
        num_batches = 0
       
        pbar = tqdm(unlabeled_dataloader, desc=f"Epoch {epoch}/{num_epochs} (mu={mu})", leave=False)
       
        for batch_data in pbar:
            if len(batch_data[0]) == 0:
                logging.warning("Skipping empty unlabeled batch.")
                continue
               
            img1, img2 = batch_data
            img1, img2 = img1.to(device), img2.to(device)
           
            optimizer_vicreg.zero_grad()
            with autocast():
                feat1 = model(img1, x_light=None)
                feat2 = model(img2, x_light=None)
                z1 = projector(feat1)
                z2 = projector(feat2)
                vicreg_loss_val = vicreg_loss(z1, z2, mu=mu)
           
            scaler.scale(vicreg_loss_val).backward()
            scaler.step(optimizer_vicreg)
            scaler.update()
           
            running_vicreg_loss += vicreg_loss_val.item()
            num_batches += 1
           
            pbar.set_postfix(vicreg_loss=vicreg_loss_val.item())
        if num_batches == 0:
            logging.error(f"Epoch {epoch} không có batch nào được xử lý. Dừng huấn luyện.")
            break
           
        avg_vicreg_loss = running_vicreg_loss / num_batches
       
        vicreg_losses.append(avg_vicreg_loss)
       
        print(f"✅ Epoch {epoch:3d}/{num_epochs} (mu={mu}) - VICReg Loss: {avg_vicreg_loss:.4f}")
       
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'projector_state_dict': projector.state_dict(),
            'vicreg_loss': avg_vicreg_loss
        }
        checkpoint_path = os.path.join(checkpoint_dir, f'vicreg_mu_{mu}_epoch_{epoch}.pth')
        torch.save(checkpoint, checkpoint_path)
       
        if metrics_df is not None:
            metrics_df.append({
                'mu': mu,
                'epoch': epoch,
                'vicreg_loss': avg_vicreg_loss
            })
    final_model_path = os.path.join(checkpoint_dir, f'vicreg_model_final_mu_{mu}.pth')
    final_projector_path = os.path.join(checkpoint_dir, f'vicreg_projector_final_mu_{mu}.pth')
   
    torch.save(model.state_dict(), final_model_path)
    torch.save(projector.state_dict(), final_projector_path)
   
    logging.info(f"Saved final models (mu={mu}): {final_model_path}, ...")

In [14]:
all_metrics = []
for mu_val in [30.0]:
    logging.info(f"--- Starting training for mu = {mu_val} ---")
    train_vicreg_with_mu(mu=mu_val, num_epochs=80, metrics_df=all_metrics)

metrics_df = pd.DataFrame(all_metrics)
metrics_csv_path = os.path.join(checkpoint_dir, 'vicreg_metrics_original_soilnet_mu_30.csv')
metrics_df.to_csv(metrics_csv_path, index=False)
logging.info(f"Saved metrics to {metrics_csv_path}")
print(f"All metrics saved to {metrics_csv_path}")

2025-11-15 08:47:16,014 - INFO - --- Starting training for mu = 30.0 ---
  with autocast():
                                                                                                                       

✅ Epoch   1/80 (mu=30.0) - VICReg Loss: 40.8137


                                                                                                                       

✅ Epoch   2/80 (mu=30.0) - VICReg Loss: 38.7496


                                                                                                                       

✅ Epoch   3/80 (mu=30.0) - VICReg Loss: 37.1015


                                                                                                                       

✅ Epoch   4/80 (mu=30.0) - VICReg Loss: 35.3533


                                                                                                                       

✅ Epoch   5/80 (mu=30.0) - VICReg Loss: 33.6452


                                                                                                                       

✅ Epoch   6/80 (mu=30.0) - VICReg Loss: 32.2888


                                                                                                                       

✅ Epoch   7/80 (mu=30.0) - VICReg Loss: 31.0657


                                                                                                                       

✅ Epoch   8/80 (mu=30.0) - VICReg Loss: 29.6976


                                                                                                                       

✅ Epoch   9/80 (mu=30.0) - VICReg Loss: 28.3350


                                                                                                                       

✅ Epoch  10/80 (mu=30.0) - VICReg Loss: 27.2281


                                                                                                                       

✅ Epoch  11/80 (mu=30.0) - VICReg Loss: 26.0967


                                                                                                                       

✅ Epoch  12/80 (mu=30.0) - VICReg Loss: 25.1864


                                                                                                                       

✅ Epoch  13/80 (mu=30.0) - VICReg Loss: 24.2113


                                                                                                                       

✅ Epoch  14/80 (mu=30.0) - VICReg Loss: 23.4336


                                                                                                                       

✅ Epoch  15/80 (mu=30.0) - VICReg Loss: 22.8429


                                                                                                                       

✅ Epoch  16/80 (mu=30.0) - VICReg Loss: 22.3620


                                                                                                                       

✅ Epoch  17/80 (mu=30.0) - VICReg Loss: 21.8803


                                                                                                                       

✅ Epoch  18/80 (mu=30.0) - VICReg Loss: 21.4093


                                                                                                                       

✅ Epoch  19/80 (mu=30.0) - VICReg Loss: 21.0959


                                                                                                                       

✅ Epoch  20/80 (mu=30.0) - VICReg Loss: 20.8804


                                                                                                                       

✅ Epoch  21/80 (mu=30.0) - VICReg Loss: 20.5121


                                                                                                                       

✅ Epoch  22/80 (mu=30.0) - VICReg Loss: 20.3650


                                                                                                                       

✅ Epoch  23/80 (mu=30.0) - VICReg Loss: 20.1436


                                                                                                                       

✅ Epoch  24/80 (mu=30.0) - VICReg Loss: 19.9366


                                                                                                                       

✅ Epoch  25/80 (mu=30.0) - VICReg Loss: 19.7033


                                                                                                                       

✅ Epoch  26/80 (mu=30.0) - VICReg Loss: 19.6884


                                                                                                                       

✅ Epoch  27/80 (mu=30.0) - VICReg Loss: 19.5735


                                                                                                                       

✅ Epoch  28/80 (mu=30.0) - VICReg Loss: 19.5006


                                                                                                                       

✅ Epoch  29/80 (mu=30.0) - VICReg Loss: 19.2839


                                                                                                                       

✅ Epoch  30/80 (mu=30.0) - VICReg Loss: 19.2542


                                                                                                                       

✅ Epoch  31/80 (mu=30.0) - VICReg Loss: 19.2526


                                                                                                                       

✅ Epoch  32/80 (mu=30.0) - VICReg Loss: 19.0993


                                                                                                                       

✅ Epoch  33/80 (mu=30.0) - VICReg Loss: 19.0706


                                                                                                                       

✅ Epoch  34/80 (mu=30.0) - VICReg Loss: 19.0162


                                                                                                                       

✅ Epoch  35/80 (mu=30.0) - VICReg Loss: 18.8882


                                                                                                                       

✅ Epoch  36/80 (mu=30.0) - VICReg Loss: 18.8687


                                                                                                                       

✅ Epoch  37/80 (mu=30.0) - VICReg Loss: 18.7488


                                                                                                                       

✅ Epoch  38/80 (mu=30.0) - VICReg Loss: 18.7497


                                                                                                                       

✅ Epoch  39/80 (mu=30.0) - VICReg Loss: 18.7967


                                                                                                                       

✅ Epoch  40/80 (mu=30.0) - VICReg Loss: 18.6446


                                                                                                                       

✅ Epoch  41/80 (mu=30.0) - VICReg Loss: 18.5691


                                                                                                                       

✅ Epoch  42/80 (mu=30.0) - VICReg Loss: 18.5701


                                                                                                                       

✅ Epoch  43/80 (mu=30.0) - VICReg Loss: 18.6066


                                                                                                                       

✅ Epoch  44/80 (mu=30.0) - VICReg Loss: 18.4825


                                                                                                                       

✅ Epoch  45/80 (mu=30.0) - VICReg Loss: 18.4782


                                                                                                                       

✅ Epoch  46/80 (mu=30.0) - VICReg Loss: 18.4980


                                                                                                                       

✅ Epoch  47/80 (mu=30.0) - VICReg Loss: 18.3541


                                                                                                                       

✅ Epoch  48/80 (mu=30.0) - VICReg Loss: 18.3205


                                                                                                                       

✅ Epoch  49/80 (mu=30.0) - VICReg Loss: 18.2956


                                                                                                                       

✅ Epoch  50/80 (mu=30.0) - VICReg Loss: 18.3018


                                                                                                                       

✅ Epoch  51/80 (mu=30.0) - VICReg Loss: 18.3559


                                                                                                                       

✅ Epoch  52/80 (mu=30.0) - VICReg Loss: 18.1555


                                                                                                                       

✅ Epoch  53/80 (mu=30.0) - VICReg Loss: 18.2168


                                                                                                                       

✅ Epoch  54/80 (mu=30.0) - VICReg Loss: 18.0863


                                                                                                                       

✅ Epoch  55/80 (mu=30.0) - VICReg Loss: 18.2411


                                                                                                                       

✅ Epoch  56/80 (mu=30.0) - VICReg Loss: 18.1900


                                                                                                                       

✅ Epoch  57/80 (mu=30.0) - VICReg Loss: 18.0342


                                                                                                                       

✅ Epoch  58/80 (mu=30.0) - VICReg Loss: 18.1296


                                                                                                                       

✅ Epoch  59/80 (mu=30.0) - VICReg Loss: 18.0912


                                                                                                                       

✅ Epoch  60/80 (mu=30.0) - VICReg Loss: 18.0010


                                                                                                                       

✅ Epoch  61/80 (mu=30.0) - VICReg Loss: 17.9895


                                                                                                                       

✅ Epoch  62/80 (mu=30.0) - VICReg Loss: 18.0544


                                                                                                                       

✅ Epoch  63/80 (mu=30.0) - VICReg Loss: 18.0549


                                                                                                                       

✅ Epoch  64/80 (mu=30.0) - VICReg Loss: 17.9361


                                                                                                                       

✅ Epoch  65/80 (mu=30.0) - VICReg Loss: 17.9275


                                                                                                                       

✅ Epoch  66/80 (mu=30.0) - VICReg Loss: 17.9890


                                                                                                                       

✅ Epoch  67/80 (mu=30.0) - VICReg Loss: 17.9472


                                                                                                                       

✅ Epoch  68/80 (mu=30.0) - VICReg Loss: 17.9328


                                                                                                                       

✅ Epoch  69/80 (mu=30.0) - VICReg Loss: 17.9061


                                                                                                                       

✅ Epoch  70/80 (mu=30.0) - VICReg Loss: 17.8845


                                                                                                                       

✅ Epoch  71/80 (mu=30.0) - VICReg Loss: 17.7795


                                                                                                                       

✅ Epoch  72/80 (mu=30.0) - VICReg Loss: 17.8828


                                                                                                                       

✅ Epoch  73/80 (mu=30.0) - VICReg Loss: 17.8961


                                                                                                                       

✅ Epoch  74/80 (mu=30.0) - VICReg Loss: 17.8206


                                                                                                                       

✅ Epoch  75/80 (mu=30.0) - VICReg Loss: 17.7590


                                                                                                                       

✅ Epoch  76/80 (mu=30.0) - VICReg Loss: 17.7135


                                                                                                                       

✅ Epoch  77/80 (mu=30.0) - VICReg Loss: 17.7594


                                                                                                                       

✅ Epoch  78/80 (mu=30.0) - VICReg Loss: 17.8064


                                                                                                                       

✅ Epoch  79/80 (mu=30.0) - VICReg Loss: 17.7448


2025-11-16 02:50:38,011 - INFO - Saved final models (mu=30.0): F:\checkpoints_VicReg_original_NEW\vicreg_model_final_mu_30.0.pth, ...
2025-11-16 02:50:38,018 - INFO - Saved metrics to F:\checkpoints_VicReg_original_NEW\vicreg_metrics_original_soilnet_mu_30.csv


✅ Epoch  80/80 (mu=30.0) - VICReg Loss: 17.7065
All metrics saved to F:\checkpoints_VicReg_original_NEW\vicreg_metrics_original_soilnet_mu_30.csv
