In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50
from torchvision.models.vgg import vgg16_bn
from torchvision.models.densenet import densenet121
from torchvision.models.mobilenet import mobilenet_v2
from efficientnet_pytorch import EfficientNet
from transformers import ViTModel, ViTFeatureExtractor
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from learn2learn.data import MetaDataset
import random
from PIL import Image
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn


class PatchEmbedding(nn.Module):
    def __init__(self, patch_size=16, embed_dim=128):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.embedding_network = nn.Sequential(
            nn.Linear(patch_size * patch_size * 3, 256),
            nn.ReLU(),
            nn.Linear(256, embed_dim),
            nn.LayerNorm(embed_dim)
        )
        
    def forward(self, x):
        B, C, H, W = x.shape
        # Split image into patches
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.permute(0, 2, 3, 1, 4, 5)  # (B, H/16, W/16, C, patch_size, patch_size)
        
        # Flatten patches while keeping spatial information
        patch_vectors = patches.reshape(B, H//self.patch_size, W//self.patch_size, -1)
        
        # Create embeddings for each patch
        embeddings = self.embedding_network(patch_vectors)  # (B, 14, 14, embed_dim)
        
        return embeddings

class EfficientNetPatchB4(nn.Module):
    def __init__(self, patch_size=16, embed_dim=128):
        super().__init__()
        self.efficientnet = EfficientNet.from_pretrained('efficientnet-b4')
        self.patch_embed = PatchEmbedding(patch_size, embed_dim)
        
    def forward(self, x):
        # Generate patch embeddings with spatial information preserved
        patch_embeddings = self.patch_embed(x)  # (batch_size, 14, 14, embed_dim)
        patch_embeddings = F.normalize(patch_embeddings, p=2, dim=-1)
        return patch_embeddings

class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super().__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2, x3):
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)
        output3 = self.embedding_net(x3)
        return output1, output2, output3

    def get_embedding(self, x):
        return self.embedding_net(x)

class SpatialTripletLoss(nn.Module):
    def __init__(self, margin):
        super().__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(dim=-1)  # Sum over embedding dimension
        distance_negative = (anchor - negative).pow(2).sum(dim=-1)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

class TripletDataset(Dataset):
    def __init__(self, root=None, transforms=None):
        img_folder = ImageFolder(root=root)
        meta_data = MetaDataset(img_folder)
        
        self.img_list = img_folder.imgs
        self.labels_to_indices = meta_data.labels_to_indices
        self.indices_to_labels = meta_data.indices_to_labels
        self.labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
        self.transforms = transforms

    def __getitem__(self, index):
        img_anchor = self.img_list[index][0]
        label_anchor = self.img_list[index][1]

        idx_positive = random.choice(self.labels_to_indices[label_anchor])            
        img_positive = self.img_list[idx_positive][0]
        
        aux_class = random.choice(list(set(self.labels) - set([label_anchor])))
        idx_negative = random.choice(self.labels_to_indices[aux_class])
        img_negative = self.img_list[idx_negative][0]

        img_anchor = Image.open(img_anchor).convert('RGB')
        img_positive = Image.open(img_positive).convert('RGB')
        img_negative = Image.open(img_negative).convert('RGB')

        if self.transforms is not None:
            img_anchor = self.transforms(img_anchor)
            img_positive = self.transforms(img_positive)
            img_negative = self.transforms(img_negative)

        return (img_anchor, img_positive, img_negative), []

    def __len__(self):
        return len(self.img_list)

