In [None]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random

class AgricultureVisionDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.split = split
        self.transform = transform
        
        self.images_dir = os.path.join(root_dir, split, "images", "rgb")
        self.masks_dir  = os.path.join(root_dir, split, "masks")
        
        self.image_ids = sorted(os.listdir(self.images_dir))
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_filename = self.image_ids[idx]
        image_path = os.path.join(self.images_dir, image_filename)
        
        mask_filename = os.path.splitext(image_filename)[0] + ".png"
        mask_path = os.path.join(self.masks_dir, mask_filename)
        
        image = np.array(Image.open(image_path).convert('RGB'))
        mask  = np.array(Image.open(mask_path))
        
        if self.transform is not None:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask  = augmented['mask']
        else:
            image = ToTensorV2()(image=image)['image']
            mask  = torch.as_tensor(mask, dtype=torch.long)
        
        mask = torch.as_tensor(mask, dtype=torch.long)
        return image, mask

def get_transforms(image_size=(256, 256)):
    mean = [0.485, 0.456, 0.406]
    std  = [0.229, 0.224, 0.225]
    
    train_transform = A.Compose([
        A.Resize(*image_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, p=0.5),
        A.Normalize(mean=mean, std=std),
        ToTensorV2()
    ])
    
    val_transform = A.Compose([
        A.Resize(*image_size),
        A.Normalize(mean=mean, std=std),
        ToTensorV2()
    ])
    return train_transform, val_transform

def get_dataloaders(root_dir, batch_size=8, image_size=(256, 256), num_workers=0):
    train_transform, val_transform = get_transforms(image_size=image_size)
    
    train_dataset = AgricultureVisionDataset(root_dir, split='train', transform=train_transform)
    val_dataset   = AgricultureVisionDataset(root_dir, split='val', transform=val_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, val_loader

def create_model(model_name, encoder_name, num_classes, in_channels=3, encoder_weights="imagenet"):
    model = smp.create_model(
        arch=model_name,
        encoder_name=encoder_name,
        encoder_weights=encoder_weights,
        in_channels=in_channels,
        classes=num_classes
    )
    return model

def compute_confusion_matrix_and_metrics(model, data_loader, device, num_classes):
    conf_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)
    model.eval()
    
    with torch.no_grad():
        for images, masks in tqdm(data_loader, desc="Computing metrics"):
            images = images.to(device)
            masks = masks.to(device)  # [B, H, W]
            outputs = model(images)   # [B, num_classes, H, W]
            preds = torch.argmax(outputs, dim=1)  # [B, H, W]
            
            preds_flat = preds.view(-1).cpu().numpy()
            masks_flat = masks.view(-1).cpu().numpy()
            
            indices = masks_flat * num_classes + preds_flat
            bincount = np.bincount(indices, minlength=num_classes*num_classes)
            conf_matrix += bincount.reshape((num_classes, num_classes))
    
    per_class_precision = []
    per_class_recall = []
    per_class_f1 = []
    for c in range(num_classes):
        tp = conf_matrix[c, c]
        fp = conf_matrix[:, c].sum() - tp
        fn = conf_matrix[c, :].sum() - tp
        precision_c = tp / (tp + fp + 1e-7)
        recall_c = tp / (tp + fn + 1e-7)
        f1_c = 2 * precision_c * recall_c / (precision_c + recall_c + 1e-7)
        per_class_precision.append(precision_c)
        per_class_recall.append(recall_c)
        per_class_f1.append(f1_c)
    
    avg_precision = np.mean(per_class_precision)
    avg_recall = np.mean(per_class_recall)
    avg_f1 = np.mean(per_class_f1)
    
    return conf_matrix, avg_precision, avg_recall, avg_f1


def evaluate_epoch_metrics(model, data_loader, criterion, device, num_classes):
    model.eval()
    running_loss = 0.0
    total_pixels = 0
    
    with torch.no_grad():
        for images, masks in tqdm(data_loader, desc="Evaluating epoch"):
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            batch_pixels = images.size(0) * images.size(2) * images.size(3)
            running_loss += loss.item() * batch_pixels
            total_pixels += batch_pixels
    
    seg_loss = running_loss / (total_pixels + 1e-7)
    conf_mat, precision, recall, f1 = compute_confusion_matrix_and_metrics(model, data_loader, device, num_classes)
    
    return {
        "seg_loss": seg_loss,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "conf_mat": conf_mat
    }

