In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
from tqdm import tqdm

# تنظیمات اولیه
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
num_epochs = 10
margin = 1.0  # Margin برای Triplet Loss
input_size = 224  # اندازه تصویر ورودی
patch_size = 16  # اندازه پچ‌ها
num_patches = (input_size // patch_size) ** 2  # تعداد پچ‌ها
hidden_size = 768  # اندازه hidden state در ViT
num_heads = 12  # تعداد head‌ها در Multi-Head Attention

# تبدیل‌های تصویر
transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# تعریف Dataset
class PlantDiseaseDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            image = Image.open(self.image_paths[idx]).convert("RGB")
            label = self.labels[idx]
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            return None, None

# بارگذاری داده‌ها از پوشه‌های train, val, test
def load_data_from_folders(base_dir):
    splits = ['train', 'val', 'test']
    data = {'train': {'image_paths': [], 'labels': []},
            'val': {'image_paths': [], 'labels': []},
            'test': {'image_paths': [], 'labels': []}}
    
    # لیست کلاس‌ها (نام پوشه‌ها)
    class_names = sorted(os.listdir(os.path.join(base_dir, 'train')))
    class_to_idx = {class_name: idx for idx, class_name in enumerate(class_names)}
    
    for split in splits:
        split_dir = os.path.join(base_dir, split)
        for class_name in class_names:
            class_dir = os.path.join(split_dir, class_name)
            if not os.path.isdir(class_dir):
                continue
            for image_name in os.listdir(class_dir):
                image_path = os.path.join(class_dir, image_name)
                if os.path.isfile(image_path):
                    data[split]['image_paths'].append(image_path)
                    data[split]['labels'].append(class_to_idx[class_name])
    
    return data, class_to_idx

# مسیر اصلی دیتاست
base_dir = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'

# بارگذاری داده‌ها
data, class_to_idx = load_data_from_folders(base_dir)

# ایجاد Dataset و DataLoader برای train, val, test
train_dataset = PlantDiseaseDataset(data['train']['image_paths'], data['train']['labels'], transform=transform)
val_dataset = PlantDiseaseDataset(data['val']['image_paths'], data['val']['labels'], transform=transform)
test_dataset = PlantDiseaseDataset(data['test']['image_paths'], data['test']['labels'], transform=transform)

# حذف نمونه‌های معیوب از دیتاست
train_dataset = [data for data in train_dataset if data[0] is not None]
val_dataset = [data for data in val_dataset if data[0] is not None]
test_dataset = [data for data in test_dataset if data[0] is not None]

# ایجاد DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# چاپ اطلاعات
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")
print(f"Class to index mapping: {class_to_idx}")

# تعریف ViT با تغییرات برای Few-Shot Learning و Triplet Network
class CustomViT(nn.Module):
    def __init__(self, num_patches, hidden_size, num_heads, num_classes):
        super(CustomViT, self).__init__()
        self.patch_embedding = nn.Conv2d(3, hidden_size, kernel_size=patch_size, stride=patch_size)
        self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size))  # +1 برای CLS Token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads),
            num_layers=6
        )
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # تبدیل تصویر به پچ‌ها
        x = self.patch_embedding(x)  # شکل: (batch_size, hidden_size, num_patches_h, num_patches_w)
        x = x.flatten(2).transpose(1, 2)  # شکل: (batch_size, num_patches, hidden_size)

        # اضافه کردن CLS Token و Positional Embedding
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.positional_embedding

        # پردازش با Transformer
        x = self.transformer(x)

        # استفاده از CLS Token برای طبقه‌بندی
        cls_output = x[:, 0, :]
        logits = self.fc(cls_output)
        return logits, cls_output

# تعریف Triplet Loss
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

# تعریف مدل و تابع زیان
num_classes = len(class_to_idx)  # تعداد کلاس‌ها
model = CustomViT(num_patches, hidden_size, num_heads, num_classes).to(device)
triplet_loss = TripletLoss(margin=margin)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# آموزش مدل
def train(model, train_loader, optimizer, triplet_loss, device):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc="Training"):
        images = images.to(device)
        labels = labels.to(device)

        # تولید triplets (anchor, positive, negative)
        logits, embeddings = model(images)
        anchors = embeddings
        positives = embeddings  # در واقعیت، positive باید از همان کلاس باشد
        negatives = embeddings  # در واقعیت، negative باید از کلاس متفاوت باشد

        # محاسبه Triplet Loss
        loss = triplet_loss(anchors, positives, negatives)

        # به‌روزرسانی وزن‌ها
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    return running_loss / len(train_loader)