def get_triplet_dataloader(root=None, batch_size=1, transforms=None):
    dataset = TripletDataset(root=root, transforms=transforms)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def get_train_transforms():
    return T.Compose([
        T.Resize((224, 224)),
        T.RandomHorizontalFlip(0.5),
        T.RandomVerticalFlip(0.5),
        T.RandomApply([T.RandomRotation(10)], 0.25),
        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def get_val_transforms():
    return T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def fit(train_loader, val_loader, model, loss_fn, optimizer, scheduler, n_epochs, device, log_interval=100):
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            img_a, img_p, img_n = data
            img_a, img_p, img_n = img_a.to(device), img_p.to(device), img_n.to(device)
            
            optimizer.zero_grad()
            embeddings = model(img_a, img_p, img_n)
            loss = loss_fn(*embeddings)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % log_interval == 0:
                print(f'Epoch {epoch}: [{batch_idx * len(img_a)}/{len(train_loader.dataset)}] '
                      f'Loss: {total_loss / (batch_idx + 1):.6f}')
        
        scheduler.step()
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data, _ in val_loader:
                img_a, img_p, img_n = data
                img_a, img_p, img_n = img_a.to(device), img_p.to(device), img_n.to(device)
                embeddings = model(img_a, img_p, img_n)
                val_loss += loss_fn(*embeddings).item()
        
        val_loss /= len(val_loader)
        print(f'Epoch {epoch}: Validation Loss: {val_loss:.6f}')

# Initialize model and training
path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
embedding_net = EfficientNetPatchB4(patch_size=16, embed_dim=128)
triplet_model = TripletNet(embedding_net=embedding_net)
loss_fn = SpatialTripletLoss(margin=1.0)
optimizer = torch.optim.Adam(triplet_model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    triplet_model = triplet_model.cuda()

# Create dataloaders
triplet_train_loader = get_triplet_dataloader(
    root=path_data + '/train/',
    batch_size=5,
    transforms=get_train_transforms()
)
triplet_val_loader = get_triplet_dataloader(
    root=path_data + '/val/',
    batch_size=5,
    transforms=get_val_transforms()
)

# Train
n_epochs = 100
fit(triplet_train_loader, triplet_val_loader, triplet_model, loss_fn, optimizer, scheduler, n_epochs, device)
torch.save(triplet_model, "f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/Patch_emdebbing_generator_triplet_model.h5" )

Loaded pretrained weights for efficientnet-b4
Epoch 0: [0/225] Loss: 0.842478
Epoch 0: Validation Loss: 0.928688
Epoch 1: [0/225] Loss: 0.998621
Epoch 1: Validation Loss: 0.929748
Epoch 2: [0/225] Loss: 1.019156
Epoch 2: Validation Loss: 0.922473
Epoch 3: [0/225] Loss: 0.990100
Epoch 3: Validation Loss: 0.927282
Epoch 4: [0/225] Loss: 0.846351
Epoch 4: Validation Loss: 0.922718
Epoch 5: [0/225] Loss: 0.890962
Epoch 5: Validation Loss: 0.910829
Epoch 6: [0/225] Loss: 0.899605
Epoch 6: Validation Loss: 0.909570
Epoch 7: [0/225] Loss: 0.859286
Epoch 7: Validation Loss: 0.910903
Epoch 8: [0/225] Loss: 0.922489
Epoch 8: Validation Loss: 0.902706
Epoch 9: [0/225] Loss: 0.900027
Epoch 9: Validation Loss: 0.904689
Epoch 10: [0/225] Loss: 0.905161
Epoch 10: Validation Loss: 0.902101
Epoch 11: [0/225] Loss: 0.946863
Epoch 11: Validation Loss: 0.895931
Epoch 12: [0/225] Loss: 1.061231
Epoch 12: Validation Loss: 0.900739
Epoch 13: [0/225] Loss: 0.892946
Epoch 13: Validation Loss: 0.900799
Epoch 14

In [5]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from datetime import datetime
from torch.utils.data import DataLoader
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys
#sys.path.insert(0,'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
sys.path.insert(0,'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
from dataloaders import get_val_transforms
import numpy as np
from sklearn.manifold import TSNE
import matplotlib as mpl
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import transforms
from torch.autograd import Variable
import os
import pandas as pd
import seaborn as sns
import dataloaders
from dataloaders import get_train_transforms, get_val_transforms, get_triplet_dataloader
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch
from sklearn.metrics import accuracy_score, f1_score , precision_score , recall_score
from sklearn.manifold import TSNE
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from datetime import datetime

class RefinedViT(nn.Module):
    def __init__(self, num_classes=15):
        super(RefinedViT, self).__init__()
        self.num_patches = 196  # 14x14
        self.embed_dim = 128    # Match previous network
        self.num_heads = 8      
        
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim,
            nhead=self.num_heads,
            dim_feedforward=512,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=10)
        
        self.norm = nn.LayerNorm(self.embed_dim)
        self.fc = nn.Linear(self.embed_dim, num_classes)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.reshape(batch_size, 196, -1)
        
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        
        x = self.transformer_encoder(x)
        x = x[:, 0]
        x = self.norm(x)
        x = self.dropout(x)
        return self.fc(x)

def generate_embeddings(data_loader, model, device):
    model.eval()
    all_embeddings = []
    all_labels = []
    
    with torch.no_grad():
        for batch_imgs, batch_labels in tqdm(data_loader):
            batch_imgs = batch_imgs.to(device)
            embeddings = model.get_embedding(batch_imgs)  # Using get_embedding from triplet model
            print(f"Size of embeddings: {embeddings.shape}")
            all_embeddings.append(embeddings.cpu().numpy())
            all_labels.append(batch_labels.numpy())
    
    return np.concatenate(all_embeddings), np.concatenate(all_labels)

def train_and_validate(model, train_embeddings, train_labels, val_embeddings, val_labels, 
                      num_epochs=500, batch_size=32, learning_rate=1e-4):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    X_train = torch.tensor(train_embeddings, dtype=torch.float32).to(device)
    y_train = torch.tensor(train_labels, dtype=torch.long).to(device)
    X_val = torch.tensor(val_embeddings, dtype=torch.float32).to(device)
    y_val = torch.tensor(val_labels, dtype=torch.long).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    
    for epoch in tqdm(range(num_epochs)):
        model.train()
        total_train_loss = 0
        
        for i in range(0, len(X_train), batch_size):
            batch_X = X_train[i:i+batch_size]
            batch_y = y_train[i:i+batch_size]
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
        
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val)
            val_loss = criterion(val_outputs, y_val).item()
            val_acc = (torch.argmax(val_outputs, dim=1) == y_val).float().mean().item()
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), 'best_model_refined_vit.pth')
        
        scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}: Train Loss={total_train_loss/len(X_train):.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}')
    
    return model

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    path_data =  'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
    
    # Load triplet model
    #triplet_model = torch.load('C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/Patch_emdebbing_generator_triplet_model.h5', map_location=device)
    triplet_model.eval()
    
    # Data loaders
    transform = get_val_transforms()
    train_data = torchvision.datasets.ImageFolder(root=f'{path_data}/train/', transform=transform)
    val_data = torchvision.datasets.ImageFolder(root=f'{path_data}/val/', transform=transform)
    test_data = torchvision.datasets.ImageFolder(root=f'{path_data}/test/', transform=transform)
    
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    test_loader = DataLoader(test_data, batch_size=32)
    
    # Generate embeddings
    train_embeddings, train_labels = generate_embeddings(train_loader, triplet_model, device)
    val_embeddings, val_labels = generate_embeddings(val_loader, triplet_model, device)
    test_embeddings, test_labels = generate_embeddings(test_loader, triplet_model, device)
    
    # Train RefinedViT
    model = RefinedViT(num_classes=15)
    trained_model = train_and_validate(model, train_embeddings, train_labels, val_embeddings, val_labels)
    
    return trained_model, (test_embeddings, test_labels)