def train_model(
    model, 
    train_loader, 
    val_loader, 
    criterion, 
    optimizer, 
    device, 
    num_classes, 
    num_epochs=25, 
    log_dir="logs",
    model_name="Unet"
):
    os.makedirs(log_dir, exist_ok=True)
    best_val_loss = float('inf')
    history = []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        total_pixels = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}] - {model_name}")
        for images, masks in pbar:
            images = images.to(device)
            masks  = masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            batch_pixels = images.size(0) * images.size(2) * images.size(3)
            running_loss += loss.item() * batch_pixels
            total_pixels += batch_pixels
            
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        
        train_seg_loss = running_loss / (total_pixels + 1e-7)
        train_metrics = evaluate_epoch_metrics(model, train_loader, criterion, device, num_classes)
        val_metrics   = evaluate_epoch_metrics(model, val_loader, criterion, device, num_classes)
        
        if val_metrics["seg_loss"] < best_val_loss:
            best_val_loss = val_metrics["seg_loss"]
            ckpt_path = os.path.join(log_dir, f"best_{model_name}.pth")
            torch.save(model.state_dict(), ckpt_path)
        
        row = {
            "epoch": epoch+1,
            "train_seg_loss": train_metrics["seg_loss"],
            "train_precision": train_metrics["precision"],
            "train_recall": train_metrics["recall"],
            "train_f1": train_metrics["f1"],
            "val_seg_loss": val_metrics["seg_loss"],
            "val_precision": val_metrics["precision"],
            "val_recall": val_metrics["recall"],
            "val_f1": val_metrics["f1"]
        }
        history.append(row)
        
        print(f"[{model_name}] Epoch {epoch+1}/{num_epochs} | Train seg_loss={train_seg_loss:.4f}, Val seg_loss={val_metrics['seg_loss']:.4f}, Val F1={val_metrics['f1']:.4f}")
    
    df = pd.DataFrame(history)
    df_path = os.path.join(log_dir, f"history_{model_name}.csv")
    df.to_csv(df_path, index=False)

def visualize_predictions_and_save(model, data_loader, device, num_images=4, num_classes=9, save_dir="SMP", model_name="model"):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)
    
    random.seed(42)
    color_map = [ (random.randint(0,255), random.randint(0,255), random.randint(0,255)) for _ in range(num_classes)]
    
    images_list = []
    preds_list = []
    gt_list = []
    count = 0
    with torch.no_grad():
        for images, masks in data_loader:
            for i in range(images.size(0)):
                if count >= num_images:
                    break
                img_tensor = images[i].to(device).unsqueeze(0)
                out = model(img_tensor)
                pred = torch.argmax(out, dim=1).squeeze(0).cpu().numpy()
                ground_truth = masks[i].cpu().numpy()
                img_np = images[i].permute(1,2,0).cpu().numpy()
                img_np = np.clip(img_np, 0, 1)
                
                images_list.append(img_np)
                preds_list.append(pred)
                gt_list.append(ground_truth)
                count += 1
            if count >= num_images:
                break
    
    fig, axes = plt.subplots(num_images, 3, figsize=(15, 5*num_images))
    if num_images == 1:
        axes = np.expand_dims(axes, axis=0)
    
    for idx in range(num_images):
        orig = images_list[idx]
        pred = preds_list[idx]
        gt_img = gt_list[idx]
        
        pred_color = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
        gt_color   = np.zeros((gt_img.shape[0], gt_img.shape[1], 3), dtype=np.uint8)
        
        for c in range(num_classes):
            pred_color[pred == c] = color_map[c]
            gt_color[gt_img == c] = color_map[c]
        
        axes[idx, 0].imshow(orig)
        axes[idx, 0].set_title("Original")
        axes[idx, 0].axis("off")
        
        axes[idx, 1].imshow(orig, alpha=0.6)
        axes[idx, 1].imshow(pred_color, alpha=0.4)
        axes[idx, 1].set_title("Predicted Mask")
        axes[idx, 1].axis("off")
        
        axes[idx, 2].imshow(orig, alpha=0.6)
        axes[idx, 2].imshow(gt_color, alpha=0.4)
        axes[idx, 2].set_title("Ground Truth")
        axes[idx, 2].axis("off")
    
    plt.tight_layout()
    save_path = os.path.join(save_dir, f"{model_name}_visualization.png")
    plt.savefig(save_path)
    plt.close()
    print(f"Saved visualization to {save_path}")

if __name__ == "__main__":
    dataset_root = "./AgricultureVision"
    num_classes  = 10
    batch_size   = 4
    num_epochs   = 5
    image_size   = (512, 512)
    device       = "cuda" if torch.cuda.is_available() else "cpu"
    log_dir      = "smp_logs"
    
    # Создаем DataLoader
    train_loader, val_loader = get_dataloaders(
        root_dir=dataset_root, 
        batch_size=batch_size, 
        image_size=image_size,
        num_workers=2
    )
    
    model_configs = [
        # Your configs
    ]
    
    for (arch, encoder) in model_configs:
        model_identifier = f"{arch}_{encoder}"
        print(f"\n=== Training {model_identifier} ===")
        
        model = create_model(
            model_name=arch,
            encoder_name=encoder,
            num_classes=num_classes,
            in_channels=3,
            encoder_weights="imagenet"
        ).to(device)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        
        train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            num_classes=num_classes,
            num_epochs=num_epochs,
            log_dir=log_dir,
            model_name=model_identifier
        )
        
        print(f"=== Finished training {model_identifier} ===")
        ckpt_path = os.path.join(log_dir, f"best_{model_identifier}.pth")
        model.load_state_dict(torch.load(ckpt_path))
        
        visualize_predictions_and_save(
            model=model,
            data_loader=val_loader,
            device=device,
            num_images=4,
            num_classes=num_classes,
            save_dir="SMP",
            model_name=model_identifier
        )
