In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from transformers import ViTModel
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
from tqdm import tqdm

class TripletViT(nn.Module):
    def __init__(self, num_classes):
        super(TripletViT, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.embedding = nn.Linear(768, 512)
        self.classifier = nn.Linear(512, num_classes)
        
    def forward_one(self, x):
        x = self.vit(x).last_hidden_state[:, 0, :]
        embedding = self.embedding(x)
        return embedding
    
    def forward(self, anchor, positive, negative=None):
        anchor_emb = self.forward_one(anchor)
        positive_emb = self.forward_one(positive) if positive is not None else None
        
        if negative is not None:
            negative_emb = self.forward_one(negative)
            return anchor_emb, positive_emb, negative_emb
        
        return self.classifier(anchor_emb)

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        
    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

def train_epoch(model, train_loader, criterion_triplet, criterion_ce, optimizer, device):
    model.train()
    total_loss = 0
    total_acc = 0
    num_batches = 0
    
    for batch in tqdm(train_loader, desc="Training"):
        anchor, positive, negative, labels = batch
        anchor, positive = anchor.to(device), positive.to(device)
        negative, labels = negative.to(device), labels.to(device)
        
        # Compute embeddings and losses
        anchor_emb, positive_emb, negative_emb = model(anchor, positive, negative)
        loss_triplet = criterion_triplet(anchor_emb, positive_emb, negative_emb)
        
        # Classification
        pred = model(anchor, None, None)
        loss_ce = criterion_ce(pred, labels)
        
        # Total loss
        loss = loss_triplet + loss_ce
        
        # Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        pred_labels = torch.argmax(pred, dim=1)
        total_acc += accuracy_score(labels.cpu(), pred_labels.cpu())
        num_batches += 1
    
    return total_loss / num_batches, total_acc / num_batches

def validate(model, val_loader, criterion_triplet, criterion_ce, device):
    model.eval()
    total_loss = 0
    total_acc = 0
    num_batches = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            anchor, positive, negative, labels = batch
            anchor, positive = anchor.to(device), positive.to(device)
            negative, labels = negative.to(device), labels.to(device)
            
            # Forward pass
            anchor_emb, positive_emb, negative_emb = model(anchor, positive, negative)
            loss_triplet = criterion_triplet(anchor_emb, positive_emb, negative_emb)
            
            pred = model(anchor, None, None)
            loss_ce = criterion_ce(pred, labels)
            loss = loss_triplet + loss_ce
            
            # Metrics
            pred_labels = torch.argmax(pred, dim=1)
            all_preds.extend(pred_labels.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            total_loss += loss.item()
            total_acc += accuracy_score(labels.cpu(), pred_labels.cpu())
            num_batches += 1
    
    # Calculate metrics
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='weighted'
    )
    
    return {
        'loss': total_loss / num_batches,
        'accuracy': total_acc / num_batches,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

def test(model, test_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            images, labels = batch
            images = images.to(device)
            
            # Forward pass
            outputs = model(images, None, None)
            pred_labels = torch.argmax(outputs, dim=1)
            
            all_preds.extend(pred_labels.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='weighted'
    )
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

# Data preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

# Training setup
def main():
    # Hyperparameters
    num_classes = 10
    batch_size = 32
    num_epochs = 50
    learning_rate = 1e-4
    
    # Model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = TripletViT(num_classes).to(device)
    
    # Loss and optimizer
    criterion_triplet = TripletLoss()
    criterion_ce = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Load data
    path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
    train_data = torchvision.datasets.ImageFolder(root=path_data + '/train/', transform=get_val_transforms())
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    
    val_data = torchvision.datasets.ImageFolder(root=path_data + '/val/', transform=get_val_transforms())
    val_loader = DataLoader(val_data, batch_size=batch_size)
    
    test_data = torchvision.datasets.ImageFolder(root=path_data + '/test/', transform=get_val_transforms())
    test_loader = DataLoader(test_data, batch_size=batch_size)
    
    # Training loop
    best_val_f1 = 0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion_triplet, criterion_ce, optimizer, device
        )
        
        # Validate
        val_metrics = validate(model, val_loader, criterion_triplet, criterion_ce, device)
        
        # Print metrics
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")
        print(f"Val F1: {val_metrics['f1']:.4f}")
        
        # Save best model
        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            torch.save(model.state_dict(), 'best_model.pth')
    
    # Test best model
    model.load_state_dict(torch.load('best_model.pth'))
    test_metrics = test(model, test_loader, device)
    print("\nTest Results:")
    print(f"Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Precision: {test_metrics['precision']:.4f}")
    print(f"Recall: {test_metrics['recall']:.4f}")
    print(f"F1 Score: {test_metrics['f1']:.4f}")

if __name__ == "__main__":
    main()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from transformers import ViTModel
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
from tqdm import tqdm

class FewShotTripletViT(nn.Module):
    def __init__(self, num_classes, n_way=15, k_shot=15):
        super(FewShotTripletViT, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.embedding = nn.Linear(768, 512)
        self.n_way = n_way
        self.k_shot = k_shot
        
    def forward_one(self, x):
        x = self.vit(x).last_hidden_state[:, 0, :]
        embedding = self.embedding(x)
        return F.normalize(embedding, p=2, dim=1)  # L2 normalization
    
    def forward(self, support_set, query):
        n_way, k_shot = support_set.shape[:2]
        
        # Support set embeddings
        support_embeddings = []
        for i in range(n_way):
            class_embeddings = []
            for j in range(k_shot):
                emb = self.forward_one(support_set[i, j])
                class_embeddings.append(emb)
            support_embeddings.append(torch.stack(class_embeddings).mean(0))
        support_embeddings = torch.stack(support_embeddings)
        
        # Query embedding
        query_embedding = self.forward_one(query)
        
        # Similarity scores
        similarities = torch.mm(query_embedding, support_embeddings.t())
        return similarities

class FewShotDataset(Dataset):
    def __init__(self, dataset, n_way=15, k_shot=15, n_query=15, n_episodes=1000):
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.n_episodes = n_episodes
        
        # Group by label
        self.label_to_indices = {}
        for idx, (_, label) in enumerate(dataset):
            if label not in self.label_to_indices:
                self.label_to_indices[label] = []
            self.label_to_indices[label].append(idx)
    
    def __len__(self):
        return self.n_episodes
    
    def __getitem__(self, idx):
        selected_classes = np.random.choice(
            list(self.label_to_indices.keys()),
            self.n_way,
            replace=False
        )
        
        support_images = []
        query_images = []
        query_labels = []
        
        for i, cls in enumerate(selected_classes):
            indices = np.random.choice(
                self.label_to_indices[cls],
                self.k_shot + self.n_query,
                replace=False
            )
            
            support_idx = indices[:self.k_shot]
            query_idx = indices[self.k_shot:self.k_shot + self.n_query]
            
            for idx in support_idx:
                img, _ = self.dataset[idx]
                support_images.append((i, img))
            
            for idx in query_idx:
                img, _ = self.dataset[idx]
                query_images.append(img)
                query_labels.append(i)
        
        support_images = torch.stack([img for _, img in sorted(support_images)])
        support_images = support_images.view(self.n_way, self.k_shot, *support_images.shape[1:])
        query_images = torch.stack(query_images)
        query_labels = torch.tensor(query_labels)
        
        return support_images, query_images, query_labels

def train_epoch(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    total_acc = 0
    episodes = 0
    
    for support_images, query_images, query_labels in tqdm(train_loader, desc="Training"):
        support_images = support_images.to(device)
        query_images = query_images.to(device)
        query_labels = query_labels.to(device)
        
        # Forward pass
        scores = model(support_images, query_images)
        loss = F.cross_entropy(scores, query_labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Metrics
        pred = scores.argmax(dim=1)
        acc = (pred == query_labels).float().mean()
        
        total_loss += loss.item()
        total_acc += acc.item()
        episodes += 1
    
    return total_loss / episodes, total_acc / episodes

def evaluate(model, data_loader, device, mode="val"):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    episodes = 0
    
    with torch.no_grad():
        for support_images, query_images, query_labels in tqdm(data_loader, desc=mode.capitalize()):
            support_images = support_images.to(device)
            query_images = query_images.to(device)
            query_labels = query_labels.to(device)
            
            # Forward pass
            scores = model(support_images, query_images)
            loss = F.cross_entropy(scores, query_labels)
            
            # Predictions
            pred = scores.argmax(dim=1)
            
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(query_labels.cpu().numpy())
            
            total_loss += loss.item()
            episodes += 1
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='weighted'
    )
    
    return {
        'loss': total_loss / episodes,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

def main():
    # Hyperparameters
    n_way = 15
    k_shot = 15
    n_query = 15
    n_episodes = 1000
    batch_size = 4
    num_epochs = 50
    learning_rate = 1e-4
    
    # Transform
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Datasets
    path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
    train_dataset = FewShotDataset(
        datasets.ImageFolder(path_data+'/train/', transform=transform),
        n_way=n_way, k_shot=k_shot, n_query=n_query, n_episodes=n_episodes
    )
    val_dataset = FewShotDataset(
        datasets.ImageFolder(path_data+'/val/', transform=transform),
        n_way=n_way, k_shot=k_shot, n_query=n_query, n_episodes=200
    )
    test_dataset = FewShotDataset(
        datasets.ImageFolder(path_data+'/test/', transform=transform),
        n_way=n_way, k_shot=k_shot, n_query=n_query, n_episodes=200
    )
    
    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # Model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = FewShotTripletViT(num_classes=n_way, n_way=n_way, k_shot=k_shot).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Training loop
    best_val_f1 = 0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, device)
        
        # Validate
        val_metrics = evaluate(model, val_loader, device, mode="val")
        
        # Print metrics
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")
        print(f"Val F1: {val_metrics['f1']:.4f}")
        
        # Save best model
        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': best_val_f1,
            }, 'best_model.pth')
    
    # Test best model
    checkpoint = torch.load('best_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    test_metrics = evaluate(model, test_loader, device, mode="test")
    
    print("\nFinal Test Results:")
    print(f"Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Precision: {test_metrics['precision']:.4f}")
    print(f"Recall: {test_metrics['recall']:.4f}")
    print(f"F1 Score: {test_metrics['f1']:.4f}")

if __name__ == "__main__":
    main()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import ViTModel
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

class InnovativeFewShotViT(nn.Module):
    def __init__(self, n_way=15, k_shot=15):
        super().__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.embedding = nn.Linear(768, 512)
        self.n_way = n_way
        self.k_shot = k_shot
        self.margin = 1.0
        
    def forward_one(self, x):
        x = self.vit(x).last_hidden_state[:, 0, :]
        embedding = self.embedding(x)
        return F.normalize(embedding, p=2, dim=1)
    
    def forward(self, support_set, query, mode='train'):
        n_way, k_shot = support_set.shape[:2]
        
        if mode == 'train':
            # Triplet learning within support set
            support_embeddings = []
            triplet_loss = 0
            
            for i in range(n_way):
                class_embeddings = []
                for j in range(k_shot):
                    emb = self.forward_one(support_set[i, j])
                    class_embeddings.append(emb)
                    
                    # Create triplets within same class
                    if j > 0:
                        anchor = emb
                        positive = class_embeddings[j-1]
                        # Get negative from different class
                        neg_class = (i + 1) % n_way
                        negative = self.forward_one(support_set[neg_class, j])
                        
                        d_pos = F.pairwise_distance(anchor, positive)
                        d_neg = F.pairwise_distance(anchor, negative)
                        triplet_loss += F.relu(d_pos - d_neg + self.margin).mean()
                
                support_embeddings.append(torch.stack(class_embeddings).mean(0))
            
            support_embeddings = torch.stack(support_embeddings)
            query_embedding = self.forward_one(query)
            
            # Prototypical network style classification
            logits = -torch.cdist(query_embedding.unsqueeze(1), 
                                support_embeddings.unsqueeze(0)).squeeze(1)
            
            return logits, triplet_loss
            
        else: # Inference mode
            support_embeddings = []
            for i in range(n_way):
                class_embs = []
                for j in range(k_shot):
                    emb = self.forward_one(support_set[i, j])
                    class_embs.append(emb)
                support_embeddings.append(torch.stack(class_embs).mean(0))
            
            support_embeddings = torch.stack(support_embeddings)
            query_embedding = self.forward_one(query)
            logits = -torch.cdist(query_embedding.unsqueeze(1),
                                support_embeddings.unsqueeze(0)).squeeze(1)
            return logits

class FewShotDataset:
    def __init__(self, dataset, n_way=15, k_shot=15, n_query=5, n_episodes=1000):
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.n_episodes = n_episodes
        
        self.label_to_indices = {}
        for idx, (_, label) in enumerate(dataset):
            if label not in self.label_to_indices:
                self.label_to_indices[label] = []
            self.label_to_indices[label].append(idx)
            
    def __len__(self):
        return self.n_episodes
    
    def __getitem__(self, idx):
        selected_classes = np.random.choice(
            list(self.label_to_indices.keys()),
            self.n_way,
            replace=False
        )
        
        support_images = []
        query_images = []
        query_labels = []
        
        for i, cls in enumerate(selected_classes):
            indices = np.random.choice(
                self.label_to_indices[cls],
                self.k_shot + self.n_query,
                replace=False
            )
            
            support_idx = indices[:self.k_shot]
            query_idx = indices[self.k_shot:]
            
            for idx in support_idx:
                img, _ = self.dataset[idx]
                support_images.append((i, img))
            
            for idx in query_idx:
                img, _ = self.dataset[idx]
                query_images.append(img)
                query_labels.append(i)
                
        support_images = torch.stack([img for _, img in sorted(support_images)])
        support_images = support_images.view(self.n_way, self.k_shot, *support_images.shape[1:])
        query_images = torch.stack(query_images)
        query_labels = torch.tensor(query_labels)
        
        return support_images, query_images, query_labels

def train_epoch(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    total_acc = 0
    episodes = 0
    
    for support_images, query_images, query_labels in tqdm(train_loader):
        support_images = support_images.to(device)
        query_images = query_images.to(device)
        query_labels = query_labels.to(device)
        
        logits, triplet_loss = model(support_images, query_images, mode='train')
        ce_loss = F.cross_entropy(logits, query_labels)
        
        # Combined loss
        loss = ce_loss + 0.5 * triplet_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        pred = logits.argmax(dim=1)
        acc = (pred == query_labels).float().mean()
        
        total_loss += loss.item()
        total_acc += acc.item()
        episodes += 1
    
    return total_loss / episodes, total_acc / episodes

def evaluate(model, data_loader, device, mode="val"):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for support_images, query_images, query_labels in tqdm(data_loader):
            support_images = support_images.to(device)
            query_images = query_images.to(device)
            query_labels = query_labels.to(device)
            
            logits = model(support_images, query_images, mode='test')
            pred = logits.argmax(dim=1)
            
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(query_labels.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='weighted'
    )
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

def main():
    # Hyperparameters
    n_way = 15
    k_shot = 15
    n_query = 5
    batch_size = 2
    num_epochs = 50
    learning_rate = 1e-4
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Datasets
    train_dataset = FewShotDataset(
        YourDataset('path/to/train', transform=transform),
        n_way=n_way, k_shot=k_shot, n_query=n_query
    )
    val_dataset = FewShotDataset(
        YourDataset('path/to/val', transform=transform),
        n_way=n_way, k_shot=k_shot, n_query=n_query
    )
    test_dataset = FewShotDataset(
        YourDataset('path/to/test', transform=transform),
        n_way=n_way, k_shot=k_shot, n_query=n_query
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = InnovativeFewShotViT(n_way=n_way, k_shot=k_shot).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    best_val_f1 = 0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, device)
        val_metrics = evaluate(model, val_loader, device, mode="val")
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Acc: {val_metrics['accuracy']:.4f}, Val F1: {val_metrics['f1']:.4f}")
        
        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': best_val_f1,
            }, 'best_model.pth')
    
    checkpoint = torch.load('best_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    test_metrics = evaluate(model, test_loader, device, mode="test")
    
    print("\nTest Results:")
    print(f"Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Precision: {test_metrics['precision']:.4f}")
    print(f"Recall: {test_metrics['recall']:.4f}")
    print(f"F1 Score: {test_metrics['f1']:.4f}")

if __name__ == "__main__":
    main()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTModel
import numpy as np

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads)
        
    def forward(self, x):
        x = x.transpose(0, 1)  # (N, L, E) -> (L, N, E)
        attn_output, _ = self.mha(x, x, x)
        return attn_output.transpose(0, 1)  # (L, N, E) -> (N, L, E)

class InnovativeFewShotViT(nn.Module):
    def __init__(self, n_way=15, k_shot=15, embed_dim=512, num_heads=8):
        super().__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.embedding = nn.Linear(768, embed_dim)
        
        # Multi-head attention for support set processing
        self.support_attention = MultiHeadAttention(embed_dim, num_heads)
        self.query_attention = MultiHeadAttention(embed_dim, num_heads)
        
        # Few-shot parameters
        self.n_way = n_way
        self.k_shot = k_shot
        self.margin = 1.0
        
    def forward_one(self, x):
        # Extract features using ViT
        x = self.vit(x).last_hidden_state[:, 0, :]  # Get [CLS] token
        embedding = self.embedding(x)
        return F.normalize(embedding, p=2, dim=1)
    
    def get_prototypes(self, support_embeddings):
        # Process support set with attention
        support_embeddings = support_embeddings.view(self.n_way * self.k_shot, -1)
        attended_support = self.support_attention(support_embeddings)
        
        # Calculate prototypes
        prototypes = attended_support.view(self.n_way, self.k_shot, -1).mean(1)
        return prototypes
    
    def forward(self, support_set, query, mode='train'):
        n_way, k_shot = support_set.shape[:2]
        
        if mode == 'train':
            # Process support set
            support_embeddings = []
            triplet_loss = 0
            
            for i in range(n_way):
                class_embeddings = []
                for j in range(k_shot):
                    # Get embedding for each support image
                    emb = self.forward_one(support_set[i, j])
                    class_embeddings.append(emb)
                    
                    # Triplet loss computation
                    if j > 0:
                        anchor = emb
                        positive = class_embeddings[j-1]
                        
                        # Get negative from different class
                        neg_class = (i + 1) % n_way
                        negative = self.forward_one(support_set[neg_class, j])
                        
                        # Calculate distances for triplet loss
                        d_pos = F.pairwise_distance(anchor, positive)
                        d_neg = F.pairwise_distance(anchor, negative)
                        triplet_loss += F.relu(d_pos - d_neg + self.margin).mean()
                
                support_embeddings.append(torch.stack(class_embeddings))
            
            support_embeddings = torch.stack(support_embeddings)
            
            # Get prototypes using attention
            prototypes = self.get_prototypes(support_embeddings)
            
            # Process query with attention
            query_embedding = self.forward_one(query)
            attended_query = self.query_attention(query_embedding)
            
            # Calculate similarity scores
            logits = -torch.cdist(attended_query.unsqueeze(1), 
                                prototypes.unsqueeze(0)).squeeze(1)
            
            return logits, triplet_loss
            
        else:  # Inference mode
            support_embeddings = []
            for i in range(n_way):
                class_embs = []
                for j in range(k_shot):
                    emb = self.forward_one(support_set[i, j])
                    class_embs.append(emb)
                support_embeddings.append(torch.stack(class_embs))
            
            support_embeddings = torch.stack(support_embeddings)
            prototypes = self.get_prototypes(support_embeddings)
            
            query_embedding = self.forward_one(query)
            attended_query = self.query_attention(query_embedding)
            
            logits = -torch.cdist(attended_query.unsqueeze(1),
                                prototypes.unsqueeze(0)).squeeze(1)
            return logits

# Loss function
class CombinedLoss(nn.Module):
    def __init__(self, triplet_weight=0.5):
        super().__init__()
        self.triplet_weight = triplet_weight
        self.ce = nn.CrossEntropyLoss()
        
    def forward(self, logits, labels, triplet_loss):
        ce_loss = self.ce(logits, labels)
        return ce_loss + self.triplet_weight * triplet_loss

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    total_acc = 0
    episodes = 0
    
    for support_images, query_images, query_labels in train_loader:
        support_images = support_images.to(device)
        query_images = query_images.to(device)
        query_labels = query_labels.to(device)
        
        # Forward pass
        logits, triplet_loss = model(support_images, query_images, mode='train')
        
        # Calculate combined loss
        loss = criterion(logits, query_labels, triplet_loss)
        
        # Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Metrics
        pred = logits.argmax(dim=1)
        acc = (pred == query_labels).float().mean()
        
        total_loss += loss.item()
        total_acc += acc.item()
        episodes += 1
    
    return total_loss / episodes, total_acc / episodes

In [None]:
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTModel
import numpy as np
import torchvision
import sys
sys.path.insert(0,'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
from dataloaders import get_train_transforms, get_val_transforms, get_triplet_dataloader
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads)
        
    def forward(self, x):
        x = x.transpose(0, 1)  # (N, L, E) -> (L, N, E)
        attn_output, _ = self.mha(x, x, x)
        return attn_output.transpose(0, 1)  # (L, N, E) -> (N, L, E)

class InnovativeFewShotViT(nn.Module):
    def __init__(self, n_way=15, k_shot=15, embed_dim=512, num_heads=8):
        super().__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.embedding = nn.Linear(768, embed_dim)
        
        # Multi-head attention for support set processing
        self.support_attention = MultiHeadAttention(embed_dim, num_heads)
        self.query_attention = MultiHeadAttention(embed_dim, num_heads)
        
        # Few-shot parameters
        self.n_way = n_way
        self.k_shot = k_shot
        self.margin = 1.0
        
    def forward_one(self, x):
        # Extract features using ViT
        x = self.vit(x).last_hidden_state[:, 0, :]  # Get [CLS] token
        embedding = self.embedding(x)
        return F.normalize(embedding, p=2, dim=1)
    
    def get_prototypes(self, support_embeddings):
        # Process support set with attention
        support_embeddings = support_embeddings.view(self.n_way * self.k_shot, -1)
        attended_support = self.support_attention(support_embeddings)
        
        # Calculate prototypes
        prototypes = attended_support.view(self.n_way, self.k_shot, -1).mean(1)
        return prototypes
    
    def forward(self, support_set, query, mode='train'):
        n_way, k_shot = support_set.shape[:2]
        
        if mode == 'train':
            # Process support set
            support_embeddings = []
            triplet_loss = 0
            
            for i in range(n_way):
                class_embeddings = []
                for j in range(k_shot):
                    # Get embedding for each support image
                    emb = self.forward_one(support_set[i, j])
                    class_embeddings.append(emb)
                    
                    # Triplet loss computation
                    if j > 0:
                        anchor = emb
                        positive = class_embeddings[j-1]
                        
                        # Get negative from different class
                        neg_class = (i + 1) % n_way
                        negative = self.forward_one(support_set[neg_class, j])
                        
                        # Calculate distances for triplet loss
                        d_pos = F.pairwise_distance(anchor, positive)
                        d_neg = F.pairwise_distance(anchor, negative)
                        triplet_loss += F.relu(d_pos - d_neg + self.margin).mean()
                
                support_embeddings.append(torch.stack(class_embeddings))
            
            support_embeddings = torch.stack(support_embeddings)
            
            # Get prototypes using attention
            prototypes = self.get_prototypes(support_embeddings)
            
            # Process query with attention
            query_embedding = self.forward_one(query)
            attended_query = self.query_attention(query_embedding)
            
            # Calculate similarity scores
            logits = -torch.cdist(attended_query.unsqueeze(1), 
                                prototypes.unsqueeze(0)).squeeze(1)
            
            return logits, triplet_loss
            
        else:  # Inference mode
            support_embeddings = []
            for i in range(n_way):
                class_embs = []
                for j in range(k_shot):
                    emb = self.forward_one(support_set[i, j])
                    class_embs.append(emb)
                support_embeddings.append(torch.stack(class_embs))
            
            support_embeddings = torch.stack(support_embeddings)
            prototypes = self.get_prototypes(support_embeddings)
            
            query_embedding = self.forward_one(query)
            attended_query = self.query_attention(query_embedding)
            
            logits = -torch.cdist(attended_query.unsqueeze(1),
                                prototypes.unsqueeze(0)).squeeze(1)
            return logits

# Loss function
class CombinedLoss(nn.Module):
    def __init__(self, triplet_weight=0.5):
        super().__init__()
        self.triplet_weight = triplet_weight
        self.ce = nn.CrossEntropyLoss()
        
    def forward(self, logits, labels, triplet_loss):
        ce_loss = self.ce(logits, labels)
        return ce_loss + self.triplet_weight * triplet_loss

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    total_acc = 0
    episodes = 0
    
    for support_images, query_images, query_labels in train_loader:
        support_images = support_images.to(device)
        query_images = query_images.to(device)
        query_labels = query_labels.to(device)
        
        # Forward pass
        logits, triplet_loss = model(support_images, query_images, mode='train')
        
        # Calculate combined loss
        loss = criterion(logits, query_labels, triplet_loss)
        
        # Optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Metrics
        pred = logits.argmax(dim=1)
        acc = (pred == query_labels).float().mean()
        
        total_loss += loss.item()
        total_acc += acc.item()
        episodes += 1
    
    return total_loss / episodes, total_acc / episodes

class Trainer:
    def __init__(self, model, train_loader, val_loader, test_loader, 
                 criterion, optimizer, device, num_epochs=50):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.num_epochs = num_epochs
        
        self.best_val_f1 = 0
        self.train_metrics = []
        self.val_metrics = []
    
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        for support_images, query_images, query_labels in tqdm(self.train_loader):
            support_images = support_images.to(self.device)
            query_images = query_images.to(self.device)
            query_labels = query_labels.to(self.device)
            
            logits, triplet_loss = self.model(support_images, query_images, mode='train')
            loss = self.criterion(logits, query_labels, triplet_loss)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(query_labels.cpu().numpy())
            total_loss += loss.item()
        
        # Calculate metrics
        metrics = self.calculate_metrics(all_labels, all_preds, total_loss, len(self.train_loader))
        self.train_metrics.append(metrics)
        return metrics
    
    def validate(self):
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for support_images, query_images, query_labels in tqdm(self.val_loader):
                support_images = support_images.to(self.device)
                query_images = query_images.to(self.device)
                query_labels = query_labels.to(self.device)
                
                logits = self.model(support_images, query_images, mode='test')
                loss = self.criterion(logits, query_labels, torch.tensor(0.))
                
                preds = logits.argmax(dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(query_labels.cpu().numpy())
                total_loss += loss.item()
        
        metrics = self.calculate_metrics(all_labels, all_preds, total_loss, len(self.val_loader))
        self.val_metrics.append(metrics)
        return metrics
    
    def test(self):
        self.model.eval()
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for support_images, query_images, query_labels in tqdm(self.test_loader):
                support_images = support_images.to(self.device)
                query_images = query_images.to(self.device)
                query_labels = query_labels.to(self.device)
                
                logits = self.model(support_images, query_images, mode='test')
                preds = logits.argmax(dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(query_labels.cpu().numpy())
        
        return self.calculate_metrics(all_labels, all_preds)
    
    def calculate_metrics(self, labels, preds, loss=None, n_batches=None):
        accuracy = accuracy_score(labels, preds)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, preds, average='weighted', zero_division=0
        )
        
        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
        }
        
        if loss is not None:
            metrics['loss'] = loss / n_batches
            
        return metrics
    
    def plot_confusion_matrix(self, labels, preds, title='Confusion Matrix'):
        cm = confusion_matrix(labels, preds)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title(title)
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.show()
    
    def plot_metrics(self):
        epochs = range(1, len(self.train_metrics) + 1)
        metrics = ['accuracy', 'f1', 'loss']
        
        plt.figure(figsize=(15, 5))
        for i, metric in enumerate(metrics):
            plt.subplot(1, 3, i+1)
            train_values = [m[metric] for m in self.train_metrics]
            val_values = [m[metric] for m in self.val_metrics]
            
            plt.plot(epochs, train_values, 'b-', label='Train')
            plt.plot(epochs, val_values, 'r-', label='Validation')
            plt.title(f'{metric.capitalize()} vs Epochs')
            plt.xlabel('Epochs')
            plt.ylabel(metric.capitalize())
            plt.legend()
        
        plt.tight_layout()
        plt.show()
    
    def train(self):
        for epoch in range(self.num_epochs):
            print(f"\nEpoch {epoch+1}/{self.num_epochs}")
            
            # Train
            train_metrics = self.train_epoch()
            
            # Validate
            val_metrics = self.validate()
            
            # Print metrics
            print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, F1: {train_metrics['f1']:.4f}")
            print(f"Val   - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1']:.4f}")
            
            # Save best model
            if val_metrics['f1'] > self.best_val_f1:
                self.best_val_f1 = val_metrics['f1']
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_f1': self.best_val_f1,
                }, 'best_model.pth')
        
        # Plot training progress
        self.plot_metrics()
        
        # Load best model and test
        checkpoint = torch.load('best_model.pth')
        self.model.load_state_dict(checkpoint['model_state_dict'])
        test_metrics = self.test()
        
        print("\nTest Results:")
        print(f"Accuracy: {test_metrics['accuracy']:.4f}")
        print(f"Precision: {test_metrics['precision']:.4f}")
        print(f"Recall: {test_metrics['recall']:.4f}")
        print(f"F1 Score: {test_metrics['f1']:.4f}")
        
        return test_metrics