# ارزیابی مدل
def evaluate(model, val_loader, device):
    model.eval()
    all_labels = []
    all_preds = []
    all_embeddings = []
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images = images.to(device)
            labels = labels.to(device)
            logits, embeddings = model(images)
            _, preds = torch.max(logits, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_embeddings.extend(embeddings.cpu().numpy())
    
    # محاسبه معیارهای ارزیابی
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    conf_matrix = confusion_matrix(all_labels, all_preds)
    
    return accuracy, precision, recall, f1, conf_matrix, np.array(all_embeddings), np.array(all_labels)

# رسم T-SNE
def plot_tsne(embeddings, labels, title="t-SNE Visualization"):
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels, cmap='viridis', alpha=0.6)
    plt.colorbar(scatter)
    plt.title(title)
    plt.show()

# حلقه آموزش
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, triplet_loss, device)
    accuracy, precision, recall, f1, conf_matrix, embeddings, true_labels = evaluate(model, val_loader, device)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}")
    print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")
    
    # نمایش Confusion Matrix فقط برای آخرین epoch
    if epoch == num_epochs - 1:
        print("Confusion Matrix (Final Epoch):")
        print(conf_matrix)
    
    # رسم T-SNE در آخرین epoch
    if epoch == num_epochs - 1:
        plot_tsne(embeddings, true_labels, title=f"t-SNE Visualization (Epoch {epoch+1})")

# ذخیره مدل
torch.save(model.state_dict(), "custom_vit_triplet_fewshot.pth")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
from tqdm import tqdm

# تنظیمات اولیه
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
num_epochs = 10
margin = 1.0  # Margin برای Triplet Loss
input_size = 224  # اندازه تصویر ورودی
patch_size = 16  # اندازه پچ‌ها
num_patches = (input_size // patch_size) ** 2  # تعداد پچ‌ها
hidden_size = 768  # اندازه hidden state در ViT
num_heads = 12  # تعداد head‌ها در Multi-Head Attention

# تبدیل‌های تصویر
transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# تعریف Dataset
class PlantDiseaseDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            image = Image.open(self.image_paths[idx]).convert("RGB")
            label = self.labels[idx]
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            return None, None

# بارگذاری داده‌ها از پوشه‌های train, val, test
def load_data_from_folders(base_dir):
    splits = ['train', 'val', 'test']
    data = {'train': {'image_paths': [], 'labels': []},
            'val': {'image_paths': [], 'labels': []},
            'test': {'image_paths': [], 'labels': []}}
    
    # لیست کلاس‌ها (نام پوشه‌ها)
    class_names = sorted(os.listdir(os.path.join(base_dir, 'train')))
    class_to_idx = {class_name: idx for idx, class_name in enumerate(class_names)}
    
    for split in splits:
        split_dir = os.path.join(base_dir, split)
        for class_name in class_names:
            class_dir = os.path.join(split_dir, class_name)
            if not os.path.isdir(class_dir):
                continue
            for image_name in os.listdir(class_dir):
                image_path = os.path.join(class_dir, image_name)
                if os.path.isfile(image_path):
                    data[split]['image_paths'].append(image_path)
                    data[split]['labels'].append(class_to_idx[class_name])
    
    return data, class_to_idx

# مسیر اصلی دیتاست
base_dir = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'

# بارگذاری داده‌ها
data, class_to_idx = load_data_from_folders(base_dir)

# ایجاد Dataset و DataLoader برای train, val, test
train_dataset = PlantDiseaseDataset(data['train']['image_paths'], data['train']['labels'], transform=transform)
val_dataset = PlantDiseaseDataset(data['val']['image_paths'], data['val']['labels'], transform=transform)
test_dataset = PlantDiseaseDataset(data['test']['image_paths'], data['test']['labels'], transform=transform)

# حذف نمونه‌های معیوب از دیتاست
train_dataset = [data for data in train_dataset if data[0] is not None]
val_dataset = [data for data in val_dataset if data[0] is not None]
test_dataset = [data for data in test_dataset if data[0] is not None]

# ایجاد DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# چاپ اطلاعات
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")
print(f"Class to index mapping: {class_to_idx}")

# تعریف ViT با تغییرات برای Few-Shot Learning و Triplet Network
class CustomViT(nn.Module):
    def __init__(self, num_patches, hidden_size, num_heads, num_classes):
        super(CustomViT, self).__init__()
        self.patch_embedding = nn.Conv2d(3, hidden_size, kernel_size=patch_size, stride=patch_size)
        self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size))  # +1 برای CLS Token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads),
            num_layers=6
        )
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # تبدیل تصویر به پچ‌ها
        x = self.patch_embedding(x)  # شکل: (batch_size, hidden_size, num_patches_h, num_patches_w)
        x = x.flatten(2).transpose(1, 2)  # شکل: (batch_size, num_patches, hidden_size)

        # اضافه کردن CLS Token و Positional Embedding
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.positional_embedding

        # پردازش با Transformer
        x = self.transformer(x)

        # استفاده از CLS Token برای طبقه‌بندی
        cls_output = x[:, 0, :]
        logits = self.fc(cls_output)
        return logits, cls_output

