In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from transformers import ViTModel
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import sys
sys.path.insert(0,'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
import torchvision
from dataloaders import get_train_transforms, get_val_transforms, get_triplet_dataloader

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        
    def forward(self, x):
        attn_output, _ = self.mha(x, x, x)
        return attn_output

class InnovativeFewShotViT(nn.Module):
    def __init__(self, n_way=15, k_shot=15, embed_dim=512, num_heads=8):
        super().__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.embedding = nn.Linear(768, embed_dim)
        self.support_attention = MultiHeadAttention(embed_dim, num_heads)
        self.query_attention = MultiHeadAttention(embed_dim, num_heads)
        self.n_way = n_way
        self.k_shot = k_shot
        self.margin = 1.0
        
        self.vit.gradient_checkpointing_enable()

    @torch.cuda.amp.autocast()
    def forward_one(self, x):
        batch_size = x.size(0)
        if len(x.shape) > 4:
            x = x.reshape(batch_size * x.size(1), 3, 224, 224)
        elif len(x.shape) == 3:
            x = x.unsqueeze(0)

        x = self.vit(x).last_hidden_state[:, 0]
        embedding = self.embedding(x)
        return F.normalize(embedding, p=2, dim=1)

    def get_prototypes(self, support_embeddings):
        support_embeddings = support_embeddings.reshape(support_embeddings.size(0), self.n_way * self.k_shot, -1)
        attended_support = self.support_attention(support_embeddings)
        prototypes = attended_support.reshape(support_embeddings.size(0), self.n_way, self.k_shot, -1).mean(2)
        return prototypes

    @torch.cuda.amp.autocast()
    def forward(self, support_set, query, mode='train'):
        batch_size = support_set.size(0)
        device = support_set.device

        # Process support set
        support_embeddings = []
        for i in range(self.n_way):
            way_embeddings = []
            support_batch = support_set[:, i].reshape(-1, 3, 224, 224)
            emb = self.forward_one(support_batch)
            emb = emb.reshape(batch_size, self.k_shot, -1)
            support_embeddings.append(emb)
        
        support_embeddings = torch.stack(support_embeddings, dim=1)
        prototypes = self.get_prototypes(support_embeddings)

        # Process query set
        query = query.reshape(-1, 3, 224, 224)
        query_emb = self.forward_one(query)

        logits = -torch.cdist(query_emb, prototypes[0])
        return logits, torch.tensor(0.0, device=device)

class CombinedLoss(nn.Module):
    def __init__(self, triplet_weight=0.5):
        super().__init__()
        self.triplet_weight = triplet_weight
        self.ce = nn.CrossEntropyLoss()
        
    def forward(self, logits, labels, triplet_loss):
        ce_loss = self.ce(logits, labels)
        return ce_loss + self.triplet_weight * triplet_loss

class FewShotDataset:
    def __init__(self, root_dir, transform=None, n_way=15, k_shot=15, n_query=5, n_episodes=1000):
        self.dataset = datasets.ImageFolder(root_dir, transform=transform)
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.n_episodes = n_episodes
        
        self.label_to_indices = {}
        for idx, (_, label) in enumerate(self.dataset):
            if label not in self.label_to_indices:
                self.label_to_indices[label] = []
            self.label_to_indices[label].append(idx)
        
        self.valid_classes = [
            cls for cls, indices in self.label_to_indices.items()
            if len(indices) >= self.k_shot + self.n_query
        ]
        
        if len(self.valid_classes) < self.n_way:
            raise ValueError(f"Not enough classes with sufficient samples. Found {len(self.valid_classes)} valid classes, need {self.n_way}")
    
    def __len__(self):
        return self.n_episodes
    
    def __getitem__(self, episode_index):
        selected_classes = np.random.choice(self.valid_classes, self.n_way, replace=False)

        support_images = torch.zeros(self.n_way, self.k_shot, 3, 224, 224)
        query_images = []
        query_labels = []

        for class_idx, class_label in enumerate(selected_classes):
            class_indices = self.label_to_indices[class_label]
            selected_indices = np.random.choice(
                class_indices, 
                self.k_shot + self.n_query, 
                replace=False
            )

            # Support set
            for shot_idx, img_idx in enumerate(selected_indices[:self.k_shot]):
                img, _ = self.dataset[img_idx]
                support_images[class_idx, shot_idx] = img

            # Query set
            for img_idx in selected_indices[self.k_shot:self.k_shot + self.n_query]:
                img, _ = self.dataset[img_idx]
                query_images.append(img)
                query_labels.append(class_idx)

        query_images = torch.stack(query_images)
        query_labels = torch.tensor(query_labels)

        return support_images, query_images, query_labels

class Trainer:
    def __init__(self, model, train_loader, val_loader, test_loader, criterion, optimizer, device, num_epochs=50):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.num_epochs = num_epochs
        self.best_val_f1 = 0
        self.train_metrics = []
        self.val_metrics = []
        self.scaler = torch.cuda.amp.GradScaler()

    def calculate_metrics(self, labels, preds, loss=None, n_batches=None):
        accuracy = accuracy_score(labels, preds)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, preds, average='weighted', zero_division=0
        )
        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
        }
        if loss is not None:
            metrics['loss'] = loss / n_batches
        return metrics
    
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        all_preds = []
        all_labels = []

        for support_images, query_images, query_labels in tqdm(self.train_loader):
            support_images = support_images.to(self.device)
            query_images = query_images.to(self.device)
            query_labels = query_labels.to(self.device).view(-1)

            self.optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():
                logits, triplet_loss = self.model(support_images, query_images, mode='train')
                loss = self.criterion(logits, query_labels, triplet_loss)

            self.scaler.scale(loss).backward()
            
            # Gradient clipping
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.scaler.step(self.optimizer)
            self.scaler.update()

            with torch.no_grad():
                preds = logits.argmax(dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(query_labels.cpu().numpy())
                total_loss += loss.item()

        return self.calculate_metrics(all_labels, all_preds, total_loss, len(self.train_loader))

    def validate(self):
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for support_images, query_images, query_labels in tqdm(self.val_loader):
                support_images = support_images.to(self.device)
                query_images = query_images.to(self.device)
                query_labels = query_labels.to(self.device).view(-1)  # Flatten labels
                
                logits, triplet_loss = self.model(support_images, query_images, mode='test')
                loss = self.criterion(logits, query_labels, triplet_loss)
                
                preds = logits.argmax(dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(query_labels.cpu().numpy())
                total_loss += loss.item()
        
        return self.calculate_metrics(all_labels, all_preds, total_loss, len(self.val_loader))

    def test(self):
        self.model.eval()
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for support_images, query_images, query_labels in tqdm(self.test_loader):
                support_images = support_images.to(self.device)
                query_images = query_images.to(self.device)
                query_labels = query_labels.to(self.device)
                
                logits, _ = self.model(support_images, query_images, mode='test')
                preds = logits.argmax(dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(query_labels.cpu().numpy())
        
        metrics = self.calculate_metrics(all_labels, all_preds)
        self.plot_confusion_matrix(all_labels, all_preds, "Test Set Confusion Matrix")
        return metrics

    def plot_confusion_matrix(self, labels, preds, title):
        cm = confusion_matrix(labels, preds)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title(title)
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.show()
    
    def plot_metrics(self):
        epochs = range(1, len(self.train_metrics) + 1)
        metrics = ['accuracy', 'f1', 'loss']
        
        plt.figure(figsize=(15, 5))
        for i, metric in enumerate(metrics):
            plt.subplot(1, 3, i+1)
            train_values = [m[metric] for m in self.train_metrics]
            val_values = [m[metric] for m in self.val_metrics]
            
            plt.plot(epochs, train_values, 'b-', label='Train')
            plt.plot(epochs, val_values, 'r-', label='Validation')
            plt.title(f'{metric.capitalize()} vs Epochs')
            plt.xlabel('Epochs')
            plt.ylabel(metric.capitalize())
            plt.legend()
        
        plt.tight_layout()
        plt.show()
    
    def train(self):
        for epoch in range(self.num_epochs):
            print(f"\nEpoch {epoch+1}/{self.num_epochs}")
            
            train_metrics = self.train_epoch()
            val_metrics = self.validate()
            
            self.train_metrics.append(train_metrics)
            self.val_metrics.append(val_metrics)
            
            print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, F1: {train_metrics['f1']:.4f}")
            print(f"Val   - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1']:.4f}")
            
            if val_metrics['f1'] > self.best_val_f1:
                self.best_val_f1 = val_metrics['f1']
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_f1': self.best_val_f1,
                }, 'best_model.pth')
        
        self.plot_metrics()
        
        checkpoint = torch.load('best_model.pth')
        self.model.load_state_dict(checkpoint['model_state_dict'])
        test_metrics = self.test()
        
        print("\nTest Results:")
        print(f"Accuracy: {test_metrics['accuracy']:.4f}")
        print(f"Precision: {test_metrics['precision']:.4f}")
        print(f"Recall: {test_metrics['recall']:.4f}")
        print(f"F1 Score: {test_metrics['f1']:.4f}")
        
        return test_metrics

def main():
    n_way = 15
    k_shot = 10
    n_query = 5
    batch_size = 2
    num_epochs = 50
    learning_rate = 1e-4
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
    
    val_dataset = FewShotDataset(path_data+'/val/', transform, n_way, k_shot, n_query, n_episodes=50)
    test_dataset = FewShotDataset(path_data+'/test/', transform, n_way, k_shot, n_query, n_episodes=50)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True
    
    model = InnovativeFewShotViT(n_way=n_way, k_shot=k_shot).to(device)
    criterion = CombinedLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        num_epochs=num_epochs
    )
    
    test_metrics = trainer.train()
    return test_metrics

if __name__ == "__main__":
    print("Starting training...")
    test_metrics = main()
    print("\nTraining completed!")