def main():
    # Hyperparameters
    n_way = 15
    k_shot = 15
    n_query = 5
    batch_size = 2
    num_epochs = 50
    learning_rate = 1e-4
    
    # Setup model, datasets, etc.
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = InnovativeFewShotViT(n_way=n_way, k_shot=k_shot).to(device)
    criterion = CombinedLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Setup data loaders
    path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
    train_data = torchvision.datasets.ImageFolder(root=path_data + '/train/', transform=get_val_transforms())
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    
    val_data = torchvision.datasets.ImageFolder(root=path_data + '/val/', transform=get_val_transforms())
    val_loader = DataLoader(val_data, batch_size=batch_size)
    
    test_data = torchvision.datasets.ImageFolder(root=path_data + '/test/', transform=get_val_transforms())
    test_loader = DataLoader(test_data, batch_size=batch_size)
   
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        num_epochs=num_epochs
    )
    
    # Train and evaluate
    test_metrics = trainer.train()

if __name__ == "__main__":
    main()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from transformers import ViTModel
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import sys
sys.path.insert(0,'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
 
import torchvision
from dataloaders import get_train_transforms, get_val_transforms, get_triplet_dataloader

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads)
        
    def forward(self, x):
        x = x.transpose(0, 1)
        attn_output, _ = self.mha(x, x, x)
        return attn_output.transpose(0, 1)