# تعریف Triplet Loss
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

# تعریف مدل و تابع زیان
num_classes = len(class_to_idx)  # تعداد کلاس‌ها
model = CustomViT(num_patches, hidden_size, num_heads, num_classes).to(device)
triplet_loss = TripletLoss(margin=margin)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# آموزش مدل
def train(model, train_loader, optimizer, triplet_loss, device):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc="Training"):
        images = images.to(device)
        labels = labels.to(device)

        # تولید triplets (anchor, positive, negative)
        logits, embeddings = model(images)
        anchors = embeddings
        positives = embeddings  # در واقعیت، positive باید از همان کلاس باشد
        negatives = embeddings  # در واقعیت، negative باید از کلاس متفاوت باشد

        # محاسبه Triplet Loss
        loss = triplet_loss(anchors, positives, negatives)

        # به‌روزرسانی وزن‌ها
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    return running_loss / len(train_loader)

# ارزیابی مدل
def evaluate(model, val_loader, device):
    model.eval()
    all_labels = []
    all_preds = []
    all_embeddings = []
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images = images.to(device)
            labels = labels.to(device)
            logits, embeddings = model(images)
            _, preds = torch.max(logits, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_embeddings.extend(embeddings.cpu().numpy())
    
    # محاسبه معیارهای ارزیابی
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    conf_matrix = confusion_matrix(all_labels, all_preds)
    
    return precision, recall, f1, conf_matrix, np.array(all_embeddings), np.array(all_labels)

# رسم T-SNE
def plot_tsne(embeddings, labels, title="t-SNE Visualization"):
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels, cmap='viridis', alpha=0.6)
    plt.colorbar(scatter)
    plt.title(title)
    plt.show()

# حلقه آموزش
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, triplet_loss, device)
    precision, recall, f1, conf_matrix, embeddings, true_labels = evaluate(model, val_loader, device)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}")
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")
    print("Confusion Matrix:")
    print(conf_matrix)
    
    # رسم T-SNE در آخرین epoch
    if epoch == num_epochs - 1:
        plot_tsne(embeddings, true_labels, title=f"t-SNE Visualization (Epoch {epoch+1})")

# ذخیره مدل
torch.save(model.state_dict(), "custom_vit_triplet_fewshot.pth")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
from tqdm import tqdm

