In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from transformers import ViTModel
from torch.optim import Adam
from PIL import Image
import numpy as np
import os
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# 1. تعریف Dataset برای Triplet‌ها
class TripletDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.classes = dataset.classes
        self.class_to_indices = {cls: np.where(np.array(dataset.targets) == idx)[0] for idx, cls in enumerate(self.classes)}
        self.transform = Compose([
            Resize((224, 224)),  # تغییر اندازه تصاویر به 224x224
            ToTensor(),  # تبدیل تصاویر به تانسور
            Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # نرمال‌سازی
        ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        anchor_path, anchor_label = self.dataset.samples[idx]  # مسیر فایل و برچسب

        # انتخاب Positive (همان کلاس)
        positive_idx = np.random.choice(self.class_to_indices[self.classes[anchor_label]])
        positive_path, _ = self.dataset.samples[positive_idx]

        # انتخاب Negative (کلاس متفاوت)
        negative_label = np.random.choice([cls for cls in self.classes if cls != self.classes[anchor_label]])
        negative_idx = np.random.choice(self.class_to_indices[negative_label])
        negative_path, _ = self.dataset.samples[negative_idx]

        # پیش‌پردازش تصاویر
        anchor_image = self.transform(Image.open(anchor_path).convert("RGB"))
        positive_image = self.transform(Image.open(positive_path).convert("RGB"))
        negative_image = self.transform(Image.open(negative_path).convert("RGB"))

        return anchor_image, positive_image, negative_image

# 2. تعریف Attention Layer
class AttentionLayer(nn.Module):
    def __init__(self, feature_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(),
            nn.Linear(feature_dim // 2, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        # x shape: (batch_size, sequence_length, feature_dim)
        attention_weights = self.attention(x)  # (batch_size, sequence_length, 1)
        weighted_features = x * attention_weights  # (batch_size, sequence_length, feature_dim)
        return weighted_features.sum(dim=1)  # (batch_size, feature_dim)

# 3. تعریف مدل با Attention Mechanism
class ViTWithAttention(nn.Module):
    def __init__(self, model_name):
        super(ViTWithAttention, self).__init__()
        self.vit = ViTModel.from_pretrained(model_name)
        self.attention = AttentionLayer(self.vit.config.hidden_size)

    def forward(self, x):
        outputs = self.vit(x)
        last_hidden_state = outputs.last_hidden_state  # (batch_size, sequence_length, hidden_size)
        features = self.attention(last_hidden_state)  # (batch_size, hidden_size)
        return features

# 4. تعریف Triplet Loss
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = F.pairwise_distance(anchor, positive)
        distance_negative = F.pairwise_distance(anchor, negative)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

# 5. بارگذاری داده‌ها از پوشه‌ها
#data_dir = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
data_dir = 'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")
test_dir = os.path.join(data_dir, "test")

# بارگذاری داده‌ها با ImageFolder
train_dataset = ImageFolder(train_dir)
val_dataset = ImageFolder(val_dir)
test_dataset = ImageFolder(test_dir)

# ایجاد TripletDataset
train_triplet_dataset = TripletDataset(train_dataset)
train_loader = DataLoader(train_triplet_dataset, batch_size=32, shuffle=True)

# 6. تعریف مدل، Optimizer و Loss Function
model_name = "google/vit-base-patch16-224"
model = ViTWithAttention(model_name)
optimizer = Adam(model.parameters(), lr=1e-5)
triplet_loss = TripletLoss(margin=1.0)

# 7. آموزش مدل و محاسبه Accuracy
num_epochs = 1
model.train()
for epoch in range(num_epochs):
    epoch_loss = 0
    correct = 0
    total = 0

    for batch in train_loader:
        anchor, positive, negative = batch

        # استخراج ویژگی‌ها (embeddings)
        anchor_features = model(anchor)
        positive_features = model(positive)
        negative_features = model(negative)

        # محاسبه‌ی loss
        loss = triplet_loss(anchor_features, positive_features, negative_features)
        epoch_loss += loss.item()

        # محاسبه accuracy
        distance_positive = F.pairwise_distance(anchor_features, positive_features)
        distance_negative = F.pairwise_distance(anchor_features, negative_features)
        predictions = (distance_positive < distance_negative).float()  # 1 اگر درست، 0 اگر نادرست
        correct += predictions.sum().item()
        total += predictions.size(0)

        # به‌روزرسانی مدل
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    accuracy = correct / total
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss / len(train_loader)}, Accuracy: {accuracy * 100:.2f}%")

# 8. ذخیره‌سازی مدل Fine-Tuned
# torch.save(model.state_dict(), "f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/vitWithAttentionState.pth")
# torch.save(model, "f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/vitWithAttention.pth")
torch.save(model.state_dict(), "C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/vitWithAttentionState.pth")
torch.save(model, "C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/vitWithAttention.pth")

# 9. ارزیابی مدل و تجسم با t-SNE
# def evaluate_model(model, dataloader):
#     model.eval()
#     all_features = []
#     all_labels = []

#     with torch.no_grad():
#         for batch in dataloader:
#             anchor, positive, negative = batch
#             anchor_features = model(anchor)
#             all_features.append(anchor_features.cpu().numpy())
#             all_labels.append(np.zeros(anchor_features.shape[0]))  # برچسب‌های ساختگی برای تجسم

#     all_features = np.concatenate(all_features, axis=0)
#     all_labels = np.concatenate(all_labels, axis=0)

#     # کاهش ابعاد با t-SNE
#     tsne = TSNE(n_components=2, random_state=42)
#     tsne_results = tsne.fit_transform(all_features)

#     # تجسم داده‌ها
#     plt.figure(figsize=(10, 8))
#     plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=all_labels, cmap='viridis', alpha=0.6)
#     plt.colorbar()
#     plt.title("t-SNE Visualization of Embeddings")
#     plt.show()
 
 
# 9. ارزیابی مدل و تجسم با t-SNE
def evaluate_model(model, dataloader):
    model.eval()
    all_features = []
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for batch in dataloader:
            anchor, positive, negative = batch
            anchor_features = model(anchor)
            positive_features = model(positive)
            negative_features = model(negative)

            # جمع‌آوری ویژگی‌ها برای t-SNE
            all_features.append(anchor_features.cpu().numpy())
            all_labels.append(np.array([label for _, label in dataloader.dataset.dataset.samples]))

            # محاسبه پیش‌بینی‌ها
            distance_positive = F.pairwise_distance(anchor_features, positive_features)
            distance_negative = F.pairwise_distance(anchor_features, negative_features)
            predictions = (distance_positive < distance_negative).float().cpu().numpy()  # 1 اگر anchor و positive نزدیک‌تر باشند
            all_predictions.append(predictions)

    all_features = np.concatenate(all_features, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    all_predictions = np.concatenate(all_predictions, axis=0)

    # کاهش ابعاد با t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(all_features)

    # تجسم داده‌ها
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=all_labels, cmap='viridis', alpha=0.6)
    plt.colorbar(scatter)
    plt.title("t-SNE Visualization of Embeddings")
    
    # اضافه کردن legend برای کلاس‌ها
    classes = dataloader.dataset.dataset.classes
    handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=plt.cm.viridis(i / (len(classes) - 1)), markersize=10) for i in range(len(classes))]
    plt.legend(handles, classes, title="Classes")
    
    plt.show()

    # محاسبه معیارهای ارزیابی
    true_labels = np.ones_like(all_predictions)  # برچسب‌های واقعی (همه ۱ هستند، زیرا anchor و positive متعلق به یک کلاس هستند)
    accuracy = accuracy_score(true_labels, all_predictions)
    precision = precision_score(true_labels, all_predictions)
    recall = recall_score(true_labels, all_predictions)
    f1 = f1_score(true_labels, all_predictions)

    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

 
     
    
# ارزیابی مدل روی داده‌های تست
test_triplet_dataset = TripletDataset(test_dataset)
test_loader = DataLoader(test_triplet_dataset, batch_size=32, shuffle=False)
evaluate_model(model, test_loader)

# تعداد کلاس‌ها
num_classes = len(test_dataset.classes)
print(f"Number of classes: {num_classes}")
# ارزیابی مدل روی داده‌های تست
test_triplet_dataset = TripletDataset(test_dataset)
test_loader = DataLoader(test_triplet_dataset, batch_size=32, shuffle=False)
evaluate_model(model, test_loader)

C:\Users\Mey\AppData\Roaming\Python\Python39\site-packages\numpy\.libs\libopenblas.XWYDX2IKJW2NMTWSFYNGFUWKQU3LYTCZ.gfortran-win_amd64.dll
C:\Users\Mey\AppData\Roaming\Python\Python39\site-packages\numpy\.libs\libopenblas64__v0.3.21-gcc_10_3_0.dll
  warn(
Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