class InnovativeFewShotViT(nn.Module):
    def __init__(self, n_way=15, k_shot=15, embed_dim=512, num_heads=8):
        super().__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.embedding = nn.Linear(768, embed_dim)
        self.support_attention = MultiHeadAttention(embed_dim, num_heads)
        self.query_attention = MultiHeadAttention(embed_dim, num_heads)
        self.n_way = n_way
        self.k_shot = k_shot
        self.margin = 1.0

     
    def forward_one(self, x):
        # تنظیم ابعاد برای ViT
        if len(x.shape) > 4:  # [B, N, C, H, W] یا بیشتر
            x = x.view(-1, x.size(-3), x.size(-2), x.size(-1))  # [B*N, C, H, W]
        elif len(x.shape) == 3:  # [C, H, W]
            x = x.unsqueeze(0)  # [1, C, H, W]

        # تبدیل به شکل استاندارد
        if x.size(1) != 3:
            x = x.permute(0, 3, 1, 2)

        # اطمینان از ابعاد درست
        assert len(x.shape) == 4 and x.size(1) == 3, f"نادرست shape ورودی: {x.shape}"

        x = self.vit(x).last_hidden_state[:, 0]
        embedding = self.embedding(x)
        return F.normalize(embedding, p=2, dim=1)    
    
    