# تنظیمات اولیه
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
num_epochs = 10
margin = 1.0  # Margin برای Triplet Loss
input_size = 224  # اندازه تصویر ورودی
patch_size = 16  # اندازه پچ‌ها
num_patches = (input_size // patch_size) ** 2  # تعداد پچ‌ها
hidden_size = 768  # اندازه hidden state در ViT
num_heads = 12  # تعداد head‌ها در Multi-Head Attention

# تبدیل‌های تصویر
transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# تعریف Dataset
class PlantDiseaseDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            image = Image.open(self.image_paths[idx]).convert("RGB")
            label = self.labels[idx]
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            return None, None

# بارگذاری داده‌ها از پوشه‌های train, val, test
def load_data_from_folders(base_dir):
    splits = ['train', 'val', 'test']
    data = {'train': {'image_paths': [], 'labels': []},
            'val': {'image_paths': [], 'labels': []},
            'test': {'image_paths': [], 'labels': []}}
    
    class_names = sorted(os.listdir(os.path.join(base_dir, 'train')))
    class_to_idx = {class_name: idx for idx, class_name in enumerate(class_names)}
    
    for split in splits:
        split_dir = os.path.join(base_dir, split)
        for class_name in class_names:
            class_dir = os.path.join(split_dir, class_name)
            if not os.path.isdir(class_dir):
                continue
            for image_name in os.listdir(class_dir):
                image_path = os.path.join(class_dir, image_name)
                if os.path.isfile(image_path):
                    data[split]['image_paths'].append(image_path)
                    data[split]['labels'].append(class_to_idx[class_name])
    
    return data, class_to_idx

# مسیر اصلی دیتاست
base_dir = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'

# بارگذاری داده‌ها
data, class_to_idx = load_data_from_folders(base_dir)

# ایجاد Dataset و DataLoader برای train, val, test
train_dataset = PlantDiseaseDataset(data['train']['image_paths'], data['train']['labels'], transform=transform)
val_dataset = PlantDiseaseDataset(data['val']['image_paths'], data['val']['labels'], transform=transform)
test_dataset = PlantDiseaseDataset(data['test']['image_paths'], data['test']['labels'], transform=transform)

# حذف نمونه‌های معیوب از دیتاست
train_dataset = [data for data in train_dataset if data[0] is not None]
val_dataset = [data for data in val_dataset if data[0] is not None]
test_dataset = [data for data in test_dataset if data[0] is not None]

# ایجاد DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# چاپ اطلاعات
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")
print(f"Class to index mapping: {class_to_idx}")

# تعریف ViT با تغییرات برای Few-Shot Learning و Triplet Network
class CustomViT(nn.Module):
    def __init__(self, num_patches, hidden_size, num_heads, num_classes):
        super(CustomViT, self).__init__()
        self.patch_embedding = nn.Conv2d(3, hidden_size, kernel_size=patch_size, stride=patch_size)
        self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size))  # +1 برای CLS Token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads),
            num_layers=6
        )
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = x.flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.positional_embedding
        x = self.transformer(x)
        cls_output = x[:, 0, :]
        logits = self.fc(cls_output)
        return logits, cls_output

# تعریف Triplet Loss
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

# تعریف مدل و تابع زیان
num_classes = len(class_to_idx)
model = CustomViT(num_patches, hidden_size, num_heads, num_classes).to(device)
triplet_loss = TripletLoss(margin=margin)
cross_entropy_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# تولید Triplet‌ها
def generate_triplets(labels, embeddings):
    triplets = []
    for i in range(len(labels)):
        anchor = embeddings[i]
        positive = None
        negative = None
        for j in range(len(labels)):
            if labels[j] == labels[i] and i != j:
                positive = embeddings[j]
                break
        for j in range(len(labels)):
            if labels[j] != labels[i]:
                negative = embeddings[j]
                break
        triplets.append((anchor, positive, negative))
    return triplets

# آموزش مدل
def train(model, train_loader, optimizer, triplet_loss, cross_entropy_loss, device):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc="Training"):
        images = images.to(device)
        labels = labels.to(device)
        
        logits, embeddings = model(images)
        
        # محاسبه Cross-Entropy Loss
        ce_loss = cross_entropy_loss(logits, labels)

        # تولید Triplet‌ها و محاسبه Triplet Loss
        triplets = generate_triplets(labels.cpu().numpy(), embeddings.cpu().detach().numpy())
        triplet_losses = [triplet_loss(anchor, positive, negative) for anchor, positive, negative in triplets]
        triplet_loss_value = torch.tensor(triplet_losses).mean().to(device)

        # مجموع دو loss
        total_loss = ce_loss + triplet_loss_value

        # به‌روزرسانی وزن‌ها
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

    return running_loss / len(train_loader)

# ارزیابی مدل
def evaluate(model, val_loader, device):
    model.eval()
    all_labels = []
    all_preds = []
    all_embeddings = []
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images = images.to(device)
            labels = labels.to(device)
            logits, embeddings = model(images)
            _, preds = torch.max(logits, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_embeddings.extend(embeddings.cpu().numpy())
    
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    conf_matrix = confusion_matrix(all_labels, all_preds)
    
    return precision, recall, f1, conf_matrix, np.array(all_embeddings), np.array(all_labels)

# رسم T-SNE
def plot_tsne(embeddings, labels, title="t-SNE Visualization"):
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels, cmap='viridis', alpha=0.6)
    plt.colorbar(scatter)
    plt.title(title)
    plt.show()

# حلقه آموزش
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, triplet_loss, cross_entropy_loss, device)
    precision, recall, f1, conf_matrix, embeddings, true_labels = evaluate(model, val_loader, device)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}")
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")
    print("Confusion Matrix:")
    print(conf_matrix)
    
    # رسم T-SNE در آخرین epoch
    if epoch == num_epochs - 1:
        plot_tsne(embeddings, true_labels, title=f"t-SNE Visualization (Epoch {epoch+1})")

# ذخیره مدل
torch.save(model.state_dict(), "custom_vit_triplet_fewshot.pth")
