In [None]:
from collections import defaultdict  # Add this import at the top of your script

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from transformers import ViTModel
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import sys
sys.path.insert(0,'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')

import torchvision
from dataloaders import get_train_transforms, get_val_transforms, get_triplet_dataloader

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads)
        
    def forward(self, x):
        x = x.transpose(0, 1)
        attn_output, _ = self.mha(x, x, x)
        return attn_output.transpose(0, 1)

class InnovativeFewShotViT(nn.Module):
   def __init__(self, n_way=15, k_shot=15, embed_dim=512, num_heads=8):
       super().__init__()
       self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
       self.embedding = nn.Linear(768, embed_dim)
       self.support_attention = MultiHeadAttention(embed_dim, num_heads)
       self.query_attention = MultiHeadAttention(embed_dim, num_heads)
       self.n_way = n_way
       self.k_shot = k_shot
       self.margin = 1.0
       
       # Meta-learning layers
       self.meta_encoder = nn.Sequential(
           nn.Linear(embed_dim, embed_dim//2),
           nn.ReLU(),
           nn.Linear(embed_dim//2, embed_dim)
       )
       
       # Prototypical Network layers
       self.proto_encoder = nn.Sequential(
           nn.Linear(embed_dim, embed_dim),
           nn.ReLU(),
           nn.Linear(embed_dim, embed_dim)
       )
       
   def forward_one(self, x):
       if len(x.shape) > 4:
           x = x.view(-1, x.size(-3), x.size(-2), x.size(-1))
       elif len(x.shape) == 3:
           x = x.unsqueeze(0)
           
       if x.size(1) != 3:
           x = x.permute(0, 3, 1, 2)
           
       # ViT feature extraction
       x = self.vit(x).last_hidden_state[:, 0]
       embedding = self.embedding(x)
       
       # Meta-learning enhancement
       meta_emb = self.meta_encoder(embedding)
       enhanced_emb = embedding + meta_emb
       
       return F.normalize(enhanced_emb, p=2, dim=1)

   def get_prototypes(self, support_embeddings):
       batch_size = support_embeddings.size(0)
       
       # Self-attention processing
       support_embeddings = support_embeddings.reshape(batch_size, self.n_way * self.k_shot, -1)
       attended_support = self.support_attention(support_embeddings)
       
       # Prototypical Network processing
       prototypes = attended_support.reshape(batch_size, self.n_way, self.k_shot, -1)
       prototypes = self.proto_encoder(prototypes.mean(2))
       return prototypes

   def forward(self, support_set, query, mode='train'):
       batch_size = support_set.size(0)
       
       # Support set processing
       support_embeddings = torch.zeros(batch_size, self.n_way, self.k_shot, 512, device=support_set.device)
       triplet_loss = 0
       
       # Extract embeddings
       for i in range(self.n_way):
           class_embeddings = []
           for j in range(self.k_shot):
               emb = self.forward_one(support_set[:, i, j])
               support_embeddings[:, i, j] = emb
               
               # Triplet loss computation
               if j > 0 and mode == 'train':
                   anchor = emb
                   positive = class_embeddings[-1]
                   neg_class = (i + 1) % self.n_way
                   neg_emb = self.forward_one(support_set[:, neg_class, j])
                   
                   pos_dist = F.pairwise_distance(anchor, positive)
                   neg_dist = F.pairwise_distance(anchor, neg_emb)
                   triplet_loss += F.relu(pos_dist - neg_dist + self.margin).mean()
                   
               class_embeddings.append(emb)

       # Get prototypes
       prototypes = self.get_prototypes(support_embeddings)
       
       # Process queries
       query_embeddings = []
       for idx in range(query.size(0)):
           emb = self.forward_one(query[idx:idx+1])
           query_embeddings.append(emb)
       query_emb = torch.cat(query_embeddings)
       
       # Compute similarity scores
       logits = -torch.cdist(query_emb, prototypes[0])
       
       if mode == 'train':
           return logits, triplet_loss
       return logits, torch.tensor(0.0, device=logits.device)

class TripletMarginWithMetaLoss(nn.Module):
   def __init__(self, margin=1.0, triplet_weight=0.5):
       super().__init__()
       self.margin = margin
       self.triplet_weight = triplet_weight
       self.ce = nn.CrossEntropyLoss()
       
   def forward(self, logits, labels, triplet_loss):
       ce_loss = self.ce(logits, labels)
       total_loss = ce_loss + self.triplet_weight * triplet_loss
       return total_loss
       
class FewShotTrainer:
   def __init__(self, model, train_loader, val_loader, test_loader, criterion, optimizer, 
                device, num_epochs=50, scheduler=None):
       self.model = model
       self.train_loader = train_loader
       self.val_loader = val_loader
       self.test_loader = test_loader
       self.criterion = criterion
       self.optimizer = optimizer
       self.scheduler = scheduler
       self.device = device
       self.num_epochs = num_epochs
       self.best_val_f1 = 0
       self.train_metrics = []
       self.val_metrics = []
       
   def train_epoch(self):
       self.model.train()
       metrics = defaultdict(float)
       
       for batch in tqdm(self.train_loader):
           support_imgs, query_imgs, labels = [x.to(self.device) for x in batch]
           labels = labels.view(-1)
           
           self.optimizer.zero_grad()
           logits, triplet_loss = self.model(support_imgs, query_imgs)
           loss = self.criterion(logits, labels, triplet_loss)
           
           loss.backward()
           self.optimizer.step()
           
           preds = logits.argmax(dim=1)
           metrics['loss'] += loss.item()
           metrics['acc'] += (preds == labels).float().mean().item()
           
       return {k: v/len(self.train_loader) for k,v in metrics.items()}
   
   @torch.no_grad()
   def validate(self, loader):
       self.model.eval()
       metrics = defaultdict(float)
       
       for batch in tqdm(loader):
           support_imgs, query_imgs, labels = [x.to(self.device) for x in batch]
           labels = labels.view(-1)
           
           logits, _ = self.model(support_imgs, query_imgs, mode='val')
           loss = self.criterion(logits, labels, torch.tensor(0.0).to(self.device))
           
           preds = logits.argmax(dim=1)
           metrics['loss'] += loss.item()
           metrics['acc'] += (preds == labels).float().mean().item()
           
       return {k: v/len(loader) for k,v in metrics.items()}
       
   def train(self):
       for epoch in range(self.num_epochs):
           print(f"\nEpoch {epoch+1}/{self.num_epochs}")
           
           train_metrics = self.train_epoch()
           val_metrics = self.validate(self.val_loader)
           
           if self.scheduler:
               self.scheduler.step(val_metrics['loss'])
               
           self.train_metrics.append(train_metrics)
           self.val_metrics.append(val_metrics)
           
           print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['acc']:.4f}")
           print(f"Val - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['acc']:.4f}")
           
           if val_metrics['acc'] > self.best_val_f1:
               self.best_val_f1 = val_metrics['acc']
               torch.save({
                   'epoch': epoch,
                   'model_state_dict': self.model.state_dict(),
                   'optimizer_state_dict': self.optimizer.state_dict(),
                   'best_acc': self.best_val_f1,
               }, 'best_model.pth')
               
       # Test best model
       checkpoint = torch.load('best_model.pth')
       self.model.load_state_dict(checkpoint['model_state_dict'])
       test_metrics = self.validate(self.test_loader)
       
       print("\nTest Results:")
       print(f"Loss: {test_metrics['loss']:.4f}")
       print(f"Accuracy: {test_metrics['acc']:.4f}")
       
       return test_metrics


class FewShotDataset:
    def __init__(self, root_dir, transform=None, n_way=15, k_shot=15, n_query=5, n_episodes=1000):
        self.dataset = datasets.ImageFolder(root_dir, transform=transform)
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.n_episodes = n_episodes
        
        self.label_to_indices = {}
        for idx, (_, label) in enumerate(self.dataset):
            if label not in self.label_to_indices:
                self.label_to_indices[label] = []
            self.label_to_indices[label].append(idx)
        
        self.valid_classes = [
            cls for cls, indices in self.label_to_indices.items()
            if len(indices) >= self.k_shot + self.n_query
        ]
        
        if len(self.valid_classes) < self.n_way:
            raise ValueError(f"Not enough classes with sufficient samples. Found {len(self.valid_classes)} valid classes, need {self.n_way}")
    
    def __len__(self):
        return self.n_episodes
    
    def __getitem__(self, episode_index):
        selected_classes = np.random.choice(self.valid_classes, self.n_way, replace=False)

        support_images = torch.zeros(self.n_way, self.k_shot, 3, 224, 224)
        query_images = []
        query_labels = []

        for class_idx, class_label in enumerate(selected_classes):
            class_indices = self.label_to_indices[class_label]
            selected_indices = np.random.choice(
                class_indices, 
                self.k_shot + self.n_query, 
                replace=False
            )

            # Support set
            for shot_idx, img_idx in enumerate(selected_indices[:self.k_shot]):
                img, _ = self.dataset[img_idx]
                support_images[class_idx, shot_idx] = img

            # Query set
            for img_idx in selected_indices[self.k_shot:self.k_shot + self.n_query]:
                img, _ = self.dataset[img_idx]
                query_images.append(img)
                query_labels.append(class_idx)

        query_images = torch.stack(query_images)
        query_labels = torch.tensor(query_labels)

        return support_images, query_images, query_labels

'''
def main():
   # Hyperparameters
   n_way = 15
   k_shot = 10 
   n_query = 5
   batch_size = 4
   num_epochs = 50
   learning_rate = 3e-4
   
   transform = transforms.Compose([
       transforms.Resize((224, 224)),
       transforms.ToTensor(),
       transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
   ])
   
   # Data loading
   path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
   train_dataset = FewShotDataset(path_data+'/train/', transform, n_way, k_shot, n_query, n_episodes=200)
   val_dataset = FewShotDataset(path_data+'/val/', transform, n_way, k_shot, n_query, n_episodes=100) 
   test_dataset = FewShotDataset(path_data+'/test/', transform, n_way, k_shot, n_query, n_episodes=100)

   train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
   val_loader = DataLoader(val_dataset, batch_size=batch_size)
   test_loader = DataLoader(test_dataset, batch_size=batch_size)
   
   # Model setup
   device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
   model = InnovativeFewShotViT(n_way=n_way, k_shot=k_shot).to(device)
   criterion = TripletMarginWithMetaLoss()
   optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
   scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
   
   trainer = FewShotTrainer(
       model=model,
       train_loader=train_loader, 
       val_loader=val_loader,
       test_loader=test_loader,
       criterion=criterion,
       optimizer=optimizer,
       scheduler=scheduler,
       device=device,
       num_epochs=num_epochs
   )
   
   return trainer.train()
'''
def main():
    # Hyperparameters
    n_way = 15
    k_shot = 10 
    n_query = 5
    batch_size = 2  # کاهش batch size
    num_epochs = 50
    learning_rate = 3e-4
    
    transform = transforms.Compose([
        transforms.Resize((112, 112)),  # کاهش ابعاد تصویر
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Data loading
    path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
    train_dataset = FewShotDataset(path_data+'/train/', transform, n_way, k_shot, n_query, n_episodes=200)
    val_dataset = FewShotDataset(path_data+'/val/', transform, n_way, k_shot, n_query, n_episodes=100) 
    test_dataset = FewShotDataset(path_data+'/test/', transform, n_way, k_shot, n_query, n_episodes=100)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0)
    
    # Model setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = InnovativeFewShotViT(n_way=n_way, k_shot=k_shot).to(device)
    criterion = TripletMarginWithMetaLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    trainer = FewShotTrainer(
        model=model,
        train_loader=train_loader, 
        val_loader=val_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        num_epochs=num_epochs
    )
    
    return trainer.train()

if __name__ == '__main__':
   test_metrics = main()