#     def forward(self, support_set, query, mode='train'):
#         batch_size = support_set.size(0)
#         support_embeddings = []
#         triplet_loss = 0

#         # پردازش support set
#         for i in range(self.n_way):
#             class_embeddings = []
#             for j in range(self.k_shot):
#                 # ورودی باید شکل [B, C, H, W] داشته باشد
#                 emb = self.forward_one(support_set[:, i, j])
#                 class_embeddings.append(emb)

#             support_embeddings.append(torch.stack(class_embeddings, dim=1))

#         support_embeddings = torch.stack(support_embeddings, dim=1)
#         prototypes = self.get_prototypes(support_embeddings)

#         # پردازش query
#         query_emb = []
#         for i in range(query.size(0)):
#             q = query[i].unsqueeze(0)  # [1, C, H, W]
#             q_emb = self.forward_one(q)
#             query_emb.append(q_emb)
#         query_embedding = torch.cat(query_emb, dim=0)
#         # محاسبه logits
#         logits = -torch.cdist(query_embedding.unsqueeze(1), prototypes).squeeze(1)

#         return logits, triplet_loss if mode == 'train' else logits    
#     def forward(self, support_set, query, mode='train'):
#         batch_size = support_set.size(0)

#         # پردازش support set برای هر تصویر
#         support_embeddings = []
#         for i in range(self.n_way):
#             class_embeddings = []
#             for j in range(self.k_shot):
#                 emb = self.forward_one(support_set[:, i, j])
#                 class_embeddings.append(emb)
#             support_embeddings.append(torch.stack(class_embeddings))

