In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50
from torchvision.models.vgg import vgg16_bn
from torchvision.models.densenet import densenet121
from torchvision.models.mobilenet import mobilenet_v2
from efficientnet_pytorch import EfficientNet
from transformers import ViTModel, ViTFeatureExtractor
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from learn2learn.data import MetaDataset
import random
from PIL import Image
from torch.optim.lr_scheduler import CosineAnnealingLR

class PatchEmbedding(nn.Module):
    def __init__(self, patch_size=16, embed_dim=128):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.embedding_network = nn.Sequential(
            nn.Linear(patch_size * patch_size * 3, 256),
            nn.ReLU(),
            nn.Linear(256, embed_dim),
            nn.LayerNorm(embed_dim)
        )
        
    def forward(self, x):
        B, C, H, W = x.shape
        # Split image into patches
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.permute(0, 2, 3, 1, 4, 5)  # (B, H/16, W/16, C, patch_size, patch_size)
        
        # Flatten patches while keeping spatial information
        patch_vectors = patches.reshape(B, H//self.patch_size, W//self.patch_size, -1)
        
        # Create embeddings for each patch
        embeddings = self.embedding_network(patch_vectors)  # (B, 14, 14, embed_dim)
        
        return embeddings

class EfficientNetPatchB4(nn.Module):
    def __init__(self, patch_size=16, embed_dim=128):
        super().__init__()
        self.efficientnet = EfficientNet.from_pretrained('efficientnet-b4')
        self.patch_embed = PatchEmbedding(patch_size, embed_dim)
        
    def forward(self, x):
        # Generate patch embeddings with spatial information preserved
        patch_embeddings = self.patch_embed(x)  # (batch_size, 14, 14, embed_dim)
        patch_embeddings = F.normalize(patch_embeddings, p=2, dim=-1)
        return patch_embeddings

class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super().__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2, x3):
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)
        output3 = self.embedding_net(x3)
        return output1, output2, output3

    def get_embedding(self, x):
        return self.embedding_net(x)

class SpatialTripletLoss(nn.Module):
    def __init__(self, margin):
        super().__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(dim=-1)  # Sum over embedding dimension
        distance_negative = (anchor - negative).pow(2).sum(dim=-1)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

class TripletDataset(Dataset):
    def __init__(self, root=None, transforms=None):
        img_folder = ImageFolder(root=root)
        meta_data = MetaDataset(img_folder)
        
        self.img_list = img_folder.imgs
        self.labels_to_indices = meta_data.labels_to_indices
        self.indices_to_labels = meta_data.indices_to_labels
        self.labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
        self.transforms = transforms

    def __getitem__(self, index):
        img_anchor = self.img_list[index][0]
        label_anchor = self.img_list[index][1]

        idx_positive = random.choice(self.labels_to_indices[label_anchor])            
        img_positive = self.img_list[idx_positive][0]
        
        aux_class = random.choice(list(set(self.labels) - set([label_anchor])))
        idx_negative = random.choice(self.labels_to_indices[aux_class])
        img_negative = self.img_list[idx_negative][0]

        img_anchor = Image.open(img_anchor).convert('RGB')
        img_positive = Image.open(img_positive).convert('RGB')
        img_negative = Image.open(img_negative).convert('RGB')

        if self.transforms is not None:
            img_anchor = self.transforms(img_anchor)
            img_positive = self.transforms(img_positive)
            img_negative = self.transforms(img_negative)

        return (img_anchor, img_positive, img_negative), []

    def __len__(self):
        return len(self.img_list)

def get_triplet_dataloader(root=None, batch_size=1, transforms=None):
    dataset = TripletDataset(root=root, transforms=transforms)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def get_train_transforms():
    return T.Compose([
        T.Resize((224, 224)),
        T.RandomHorizontalFlip(0.5),
        T.RandomVerticalFlip(0.5),
        T.RandomApply([T.RandomRotation(10)], 0.25),
        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def get_val_transforms():
    return T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def fit(train_loader, val_loader, model, loss_fn, optimizer, scheduler, n_epochs, device, log_interval=100):
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            img_a, img_p, img_n = data
            img_a, img_p, img_n = img_a.to(device), img_p.to(device), img_n.to(device)
            
            optimizer.zero_grad()
            embeddings = model(img_a, img_p, img_n)
            loss = loss_fn(*embeddings)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % log_interval == 0:
                print(f'Epoch {epoch}: [{batch_idx * len(img_a)}/{len(train_loader.dataset)}] '
                      f'Loss: {total_loss / (batch_idx + 1):.6f}')
        
        scheduler.step()
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data, _ in val_loader:
                img_a, img_p, img_n = data
                img_a, img_p, img_n = img_a.to(device), img_p.to(device), img_n.to(device)
                embeddings = model(img_a, img_p, img_n)
                val_loss += loss_fn(*embeddings).item()
        
        val_loss /= len(val_loader)
        print(f'Epoch {epoch}: Validation Loss: {val_loss:.6f}')

# Initialize model and training
path_data = 'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
embedding_net = EfficientNetPatchB4(patch_size=16, embed_dim=128)
triplet_model = TripletNet(embedding_net=embedding_net)
loss_fn = SpatialTripletLoss(margin=1.0)
optimizer = torch.optim.Adam(triplet_model.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    triplet_model = triplet_model.cuda()

# Create dataloaders
triplet_train_loader = get_triplet_dataloader(
    root=path_data + '/train/',
    batch_size=5,
    transforms=get_train_transforms()
)
triplet_val_loader = get_triplet_dataloader(
    root=path_data + '/val/',
    batch_size=5,
    transforms=get_val_transforms()
)

# Train
n_epochs = 100
fit(triplet_train_loader, triplet_val_loader, triplet_model, loss_fn, optimizer, scheduler, n_epochs, device)