In [9]:
import sys
#sys.path.insert(0,'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor
#from dataloaders import get_train_transforms, get_val_transforms, get_triplet_dataloader
# from networks import TripletNet 
# from models import MobileNetv2
# from models import EfficientNetB4
# from losses import TripletLoss
# from trainer import fit
import torchvision
import timm
from IPython.display import clear_output 
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.manifold import TSNE
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import accuracy_score, f1_score , precision_score , recall_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50
from torchvision.models.vgg import vgg16_bn
from torchvision.models.densenet import densenet121
from torchvision.models.mobilenet import mobilenet_v2
from efficientnet_pytorch import EfficientNet
import learn2learn as l2l
from transformers import ViTModel, ViTFeatureExtractor
import numpy as np
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import time
import learn2learn as l2l
#from dataset import SiameseDataset, TripletDataset
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from learn2learn.data import MetaDataset
import random
from PIL import Image
 

class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()


class EfficientNetB4(nn.Module):
    def __init__(self):
        super(EfficientNetB4, self).__init__()
        self.efficientnet = EfficientNet.from_pretrained('efficientnet-b4')
    
    def forward(self, x):
        x = self.efficientnet.extract_features(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = F.normalize(x, p=2, dim=1)

        return x

class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2, x3):
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)
        output3 = self.embedding_net(x3)
        return output1, output2, output3

    def get_embedding(self, x):
        return self.embedding_net(x)  

class TripletDataset(Dataset):
    def __init__(self, root=None, transforms=None):

        img_folder = ImageFolder(root=root)
        meta_data = MetaDataset(img_folder)
        
        self.img_list = img_folder.imgs
        self.labels_to_indices = meta_data.labels_to_indices
        self.indices_to_labels = meta_data.indices_to_labels
        self.labels = [0, 1, 2, 3, 4,5,6,7,8,9,10,11,12,13,14] # labels in dataset
        self.transforms = transforms

    def __getitem__(self, index):
        
        img_anchor = self.img_list[index][0]
        label_anchor = self.img_list[index][1]

        # get a positive sample
        idx_positive = random.choice(self.labels_to_indices[label_anchor])            
        img_positive = self.img_list[idx_positive][0]
        label_positive = self.img_list[idx_positive][1]
        
        # get a negative sample
        aux_class = random.choice(list(set(self.labels) - set([label_anchor]))) # select a random label from the labels list
        idx_negative = random.choice(self.labels_to_indices[aux_class])
        img_negative = self.img_list[idx_negative][0]
        label_negative = self.img_list[idx_negative][1]

        img_anchor = Image.open(img_anchor).convert('RGB')
        img_positive = Image.open(img_positive).convert('RGB')
        img_negative = Image.open(img_negative).convert('RGB')

        if self.transforms is not None:
            img_anchor = self.transforms(img_anchor)
            img_positive = self.transforms(img_positive)
            img_negative = self.transforms(img_negative)

        return (img_anchor, img_positive, img_negative), []


    def __len__(self):
        return len(self.img_list)
    
    
# create TripletNet dataloader given a root path
def get_triplet_dataloader(root=None, batch_size=1, transforms=None):
    dataset = TripletDataset(root=root, transforms=transforms)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader  

def get_train_transforms():
    return T.Compose([T.Resize((224, 224)),
                      T.RandomHorizontalFlip(0.5),
                      T.RandomVerticalFlip(0.5),
                      T.RandomApply([T.RandomRotation(10)], 0.25),
                      T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
                      T.ToTensor(),
                      T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                     ])


# get validation and test transforms
def get_val_transforms():
    return T.Compose([T.Resize((224, 224)),
                      T.ToTensor(),
                      T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                     ])
    
# استفاده از مدل EfficientNet به جای MobileNetv2 برای استخراج ویژگی‌ها
embedding_net =EfficientNetB4() #timm.create_model('efficientnet_b0', pretrained=True)  # استفاده از مدل EfficientNet
siamese_model = TripletNet(embedding_net=embedding_net)

optimizer = torch.optim.Adam(siamese_model.parameters(), lr=1e-4)  # تغییر به Adam برای بهبود عملکرد
lr_scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)  # استفاده از Cosine Annealing برای تغییر نرخ یادگیری
loss_fn = TripletLoss(1.)
n_epochs = 100  # تعداد epochs
device = torch.cuda.is_available()

if device:
    siamese_model.cuda()
    
# بارگذاری داده‌ها
#path_data = 'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
path_data = 'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
triplet_train_loader = get_triplet_dataloader(root=path_data + '/train/', batch_size=5, transforms=get_train_transforms())
triplet_val_loader = get_triplet_dataloader(root=path_data + '/val/', batch_size=5, transforms=get_val_transforms())
  

Loaded pretrained weights for efficientnet-b4


In [10]:
import sys
#sys.path.insert(0,'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
sys.path.insert(0,'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
from trainer import fit
# آموزش مدل Siamese
fit(triplet_train_loader, triplet_val_loader, siamese_model, loss_fn, optimizer, lr_scheduler, n_epochs, device, log_interval=10)




KeyboardInterrupt: 