#         support_embeddings = torch.stack(support_embeddings)
#         prototypes = self.get_prototypes(support_embeddings)

#         # پردازش query set
#         query_emb = self.forward_one(query)  # حذف حلقه
#         logits = -torch.cdist(query_emb, prototypes.squeeze(0)).squeeze(1)

#         return logits, torch.tensor(0.0, device=logits.device)  # triplet_loss موقتاً غیرفعال 

#     def get_prototypes(self, support_embeddings):
#         batch_size = support_embeddings.size(0)
#         # تغییر shape برای MultiheadAttention
#         support_embeddings = support_embeddings.view(batch_size, self.n_way * self.k_shot, -1)
#         attended_support = self.support_attention(support_embeddings)
#         prototypes = attended_support.view(batch_size, self.n_way, self.k_shot, -1).mean(2)
#         return prototypes

    

    def get_prototypes(self, support_embeddings):
        batch_size = support_embeddings.size(0)
        support_embeddings = support_embeddings.reshape(batch_size, self.n_way * self.k_shot, -1)
        attended_support = self.support_attention(support_embeddings)
        prototypes = attended_support.reshape(batch_size, self.n_way, self.k_shot, -1).mean(2)
        return prototypes

    def forward(self, support_set, query, mode='train'):
        batch_size = support_set.size(0)

        # پردازش support set
        support_embeddings = torch.zeros(batch_size, self.n_way, self.k_shot, 512, device=support_set.device)
        for i in range(self.n_way):
            for j in range(self.k_shot):
                emb = self.forward_one(support_set[:, i, j])
                support_embeddings[:, i, j] = emb

        prototypes = self.get_prototypes(support_embeddings)

        # پردازش query
        query_emb = self.forward_one(query.view(-1, 3, 224, 224))
        logits = -torch.cdist(query_emb, prototypes.view(batch_size * self.n_way, -1))
        logits = logits.view(batch_size, -1, self.n_way)

        return logits.squeeze(1), torch.tensor(0.0, device=logits.device) 
        
        