model, test_data = main()  
torch.save(model, "f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/Patch_Rvit_triplet_model.h5" )

 38%|███████████████████████████████▌                                                    | 3/8 [00:00<00:00, 10.19it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 75%|███████████████████████████████████████████████████████████████                     | 6/8 [00:00<00:00, 10.83it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 11.54it/s]


Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([1, 14, 14, 128])


  3%|██▌                                                                                | 2/65 [00:00<00:05, 11.76it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


  6%|█████                                                                              | 4/65 [00:00<00:05, 10.76it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


  9%|███████▋                                                                           | 6/65 [00:00<00:05, 11.03it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 12%|██████████▏                                                                        | 8/65 [00:00<00:05, 10.64it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 15%|████████████▌                                                                     | 10/65 [00:00<00:05, 10.81it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 22%|█████████████████▋                                                                | 14/65 [00:01<00:04, 10.93it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 25%|████████████████████▏                                                             | 16/65 [00:01<00:04, 10.69it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 31%|█████████████████████████▏                                                        | 20/65 [00:01<00:04, 10.75it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 34%|███████████████████████████▊                                                      | 22/65 [00:02<00:04, 10.73it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 37%|██████████████████████████████▎                                                   | 24/65 [00:02<00:03, 10.88it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 40%|████████████████████████████████▊                                                 | 26/65 [00:02<00:03, 10.72it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 43%|███████████████████████████████████▎                                              | 28/65 [00:02<00:03, 10.80it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 46%|█████████████████████████████████████▊                                            | 30/65 [00:02<00:03, 10.87it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 52%|██████████████████████████████████████████▉                                       | 34/65 [00:03<00:03,  9.16it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 55%|█████████████████████████████████████████████▍                                    | 36/65 [00:03<00:02,  9.87it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 62%|██████████████████████████████████████████████████▍                               | 40/65 [00:03<00:02, 10.87it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 65%|████████████████████████████████████████████████████▉                             | 42/65 [00:03<00:02, 11.05it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 71%|██████████████████████████████████████████████████████████                        | 46/65 [00:04<00:01, 11.53it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 74%|████████████████████████████████████████████████████████████▌                     | 48/65 [00:04<00:01, 11.50it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 80%|█████████████████████████████████████████████████████████████████▌                | 52/65 [00:04<00:01, 11.15it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 83%|████████████████████████████████████████████████████████████████████              | 54/65 [00:05<00:00, 11.48it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 89%|█████████████████████████████████████████████████████████████████████████▏        | 58/65 [00:05<00:00, 11.86it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 92%|███████████████████████████████████████████████████████████████████████████▋      | 60/65 [00:05<00:00, 11.85it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 95%|██████████████████████████████████████████████████████████████████████████████▏   | 62/65 [00:05<00:00, 11.79it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


100%|██████████████████████████████████████████████████████████████████████████████████| 65/65 [00:05<00:00, 10.87it/s]


Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([28, 14, 14, 128])


  0%|                                                                                          | 0/129 [00:00<?, ?it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  2%|█▎                                                                                | 2/129 [00:00<00:10, 11.70it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


  3%|██▌                                                                               | 4/129 [00:00<00:10, 12.33it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  5%|███▊                                                                              | 6/129 [00:00<00:09, 12.59it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


  6%|█████                                                                             | 8/129 [00:00<00:09, 12.84it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


  8%|██████▎                                                                          | 10/129 [00:00<00:09, 12.37it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  9%|███████▌                                                                         | 12/129 [00:00<00:09, 11.91it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 11%|████████▊                                                                        | 14/129 [00:01<00:09, 11.91it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 12%|██████████                                                                       | 16/129 [00:01<00:09, 11.73it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 14%|███████████▎                                                                     | 18/129 [00:01<00:09, 11.74it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 16%|████████████▌                                                                    | 20/129 [00:01<00:09, 11.83it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 17%|█████████████▊                                                                   | 22/129 [00:01<00:09, 11.77it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 19%|███████████████                                                                  | 24/129 [00:02<00:08, 11.92it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 20%|████████████████▎                                                                | 26/129 [00:02<00:08, 11.48it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 23%|██████████████████▊                                                              | 30/129 [00:02<00:08, 11.52it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 25%|████████████████████                                                             | 32/129 [00:02<00:08, 11.17it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 28%|██████████████████████▌                                                          | 36/129 [00:03<00:09,  9.48it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 31%|█████████████████████████                                                        | 40/129 [00:03<00:08, 10.41it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 33%|██████████████████████████▎                                                      | 42/129 [00:03<00:08, 10.87it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 36%|████████████████████████████▉                                                    | 46/129 [00:04<00:06, 12.11it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 37%|██████████████████████████████▏                                                  | 48/129 [00:04<00:06, 12.48it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 40%|████████████████████████████████▋                                                | 52/129 [00:04<00:06, 12.67it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 43%|███████████████████████████████████▏                                             | 56/129 [00:04<00:05, 12.86it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 45%|████████████████████████████████████▍                                            | 58/129 [00:04<00:05, 13.15it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 48%|██████████████████████████████████████▉                                          | 62/129 [00:05<00:04, 13.42it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 50%|████████████████████████████████████████▏                                        | 64/129 [00:05<00:04, 13.78it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 53%|██████████████████████████████████████████▋                                      | 68/129 [00:05<00:04, 12.72it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 54%|███████████████████████████████████████████▉                                     | 70/129 [00:05<00:04, 12.80it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 57%|██████████████████████████████████████████████▍                                  | 74/129 [00:06<00:04, 11.88it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 59%|███████████████████████████████████████████████▋                                 | 76/129 [00:06<00:04, 12.06it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 62%|██████████████████████████████████████████████████▏                              | 80/129 [00:06<00:04, 12.14it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 64%|███████████████████████████████████████████████████▍                             | 82/129 [00:06<00:04, 11.46it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 67%|██████████████████████████████████████████████████████                           | 86/129 [00:07<00:03, 12.13it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 68%|███████████████████████████████████████████████████████▎                         | 88/129 [00:07<00:03, 11.64it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 70%|████████████████████████████████████████████████████████▌                        | 90/129 [00:07<00:03, 11.32it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 73%|███████████████████████████████████████████████████████████                      | 94/129 [00:07<00:03, 11.13it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 74%|████████████████████████████████████████████████████████████▎                    | 96/129 [00:08<00:03, 10.58it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 78%|██████████████████████████████████████████████████████████████                  | 100/129 [00:08<00:02, 10.74it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 79%|███████████████████████████████████████████████████████████████▎                | 102/129 [00:08<00:02, 10.85it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 81%|████████████████████████████████████████████████████████████████▍               | 104/129 [00:08<00:02, 10.84it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 84%|██████████████████████████████████████████████████████████████████▉             | 108/129 [00:09<00:01, 10.61it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 85%|████████████████████████████████████████████████████████████████████▏           | 110/129 [00:09<00:01, 10.06it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 87%|█████████████████████████████████████████████████████████████████████▍          | 112/129 [00:09<00:01, 10.07it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 90%|███████████████████████████████████████████████████████████████████████▉        | 116/129 [00:10<00:01, 10.36it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 91%|█████████████████████████████████████████████████████████████████████████▏      | 118/129 [00:10<00:01,  9.83it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 94%|███████████████████████████████████████████████████████████████████████████     | 121/129 [00:10<00:00, 10.21it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 95%|████████████████████████████████████████████████████████████████████████████▎   | 123/129 [00:10<00:00,  9.57it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


 96%|████████████████████████████████████████████████████████████████████████████▉   | 124/129 [00:10<00:00,  9.48it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])


100%|████████████████████████████████████████████████████████████████████████████████| 129/129 [00:11<00:00, 11.25it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([26, 14, 14, 128])



  2%|█▌                                                                               | 10/500 [00:06<05:31,  1.48it/s]

Epoch 10: Train Loss=0.0479, Val Loss=1.6531, Val Acc=0.5010


  4%|███▏                                                                             | 20/500 [00:13<05:25,  1.47it/s]

Epoch 20: Train Loss=0.0300, Val Loss=1.3057, Val Acc=0.5934


  6%|████▊                                                                            | 30/500 [00:20<05:12,  1.50it/s]

Epoch 30: Train Loss=0.0173, Val Loss=1.2359, Val Acc=0.6329


  8%|██████▍                                                                          | 40/500 [00:27<05:20,  1.44it/s]

Epoch 40: Train Loss=0.0079, Val Loss=1.0680, Val Acc=0.6744


 10%|████████                                                                         | 50/500 [00:33<04:56,  1.52it/s]

Epoch 50: Train Loss=0.0039, Val Loss=1.1141, Val Acc=0.6850


 12%|█████████▋                                                                       | 60/500 [00:40<04:48,  1.52it/s]

Epoch 60: Train Loss=0.0026, Val Loss=1.0961, Val Acc=0.6845


 14%|███████████▎                                                                     | 70/500 [00:47<04:43,  1.52it/s]

Epoch 70: Train Loss=0.0017, Val Loss=1.1347, Val Acc=0.6879


 16%|████████████▉                                                                    | 80/500 [00:53<04:35,  1.53it/s]

Epoch 80: Train Loss=0.0013, Val Loss=1.1421, Val Acc=0.6941


 18%|██████████████▌                                                                  | 90/500 [01:00<04:29,  1.52it/s]

Epoch 90: Train Loss=0.0010, Val Loss=1.1501, Val Acc=0.6975


 20%|████████████████                                                                | 100/500 [01:06<04:23,  1.52it/s]

Epoch 100: Train Loss=0.0008, Val Loss=1.1678, Val Acc=0.6951


 22%|█████████████████▌                                                              | 110/500 [01:13<04:16,  1.52it/s]

Epoch 110: Train Loss=0.0007, Val Loss=1.1871, Val Acc=0.6936


 24%|███████████████████▏                                                            | 120/500 [01:20<04:09,  1.53it/s]

Epoch 120: Train Loss=0.0006, Val Loss=1.2032, Val Acc=0.6946


 26%|████████████████████▊                                                           | 130/500 [01:26<04:08,  1.49it/s]

Epoch 130: Train Loss=0.0006, Val Loss=1.2109, Val Acc=0.6951


 28%|██████████████████████▍                                                         | 140/500 [01:33<03:56,  1.52it/s]

Epoch 140: Train Loss=0.0005, Val Loss=1.2331, Val Acc=0.6908


 30%|████████████████████████                                                        | 150/500 [01:39<03:52,  1.51it/s]

Epoch 150: Train Loss=0.0004, Val Loss=1.2364, Val Acc=0.6932


 32%|█████████████████████████▌                                                      | 160/500 [01:46<03:44,  1.52it/s]

Epoch 160: Train Loss=0.0004, Val Loss=1.2507, Val Acc=0.6917


 34%|███████████████████████████▏                                                    | 170/500 [01:53<03:37,  1.52it/s]

Epoch 170: Train Loss=0.0004, Val Loss=1.2644, Val Acc=0.6936


 36%|████████████████████████████▊                                                   | 180/500 [01:59<03:31,  1.51it/s]

Epoch 180: Train Loss=0.0003, Val Loss=1.2763, Val Acc=0.6936


 38%|██████████████████████████████▍                                                 | 190/500 [02:06<03:23,  1.53it/s]

Epoch 190: Train Loss=0.0003, Val Loss=1.2927, Val Acc=0.6912


 40%|████████████████████████████████                                                | 200/500 [02:13<03:24,  1.47it/s]

Epoch 200: Train Loss=0.0003, Val Loss=1.3007, Val Acc=0.6898


 42%|█████████████████████████████████▌                                              | 210/500 [02:19<03:13,  1.50it/s]

Epoch 210: Train Loss=0.0003, Val Loss=1.3003, Val Acc=0.6946


 44%|███████████████████████████████████▏                                            | 220/500 [02:26<03:06,  1.50it/s]

Epoch 220: Train Loss=0.0003, Val Loss=1.3110, Val Acc=0.6903


 46%|████████████████████████████████████▊                                           | 230/500 [02:33<03:00,  1.49it/s]

Epoch 230: Train Loss=0.0003, Val Loss=1.3125, Val Acc=0.6946


 48%|██████████████████████████████████████▍                                         | 240/500 [02:39<02:56,  1.47it/s]

Epoch 240: Train Loss=0.0002, Val Loss=1.3284, Val Acc=0.6903


 50%|████████████████████████████████████████                                        | 250/500 [02:46<02:45,  1.51it/s]

Epoch 250: Train Loss=0.0002, Val Loss=1.3396, Val Acc=0.6922


 52%|█████████████████████████████████████████▌                                      | 260/500 [02:53<02:38,  1.52it/s]

Epoch 260: Train Loss=0.0002, Val Loss=1.3369, Val Acc=0.6956


 54%|███████████████████████████████████████████▏                                    | 270/500 [02:59<02:31,  1.52it/s]

Epoch 270: Train Loss=0.0002, Val Loss=1.3421, Val Acc=0.6922


 56%|████████████████████████████████████████████▊                                   | 280/500 [03:06<02:25,  1.52it/s]

Epoch 280: Train Loss=0.0002, Val Loss=1.3561, Val Acc=0.6932


 58%|██████████████████████████████████████████████▍                                 | 290/500 [03:13<02:18,  1.52it/s]

Epoch 290: Train Loss=0.0002, Val Loss=1.3559, Val Acc=0.6922


 60%|████████████████████████████████████████████████                                | 300/500 [03:19<02:11,  1.52it/s]

Epoch 300: Train Loss=0.0002, Val Loss=1.3576, Val Acc=0.6956


 62%|█████████████████████████████████████████████████▌                              | 310/500 [03:26<02:05,  1.51it/s]

Epoch 310: Train Loss=0.0002, Val Loss=1.3721, Val Acc=0.6946


 64%|███████████████████████████████████████████████████▏                            | 320/500 [03:32<01:58,  1.51it/s]

Epoch 320: Train Loss=0.0002, Val Loss=1.3654, Val Acc=0.6936


 66%|████████████████████████████████████████████████████▊                           | 330/500 [03:39<01:52,  1.51it/s]

Epoch 330: Train Loss=0.0002, Val Loss=1.3740, Val Acc=0.6946


 68%|██████████████████████████████████████████████████████▍                         | 340/500 [03:46<01:45,  1.52it/s]

Epoch 340: Train Loss=0.0002, Val Loss=1.3830, Val Acc=0.6932


 70%|████████████████████████████████████████████████████████                        | 350/500 [03:52<01:38,  1.52it/s]

Epoch 350: Train Loss=0.0002, Val Loss=1.3923, Val Acc=0.6922


 72%|█████████████████████████████████████████████████████████▌                      | 360/500 [03:59<01:32,  1.52it/s]

Epoch 360: Train Loss=0.0002, Val Loss=1.3914, Val Acc=0.6898


 74%|███████████████████████████████████████████████████████████▏                    | 370/500 [04:06<01:26,  1.50it/s]

Epoch 370: Train Loss=0.0002, Val Loss=1.3894, Val Acc=0.6927


 76%|████████████████████████████████████████████████████████████▊                   | 380/500 [04:12<01:21,  1.48it/s]

Epoch 380: Train Loss=0.0002, Val Loss=1.3943, Val Acc=0.6917


 78%|██████████████████████████████████████████████████████████████▍                 | 390/500 [04:19<01:12,  1.52it/s]

Epoch 390: Train Loss=0.0002, Val Loss=1.3931, Val Acc=0.6951


 80%|████████████████████████████████████████████████████████████████                | 400/500 [04:25<01:05,  1.52it/s]

Epoch 400: Train Loss=0.0001, Val Loss=1.3957, Val Acc=0.6922


 82%|█████████████████████████████████████████████████████████████████▌              | 410/500 [04:32<01:00,  1.50it/s]

Epoch 410: Train Loss=0.0002, Val Loss=1.3993, Val Acc=0.6941


 84%|███████████████████████████████████████████████████████████████████▏            | 420/500 [04:39<00:53,  1.50it/s]

Epoch 420: Train Loss=0.0001, Val Loss=1.4007, Val Acc=0.6917


 86%|████████████████████████████████████████████████████████████████████▊           | 430/500 [04:45<00:47,  1.48it/s]

Epoch 430: Train Loss=0.0001, Val Loss=1.4003, Val Acc=0.6927


 88%|██████████████████████████████████████████████████████████████████████▍         | 440/500 [04:52<00:39,  1.52it/s]

Epoch 440: Train Loss=0.0002, Val Loss=1.4022, Val Acc=0.6941


 90%|████████████████████████████████████████████████████████████████████████        | 450/500 [04:59<00:32,  1.52it/s]

Epoch 450: Train Loss=0.0001, Val Loss=1.4014, Val Acc=0.6932


 92%|█████████████████████████████████████████████████████████████████████████▌      | 460/500 [05:05<00:26,  1.52it/s]

Epoch 460: Train Loss=0.0001, Val Loss=1.4034, Val Acc=0.6932


 94%|███████████████████████████████████████████████████████████████████████████▏    | 470/500 [05:12<00:19,  1.52it/s]

Epoch 470: Train Loss=0.0002, Val Loss=1.4043, Val Acc=0.6932


 96%|████████████████████████████████████████████████████████████████████████████▊   | 480/500 [05:18<00:13,  1.53it/s]

Epoch 480: Train Loss=0.0001, Val Loss=1.4053, Val Acc=0.6927


 98%|██████████████████████████████████████████████████████████████████████████████▍ | 490/500 [05:25<00:06,  1.52it/s]

Epoch 490: Train Loss=0.0002, Val Loss=1.4054, Val Acc=0.6927


100%|████████████████████████████████████████████████████████████████████████████████| 500/500 [05:32<00:00,  1.51it/s]

Epoch 500: Train Loss=0.0002, Val Loss=1.4054, Val Acc=0.6932





In [None]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import AdamW
from torchvision.datasets import ImageFolder

# Improved RefinedViT Model
class RefinedViT(nn.Module):
    def __init__(self, num_classes=15):
        super(RefinedViT, self).__init__()
        self.num_patches = 196  # 14x14
        self.embed_dim = 128    # Match previous network
        self.num_heads = 8      
        
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim,
            nhead=self.num_heads,
            dim_feedforward=512,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=12)  # Increased layers
        
        self.norm = nn.LayerNorm(self.embed_dim)
        self.fc = nn.Sequential(
            nn.Linear(self.embed_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.reshape(batch_size, 196, -1)
        
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        
        x = self.transformer_encoder(x)
        x = x[:, 0]
        x = self.norm(x)
        x = self.dropout(x)
        return self.fc(x)

# Generate Embeddings
def generate_embeddings(data_loader, model, device):
    model.eval()
    all_embeddings = []
    all_labels = []
    
    with torch.no_grad():
        for batch_imgs, batch_labels in tqdm(data_loader):
            batch_imgs = batch_imgs.to(device)
            embeddings = model.get_embedding(batch_imgs)  # Using get_embedding from triplet model
            all_embeddings.append(embeddings.cpu().numpy())
            all_labels.append(batch_labels.numpy())
    
    return np.concatenate(all_embeddings), np.concatenate(all_labels)

# Train and Validate
def train_and_validate(model, train_embeddings, train_labels, val_embeddings, val_labels, 
                      num_epochs=500, batch_size=32, learning_rate=1e-4):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    X_train = torch.tensor(train_embeddings, dtype=torch.float32).to(device)
    y_train = torch.tensor(train_labels, dtype=torch.long).to(device)
    X_val = torch.tensor(val_embeddings, dtype=torch.float32).to(device)
    y_val = torch.tensor(val_labels, dtype=torch.long).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_val_loss = float('inf')
    early_stopping_patience = 10
    patience_counter = 0
    
    for epoch in tqdm(range(num_epochs)):
        model.train()
        total_train_loss = 0
        
        for i in range(0, len(X_train), batch_size):
            batch_X = X_train[i:i+batch_size]
            batch_y = y_train[i:i+batch_size]
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            
            total_train_loss += loss.item()
        
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val)
            val_loss = criterion(val_outputs, y_val).item()
            val_acc = (torch.argmax(val_outputs, dim=1) == y_val).float().mean().item()
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), 'best_model_refined_vit.pth')
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= early_stopping_patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
        
        scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}: Train Loss={total_train_loss/len(X_train):.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}')
    
    return model

# Main Function
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
    
    # Load triplet model
    triplet_model = torch.load('f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/Patch_emdebbing_generator_triplet_model.h5', map_location=device)
    triplet_model.eval()
    
    # Data loaders
    transform = get_val_transforms()
    train_data = ImageFolder(root=f'{path_data}/train/', transform=transform)
    val_data = ImageFolder(root=f'{path_data}/val/', transform=transform)
    test_data = ImageFolder(root=f'{path_data}/test/', transform=transform)
    
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    test_loader = DataLoader(test_data, batch_size=32)
    
    # Generate embeddings
    train_embeddings, train_labels = generate_embeddings(train_loader, triplet_model, device)
    val_embeddings, val_labels = generate_embeddings(val_loader, triplet_model, device)
    test_embeddings, test_labels = generate_embeddings(test_loader, triplet_model, device)
    
    # Train RefinedViT
    model = RefinedViT(num_classes=15)
    trained_model = train_and_validate(model, train_embeddings, train_labels, val_embeddings, val_labels)
    
    return trained_model, (test_embeddings, test_labels)

# Run the main function
model, test_data = main()  
torch.save(model, "f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/Patch_Rvit_triplet_model.h5")

In [None]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import AdamW
from torchvision.datasets import ImageFolder

# Improved RefinedViT Model
class RefinedViT(nn.Module):
    def __init__(self, num_classes=15):
        super(RefinedViT, self).__init__()
        self.num_patches = 196  # 14x14
        self.embed_dim = 128    # Match previous network
        self.num_heads = 16      
        
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim,
            nhead=self.num_heads,
            dim_feedforward=512,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=12)  # Increased layers
        
        self.norm = nn.LayerNorm(self.embed_dim)
        self.fc = nn.Sequential(
            nn.Linear(self.embed_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.reshape(batch_size, 196, -1)
        
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        
        x = self.transformer_encoder(x)
        x = x[:, 0]
        x = self.norm(x)
        x = self.dropout(x)
        return self.fc(x)

# Generate Embeddings
def generate_embeddings(data_loader, model, device):
    model.eval()
    all_embeddings = []
    all_labels = []
    
    with torch.no_grad():
        for batch_imgs, batch_labels in tqdm(data_loader):
            batch_imgs = batch_imgs.to(device)
            embeddings = model.get_embedding(batch_imgs)  # Using get_embedding from triplet model
            all_embeddings.append(embeddings.cpu().numpy())
            all_labels.append(batch_labels.numpy())
    
    return np.concatenate(all_embeddings), np.concatenate(all_labels)

# Train and Validate
def train_and_validate(model, train_embeddings, train_labels, val_embeddings, val_labels, 
                      num_epochs=500, batch_size=32, learning_rate=1e-4):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    X_train = torch.tensor(train_embeddings, dtype=torch.float32).to(device)
    y_train = torch.tensor(train_labels, dtype=torch.long).to(device)
    X_val = torch.tensor(val_embeddings, dtype=torch.float32).to(device)
    y_val = torch.tensor(val_labels, dtype=torch.long).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_val_loss = float('inf')
    early_stopping_patience = 10
    patience_counter = 0
    
    for epoch in tqdm(range(num_epochs)):
        model.train()
        total_train_loss = 0
        
        for i in range(0, len(X_train), batch_size):
            batch_X = X_train[i:i+batch_size]
            batch_y = y_train[i:i+batch_size]
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            
            total_train_loss += loss.item()
        
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val)
            val_loss = criterion(val_outputs, y_val).item()
            val_acc = (torch.argmax(val_outputs, dim=1) == y_val).float().mean().item()
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), 'best_model_refined_vit.pth')
                patience_counter = 0
            # else:
            #     patience_counter += 1
            #     if patience_counter >= early_stopping_patience:
            #         print(f"Early stopping at epoch {epoch+1}")
            #         break
        
        scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}: Train Loss={total_train_loss/len(X_train):.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}')
    
    return model

# Main Function
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
    
    # Load triplet model
    triplet_model = torch.load('f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/Patch_emdebbing_generator_triplet_model.h5', map_location=device)
    triplet_model.eval()
    
    # Data loaders
    transform = get_val_transforms()
    train_data = ImageFolder(root=f'{path_data}/train/', transform=transform)
    val_data = ImageFolder(root=f'{path_data}/val/', transform=transform)
    test_data = ImageFolder(root=f'{path_data}/test/', transform=transform)
    
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    test_loader = DataLoader(test_data, batch_size=32)
    
    # Generate embeddings
    train_embeddings, train_labels = generate_embeddings(train_loader, triplet_model, device)
    val_embeddings, val_labels = generate_embeddings(val_loader, triplet_model, device)
    test_embeddings, test_labels = generate_embeddings(test_loader, triplet_model, device)
    
    # Train RefinedViT
    model = RefinedViT(num_classes=15)
    trained_model = train_and_validate(model, train_embeddings, train_labels, val_embeddings, val_labels)
    
    return trained_model, (test_embeddings, test_labels)

# Run the main function
model, test_data = main()  
torch.save(model, "f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/Patch_Rvit_triplet_model.h5")

  triplet_model = torch.load('f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/Patch_emdebbing_generator_triplet_model.h5', map_location=device)
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x0000018B0F2AA9D0>
Traceback (most recent call last):
  File "C:\Users\pc\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "C:\Users\pc\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 1435, in _shutdown_workers
    if self._persistent_workers or self._workers_status[worker_id]:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_workers_status'
100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 11.76it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 65/65 [00:05<00:00, 11.45it/s]
100%|█████████████████████████████████████████████████