class CombinedLoss(nn.Module):
    def __init__(self, triplet_weight=0.5):
        super().__init__()
        self.triplet_weight = triplet_weight
        self.ce = nn.CrossEntropyLoss()
        
    def forward(self, logits, labels, triplet_loss):
        ce_loss = self.ce(logits, labels)
        return ce_loss + self.triplet_weight * triplet_loss

class FewShotDataset:
    def __init__(self, root_dir, transform=None, n_way=15, k_shot=15, n_query=5, n_episodes=1000):
        self.dataset = datasets.ImageFolder(root_dir, transform=transform)
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.n_episodes = n_episodes
        
        self.label_to_indices = {}
        for idx, (_, label) in enumerate(self.dataset):
            if label not in self.label_to_indices:
                self.label_to_indices[label] = []
            self.label_to_indices[label].append(idx)
        
        self.valid_classes = [
            cls for cls, indices in self.label_to_indices.items()
            if len(indices) >= self.k_shot + self.n_query
        ]
        
        if len(self.valid_classes) < self.n_way:
            raise ValueError(f"Not enough classes with sufficient samples. Found {len(self.valid_classes)} valid classes, need {self.n_way}")
    
    def __len__(self):
        return self.n_episodes
    
    def __getitem__(self, episode_index):
        selected_classes = np.random.choice(self.valid_classes, self.n_way, replace=False)

        support_images = torch.zeros(self.n_way, self.k_shot, 3, 224, 224)
        query_images = []
        query_labels = []

        for class_idx, class_label in enumerate(selected_classes):
            class_indices = self.label_to_indices[class_label]
            selected_indices = np.random.choice(
                class_indices, 
                self.k_shot + self.n_query, 
                replace=False
            )

            # Support set
            for shot_idx, img_idx in enumerate(selected_indices[:self.k_shot]):
                img, _ = self.dataset[img_idx]
                support_images[class_idx, shot_idx] = img

            # Query set
            for img_idx in selected_indices[self.k_shot:self.k_shot + self.n_query]:
                img, _ = self.dataset[img_idx]
                query_images.append(img)
                query_labels.append(class_idx)

        query_images = torch.stack(query_images)
        query_labels = torch.tensor(query_labels)

        return support_images, query_images, query_labels


class Trainer:
    def __init__(self, model, train_loader, val_loader, test_loader, criterion, optimizer, device, num_epochs=50):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.num_epochs = num_epochs
        self.best_val_f1 = 0
        self.train_metrics = []
        self.val_metrics = []
    
    def calculate_metrics(self, labels, preds, loss=None, n_batches=None):
        accuracy = accuracy_score(labels, preds)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, preds, average='weighted', zero_division=0
        )
        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
        }
        if loss is not None:
            metrics['loss'] = loss / n_batches
        return metrics
    
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        for support_images, query_images, query_labels in tqdm(self.train_loader):
            support_images = support_images.to(self.device)
            query_images = query_images.to(self.device)
            query_labels = query_labels.to(self.device)
            
            logits, triplet_loss = self.model(support_images, query_images, mode='train')
            loss = self.criterion(logits, query_labels, triplet_loss)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(query_labels.cpu().numpy())
            total_loss += loss.item()
        
        return self.calculate_metrics(all_labels, all_preds, total_loss, len(self.train_loader))
    
    def validate(self):
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for support_images, query_images, query_labels in tqdm(self.val_loader):
                support_images = support_images.to(self.device)
                query_images = query_images.to(self.device)
                query_labels = query_labels.to(self.device)
                
                logits = self.model(support_images, query_images, mode='test')
                loss = self.criterion(logits, query_labels, torch.tensor(0.))
                
                preds = logits.argmax(dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(query_labels.cpu().numpy())
                total_loss += loss.item()
        
        return self.calculate_metrics(all_labels, all_preds, total_loss, len(self.val_loader))
    
    def test(self):
        self.model.eval()
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for support_images, query_images, query_labels in tqdm(self.test_loader):
                support_images = support_images.to(self.device)
                query_images = query_images.to(self.device)
                query_labels = query_labels.to(self.device)
                
                logits = self.model(support_images, query_images, mode='test')
                preds = logits.argmax(dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(query_labels.cpu().numpy())
        
        metrics = self.calculate_metrics(all_labels, all_preds)
        self.plot_confusion_matrix(all_labels, all_preds, "Test Set Confusion Matrix")
        return metrics
    
    def plot_confusion_matrix(self, labels, preds, title):
        cm = confusion_matrix(labels, preds)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title(title)
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.show()
    
    def plot_metrics(self):
        epochs = range(1, len(self.train_metrics) + 1)
        metrics = ['accuracy', 'f1', 'loss']
        
        plt.figure(figsize=(15, 5))
        for i, metric in enumerate(metrics):
            plt.subplot(1, 3, i+1)
            train_values = [m[metric] for m in self.train_metrics]
            val_values = [m[metric] for m in self.val_metrics]
            
            plt.plot(epochs, train_values, 'b-', label='Train')
            plt.plot(epochs, val_values, 'r-', label='Validation')
            plt.title(f'{metric.capitalize()} vs Epochs')
            plt.xlabel('Epochs')
            plt.ylabel(metric.capitalize())
            plt.legend()
        
        plt.tight_layout()
        plt.show()
    
    def train(self):
        for epoch in range(self.num_epochs):
            print(f"\nEpoch {epoch+1}/{self.num_epochs}")
            
            train_metrics = self.train_epoch()
            val_metrics = self.validate()
            
            self.train_metrics.append(train_metrics)
            self.val_metrics.append(val_metrics)
            
            print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, F1: {train_metrics['f1']:.4f}")
            print(f"Val   - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1']:.4f}")
            
            if val_metrics['f1'] > self.best_val_f1:
                self.best_val_f1 = val_metrics['f1']
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_f1': self.best_val_f1,
                }, 'best_model.pth')
        
        self.plot_metrics()
        
        checkpoint = torch.load('best_model.pth')
        self.model.load_state_dict(checkpoint['model_state_dict'])
        test_metrics = self.test()
        
        print("\nTest Results:")
        print(f"Accuracy: {test_metrics['accuracy']:.4f}")
        print(f"Precision: {test_metrics['precision']:.4f}")
        print(f"Recall: {test_metrics['recall']:.4f}")
        print(f"F1 Score: {test_metrics['f1']:.4f}")
        
        return test_metrics

 
 
def main():
    # Hyperparameters
    n_way = 15
    k_shot = 10
    n_query = 5
    batch_size = 1  # کاهش batch size
    num_epochs = 50
    learning_rate = 1e-4
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
    
    train_dataset = FewShotDataset(path_data+'/train/', transform, n_way, k_shot, n_query, n_episodes=100)
    val_dataset = FewShotDataset(path_data+'/val/', transform, n_way, k_shot, n_query, n_episodes=50)
    test_dataset = FewShotDataset(path_data+'/test/', transform, n_way, k_shot, n_query, n_episodes=50)
    
    # حذف num_workers و کاهش batch size
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = InnovativeFewShotViT(n_way=n_way, k_shot=k_shot).to(device)
    criterion = CombinedLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        num_epochs=num_epochs
    )
    
    test_metrics = trainer.train()
    return test_metrics
# if __name__ == "__main__":
print("Starting training...")
test_metrics = main()
print("\nTraining completed!")
                                   