In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50
from torchvision.models.vgg import vgg16_bn
from torchvision.models.densenet import densenet121
from torchvision.models.mobilenet import mobilenet_v2
from efficientnet_pytorch import EfficientNet
from transformers import ViTModel, ViTFeatureExtractor
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from learn2learn.data import MetaDataset
import random
from PIL import Image
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn


class PatchEmbedding(nn.Module):
    def __init__(self, patch_size=16, embed_dim=128):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.embedding_network = nn.Sequential(
            nn.Linear(patch_size * patch_size * 3, 256),
            nn.ReLU(),
            nn.Linear(256, embed_dim),
            nn.LayerNorm(embed_dim)
        )
        
    def forward(self, x):
        B, C, H, W = x.shape
        # Split image into patches
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.permute(0, 2, 3, 1, 4, 5)  # (B, H/16, W/16, C, patch_size, patch_size)
        
        # Flatten patches while keeping spatial information
        patch_vectors = patches.reshape(B, H//self.patch_size, W//self.patch_size, -1)
        
        # Create embeddings for each patch
        embeddings = self.embedding_network(patch_vectors)  # (B, 14, 14, embed_dim)
        
        return embeddings

class EfficientNetPatchB4(nn.Module):
    def __init__(self, patch_size=16, embed_dim=128):
        super().__init__()
        self.efficientnet = EfficientNet.from_pretrained('efficientnet-b4')
        self.patch_embed = PatchEmbedding(patch_size, embed_dim)
        
    def forward(self, x):
        # Generate patch embeddings with spatial information preserved
        patch_embeddings = self.patch_embed(x)  # (batch_size, 14, 14, embed_dim)
        patch_embeddings = F.normalize(patch_embeddings, p=2, dim=-1)
        return patch_embeddings

class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super().__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2, x3):
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)
        output3 = self.embedding_net(x3)
        return output1, output2, output3

    def get_embedding(self, x):
        return self.embedding_net(x)

class SpatialTripletLoss(nn.Module):
    def __init__(self, margin):
        super().__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(dim=-1)  # Sum over embedding dimension
        distance_negative = (anchor - negative).pow(2).sum(dim=-1)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

class TripletDataset(Dataset):
    def __init__(self, root=None, transforms=None):
        img_folder = ImageFolder(root=root)
        meta_data = MetaDataset(img_folder)
        
        self.img_list = img_folder.imgs
        self.labels_to_indices = meta_data.labels_to_indices
        self.indices_to_labels = meta_data.indices_to_labels
        self.labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
        self.transforms = transforms

    def __getitem__(self, index):
        img_anchor = self.img_list[index][0]
        label_anchor = self.img_list[index][1]

        idx_positive = random.choice(self.labels_to_indices[label_anchor])            
        img_positive = self.img_list[idx_positive][0]
        
        aux_class = random.choice(list(set(self.labels) - set([label_anchor])))
        idx_negative = random.choice(self.labels_to_indices[aux_class])
        img_negative = self.img_list[idx_negative][0]

        img_anchor = Image.open(img_anchor).convert('RGB')
        img_positive = Image.open(img_positive).convert('RGB')
        img_negative = Image.open(img_negative).convert('RGB')

        if self.transforms is not None:
            img_anchor = self.transforms(img_anchor)
            img_positive = self.transforms(img_positive)
            img_negative = self.transforms(img_negative)

        return (img_anchor, img_positive, img_negative), []

    def __len__(self):
        return len(self.img_list)

def get_triplet_dataloader(root=None, batch_size=1, transforms=None):
    dataset = TripletDataset(root=root, transforms=transforms)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def get_train_transforms():
    return T.Compose([
        T.Resize((224, 224)),
        T.RandomHorizontalFlip(0.5),
        T.RandomVerticalFlip(0.5),
        T.RandomApply([T.RandomRotation(10)], 0.25),
        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def get_val_transforms():
    return T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def fit(train_loader, val_loader, model, loss_fn, optimizer, scheduler, n_epochs, device, log_interval=100):
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            img_a, img_p, img_n = data
            img_a, img_p, img_n = img_a.to(device), img_p.to(device), img_n.to(device)
            
            optimizer.zero_grad()
            embeddings = model(img_a, img_p, img_n)
            loss = loss_fn(*embeddings)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % log_interval == 0:
                print(f'Epoch {epoch}: [{batch_idx * len(img_a)}/{len(train_loader.dataset)}] '
                      f'Loss: {total_loss / (batch_idx + 1):.6f}')
        
        scheduler.step()
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data, _ in val_loader:
                img_a, img_p, img_n = data
                img_a, img_p, img_n = img_a.to(device), img_p.to(device), img_n.to(device)
                embeddings = model(img_a, img_p, img_n)
                val_loss += loss_fn(*embeddings).item()
        
        val_loss /= len(val_loader)
        print(f'Epoch {epoch}: Validation Loss: {val_loss:.6f}')

# Initialize model and training
path_data = 'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
embedding_net = EfficientNetPatchB4(patch_size=16, embed_dim=128)
triplet_model = TripletNet(embedding_net=embedding_net)
loss_fn = SpatialTripletLoss(margin=1.0)
optimizer = torch.optim.Adam(triplet_model.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    triplet_model = triplet_model.cuda()

# Create dataloaders
triplet_train_loader = get_triplet_dataloader(
    root=path_data + '/train/',
    batch_size=5,
    transforms=get_train_transforms()
)
triplet_val_loader = get_triplet_dataloader(
    root=path_data + '/val/',
    batch_size=5,
    transforms=get_val_transforms()
)

# Train
n_epochs = 1
fit(triplet_train_loader, triplet_val_loader, triplet_model, loss_fn, optimizer, scheduler, n_epochs, device)
torch.save(triplet_model, "C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/Patch_emdebbing_generator_triplet_model.h5" )

C:\Users\Mey\AppData\Roaming\Python\Python39\site-packages\numpy\.libs\libopenblas.XWYDX2IKJW2NMTWSFYNGFUWKQU3LYTCZ.gfortran-win_amd64.dll
C:\Users\Mey\AppData\Roaming\Python\Python39\site-packages\numpy\.libs\libopenblas64__v0.3.21-gcc_10_3_0.dll
  warn(


Loaded pretrained weights for efficientnet-b4
Epoch 0: [0/225] Loss: 1.078055
Epoch 0: Validation Loss: 0.928408


In [None]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from datetime import datetime
from torch.utils.data import DataLoader
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys
sys.path.insert(0,'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
#sys.path.insert(0,'f:/Meysam-Khodarahi/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
from dataloaders import get_val_transforms
import numpy as np
from sklearn.manifold import TSNE
import matplotlib as mpl
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import transforms
from torch.autograd import Variable
import os
import pandas as pd
import seaborn as sns
import dataloaders
from dataloaders import get_train_transforms, get_val_transforms, get_triplet_dataloader
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch
from sklearn.metrics import accuracy_score, f1_score , precision_score , recall_score
from sklearn.manifold import TSNE
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from datetime import datetime

class RefinedViT(nn.Module):
    def __init__(self, num_classes=15):
        super(RefinedViT, self).__init__()
        self.num_patches = 196  # 14x14
        self.embed_dim = 128    # Match previous network
        self.num_heads = 8      
        
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim,
            nhead=self.num_heads,
            dim_feedforward=512,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=10)
        
        self.norm = nn.LayerNorm(self.embed_dim)
        self.fc = nn.Linear(self.embed_dim, num_classes)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.reshape(batch_size, 196, -1)
        
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        
        x = self.transformer_encoder(x)
        x = x[:, 0]
        x = self.norm(x)
        x = self.dropout(x)
        return self.fc(x)

def generate_embeddings(data_loader, model, device):
    model.eval()
    all_embeddings = []
    all_labels = []
    
    with torch.no_grad():
        for batch_imgs, batch_labels in tqdm(data_loader):
            batch_imgs = batch_imgs.to(device)
            embeddings = model.get_embedding(batch_imgs)  # Using get_embedding from triplet model
            print(f"Size of embeddings: {embeddings.shape}")
            all_embeddings.append(embeddings.cpu().numpy())
            all_labels.append(batch_labels.numpy())
    
    return np.concatenate(all_embeddings), np.concatenate(all_labels)

def train_and_validate(model, train_embeddings, train_labels, val_embeddings, val_labels, 
                      num_epochs=100, batch_size=32, learning_rate=1e-4):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    X_train = torch.tensor(train_embeddings, dtype=torch.float32).to(device)
    y_train = torch.tensor(train_labels, dtype=torch.long).to(device)
    X_val = torch.tensor(val_embeddings, dtype=torch.float32).to(device)
    y_val = torch.tensor(val_labels, dtype=torch.long).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    
    for epoch in tqdm(range(num_epochs)):
        model.train()
        total_train_loss = 0
        
        for i in range(0, len(X_train), batch_size):
            batch_X = X_train[i:i+batch_size]
            batch_y = y_train[i:i+batch_size]
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
        
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val)
            val_loss = criterion(val_outputs, y_val).item()
            val_acc = (torch.argmax(val_outputs, dim=1) == y_val).float().mean().item()
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), 'best_model_refined_vit.pth')
        
        scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}: Train Loss={total_train_loss/len(X_train):.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}')
    
    return model

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    path_data =  'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
    
    # Load triplet model
    #triplet_model = torch.load('C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/Patch_emdebbing_generator_triplet_model.h5', map_location=device)
    triplet_model.eval()
    
    # Data loaders
    transform = get_val_transforms()
    train_data = torchvision.datasets.ImageFolder(root=f'{path_data}/train/', transform=transform)
    val_data = torchvision.datasets.ImageFolder(root=f'{path_data}/val/', transform=transform)
    test_data = torchvision.datasets.ImageFolder(root=f'{path_data}/test/', transform=transform)
    
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    test_loader = DataLoader(test_data, batch_size=32)
    
    # Generate embeddings
    train_embeddings, train_labels = generate_embeddings(train_loader, triplet_model, device)
    val_embeddings, val_labels = generate_embeddings(val_loader, triplet_model, device)
    test_embeddings, test_labels = generate_embeddings(test_loader, triplet_model, device)
    
    # Train RefinedViT
    model = RefinedViT(num_classes=15)
    trained_model = train_and_validate(model, train_embeddings, train_labels, val_embeddings, val_labels)
    
    return trained_model, (test_embeddings, test_labels)


model, test_data = main()    

 12%|██████████▌                                                                         | 1/8 [00:00<00:02,  2.60it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 25%|█████████████████████                                                               | 2/8 [00:00<00:02,  2.64it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 38%|███████████████████████████████▌                                                    | 3/8 [00:01<00:01,  2.52it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 50%|██████████████████████████████████████████                                          | 4/8 [00:01<00:01,  2.61it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 62%|████████████████████████████████████████████████████▌                               | 5/8 [00:01<00:01,  2.70it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 75%|███████████████████████████████████████████████████████████████                     | 6/8 [00:02<00:00,  2.63it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  2.92it/s]


Size of embeddings: torch.Size([32, 14, 14, 128])
Size of embeddings: torch.Size([1, 14, 14, 128])


  2%|█▎                                                                                 | 1/65 [00:00<00:25,  2.51it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  3%|██▌                                                                                | 2/65 [00:00<00:24,  2.56it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  5%|███▊                                                                               | 3/65 [00:01<00:24,  2.49it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  6%|█████                                                                              | 4/65 [00:01<00:25,  2.42it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  8%|██████▍                                                                            | 5/65 [00:02<00:24,  2.45it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  9%|███████▋                                                                           | 6/65 [00:02<00:23,  2.55it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 11%|████████▉                                                                          | 7/65 [00:02<00:22,  2.60it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 12%|██████████▏                                                                        | 8/65 [00:03<00:22,  2.54it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 14%|███████████▍                                                                       | 9/65 [00:03<00:21,  2.58it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 15%|████████████▌                                                                     | 10/65 [00:03<00:22,  2.49it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 17%|█████████████▉                                                                    | 11/65 [00:04<00:21,  2.55it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 18%|███████████████▏                                                                  | 12/65 [00:04<00:21,  2.45it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 20%|████████████████▍                                                                 | 13/65 [00:05<00:22,  2.31it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 22%|█████████████████▋                                                                | 14/65 [00:05<00:23,  2.19it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 23%|██████████████████▉                                                               | 15/65 [00:06<00:23,  2.12it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 25%|████████████████████▏                                                             | 16/65 [00:06<00:24,  2.03it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 26%|█████████████████████▍                                                            | 17/65 [00:07<00:24,  1.97it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 28%|██████████████████████▋                                                           | 18/65 [00:07<00:23,  2.01it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 29%|███████████████████████▉                                                          | 19/65 [00:08<00:22,  2.08it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 31%|█████████████████████████▏                                                        | 20/65 [00:08<00:21,  2.12it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 32%|██████████████████████████▍                                                       | 21/65 [00:09<00:20,  2.18it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 34%|███████████████████████████▊                                                      | 22/65 [00:09<00:18,  2.30it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 35%|█████████████████████████████                                                     | 23/65 [00:09<00:17,  2.34it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 37%|██████████████████████████████▎                                                   | 24/65 [00:10<00:17,  2.39it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 38%|███████████████████████████████▌                                                  | 25/65 [00:10<00:16,  2.42it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 40%|████████████████████████████████▊                                                 | 26/65 [00:11<00:15,  2.56it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 42%|██████████████████████████████████                                                | 27/65 [00:11<00:14,  2.60it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 43%|███████████████████████████████████▎                                              | 28/65 [00:11<00:13,  2.69it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 45%|████████████████████████████████████▌                                             | 29/65 [00:12<00:14,  2.48it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 46%|█████████████████████████████████████▊                                            | 30/65 [00:12<00:13,  2.54it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 48%|███████████████████████████████████████                                           | 31/65 [00:12<00:12,  2.66it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 49%|████████████████████████████████████████▎                                         | 32/65 [00:13<00:12,  2.68it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 51%|█████████████████████████████████████████▋                                        | 33/65 [00:13<00:12,  2.60it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 52%|██████████████████████████████████████████▉                                       | 34/65 [00:14<00:12,  2.58it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 54%|████████████████████████████████████████████▏                                     | 35/65 [00:14<00:11,  2.67it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 55%|█████████████████████████████████████████████▍                                    | 36/65 [00:14<00:10,  2.71it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 57%|██████████████████████████████████████████████▋                                   | 37/65 [00:15<00:10,  2.75it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 58%|███████████████████████████████████████████████▉                                  | 38/65 [00:15<00:09,  2.73it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 60%|█████████████████████████████████████████████████▏                                | 39/65 [00:15<00:09,  2.73it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 62%|██████████████████████████████████████████████████▍                               | 40/65 [00:16<00:09,  2.78it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 63%|███████████████████████████████████████████████████▋                              | 41/65 [00:16<00:08,  2.71it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 65%|████████████████████████████████████████████████████▉                             | 42/65 [00:17<00:08,  2.75it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 66%|██████████████████████████████████████████████████████▏                           | 43/65 [00:17<00:08,  2.50it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 68%|███████████████████████████████████████████████████████▌                          | 44/65 [00:17<00:08,  2.53it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 69%|████████████████████████████████████████████████████████▊                         | 45/65 [00:18<00:07,  2.51it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 71%|██████████████████████████████████████████████████████████                        | 46/65 [00:18<00:07,  2.43it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 72%|███████████████████████████████████████████████████████████▎                      | 47/65 [00:19<00:07,  2.37it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 74%|████████████████████████████████████████████████████████████▌                     | 48/65 [00:19<00:06,  2.49it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 75%|█████████████████████████████████████████████████████████████▊                    | 49/65 [00:19<00:06,  2.50it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 77%|███████████████████████████████████████████████████████████████                   | 50/65 [00:20<00:05,  2.58it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 78%|████████████████████████████████████████████████████████████████▎                 | 51/65 [00:20<00:05,  2.64it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 80%|█████████████████████████████████████████████████████████████████▌                | 52/65 [00:21<00:04,  2.75it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 82%|██████████████████████████████████████████████████████████████████▊               | 53/65 [00:21<00:04,  2.84it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 83%|████████████████████████████████████████████████████████████████████              | 54/65 [00:21<00:04,  2.56it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 85%|█████████████████████████████████████████████████████████████████████▍            | 55/65 [00:22<00:03,  2.66it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 86%|██████████████████████████████████████████████████████████████████████▋           | 56/65 [00:22<00:03,  2.54it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 88%|███████████████████████████████████████████████████████████████████████▉          | 57/65 [00:23<00:03,  2.39it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 89%|█████████████████████████████████████████████████████████████████████████▏        | 58/65 [00:23<00:03,  2.29it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 91%|██████████████████████████████████████████████████████████████████████████▍       | 59/65 [00:23<00:02,  2.33it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 92%|███████████████████████████████████████████████████████████████████████████▋      | 60/65 [00:24<00:02,  2.47it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 94%|████████████████████████████████████████████████████████████████████████████▉     | 61/65 [00:24<00:01,  2.43it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 95%|██████████████████████████████████████████████████████████████████████████████▏   | 62/65 [00:25<00:01,  2.56it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 97%|███████████████████████████████████████████████████████████████████████████████▍  | 63/65 [00:25<00:00,  2.50it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 98%|████████████████████████████████████████████████████████████████████████████████▋ | 64/65 [00:25<00:00,  2.53it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


100%|██████████████████████████████████████████████████████████████████████████████████| 65/65 [00:26<00:00,  2.48it/s]


Size of embeddings: torch.Size([28, 14, 14, 128])


  1%|▋                                                                                 | 1/129 [00:00<00:41,  3.12it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  2%|█▎                                                                                | 2/129 [00:00<00:44,  2.86it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  2%|█▉                                                                                | 3/129 [00:01<00:47,  2.63it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  3%|██▌                                                                               | 4/129 [00:01<00:45,  2.77it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  4%|███▏                                                                              | 5/129 [00:01<00:47,  2.59it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  5%|███▊                                                                              | 6/129 [00:02<00:50,  2.45it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  5%|████▍                                                                             | 7/129 [00:02<00:54,  2.23it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  6%|█████                                                                             | 8/129 [00:03<00:55,  2.18it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  7%|█████▋                                                                            | 9/129 [00:03<00:53,  2.23it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  8%|██████▎                                                                          | 10/129 [00:04<00:54,  2.19it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  9%|██████▉                                                                          | 11/129 [00:04<00:52,  2.25it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


  9%|███████▌                                                                         | 12/129 [00:05<00:52,  2.22it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 10%|████████▏                                                                        | 13/129 [00:05<00:54,  2.12it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 11%|████████▊                                                                        | 14/129 [00:06<00:53,  2.13it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 12%|█████████▍                                                                       | 15/129 [00:06<00:51,  2.23it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 12%|██████████                                                                       | 16/129 [00:06<00:51,  2.18it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 13%|██████████▋                                                                      | 17/129 [00:07<00:54,  2.07it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 14%|███████████▎                                                                     | 18/129 [00:07<00:53,  2.08it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 15%|███████████▉                                                                     | 19/129 [00:08<00:48,  2.27it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 16%|████████████▌                                                                    | 20/129 [00:08<00:49,  2.21it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 16%|█████████████▏                                                                   | 21/129 [00:09<00:51,  2.09it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 17%|█████████████▊                                                                   | 22/129 [00:09<00:48,  2.22it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 18%|██████████████▍                                                                  | 23/129 [00:10<00:45,  2.31it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 19%|███████████████                                                                  | 24/129 [00:10<00:42,  2.45it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 19%|███████████████▋                                                                 | 25/129 [00:10<00:39,  2.60it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 20%|████████████████▎                                                                | 26/129 [00:11<00:40,  2.52it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 21%|████████████████▉                                                                | 27/129 [00:11<00:41,  2.45it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 22%|█████████████████▌                                                               | 28/129 [00:12<00:42,  2.37it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 22%|██████████████████▏                                                              | 29/129 [00:12<00:40,  2.46it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 23%|██████████████████▊                                                              | 30/129 [00:12<00:38,  2.57it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 24%|███████████████████▍                                                             | 31/129 [00:13<00:37,  2.58it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 25%|████████████████████                                                             | 32/129 [00:13<00:36,  2.68it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 26%|████████████████████▋                                                            | 33/129 [00:13<00:35,  2.71it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 26%|█████████████████████▎                                                           | 34/129 [00:14<00:34,  2.74it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 27%|█████████████████████▉                                                           | 35/129 [00:14<00:36,  2.60it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 28%|██████████████████████▌                                                          | 36/129 [00:15<00:39,  2.37it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 29%|███████████████████████▏                                                         | 37/129 [00:15<00:40,  2.26it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 29%|███████████████████████▊                                                         | 38/129 [00:16<00:42,  2.16it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 30%|████████████████████████▍                                                        | 39/129 [00:16<00:41,  2.17it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 31%|█████████████████████████                                                        | 40/129 [00:17<00:38,  2.29it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 32%|█████████████████████████▋                                                       | 41/129 [00:17<00:38,  2.31it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 33%|██████████████████████████▎                                                      | 42/129 [00:17<00:37,  2.32it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 33%|███████████████████████████                                                      | 43/129 [00:18<00:36,  2.34it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 34%|███████████████████████████▋                                                     | 44/129 [00:18<00:34,  2.45it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 35%|████████████████████████████▎                                                    | 45/129 [00:19<00:32,  2.57it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 36%|████████████████████████████▉                                                    | 46/129 [00:19<00:31,  2.63it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 36%|█████████████████████████████▌                                                   | 47/129 [00:19<00:30,  2.71it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 37%|██████████████████████████████▏                                                  | 48/129 [00:20<00:29,  2.76it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 38%|██████████████████████████████▊                                                  | 49/129 [00:20<00:28,  2.84it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 39%|███████████████████████████████▍                                                 | 50/129 [00:20<00:27,  2.84it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 40%|████████████████████████████████                                                 | 51/129 [00:21<00:27,  2.83it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 40%|████████████████████████████████▋                                                | 52/129 [00:21<00:29,  2.58it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 41%|█████████████████████████████████▎                                               | 53/129 [00:22<00:30,  2.45it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 42%|█████████████████████████████████▉                                               | 54/129 [00:22<00:32,  2.32it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 43%|██████████████████████████████████▌                                              | 55/129 [00:23<00:33,  2.18it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 43%|███████████████████████████████████▏                                             | 56/129 [00:23<00:32,  2.28it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 44%|███████████████████████████████████▊                                             | 57/129 [00:23<00:33,  2.15it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 45%|████████████████████████████████████▍                                            | 58/129 [00:24<00:30,  2.34it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 46%|█████████████████████████████████████                                            | 59/129 [00:24<00:28,  2.48it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 47%|█████████████████████████████████████▋                                           | 60/129 [00:25<00:27,  2.51it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 47%|██████████████████████████████████████▎                                          | 61/129 [00:25<00:29,  2.29it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 48%|██████████████████████████████████████▉                                          | 62/129 [00:26<00:29,  2.30it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 49%|███████████████████████████████████████▌                                         | 63/129 [00:26<00:28,  2.28it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 50%|████████████████████████████████████████▏                                        | 64/129 [00:26<00:29,  2.20it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 50%|████████████████████████████████████████▊                                        | 65/129 [00:27<00:27,  2.33it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 51%|█████████████████████████████████████████▍                                       | 66/129 [00:27<00:26,  2.35it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 52%|██████████████████████████████████████████                                       | 67/129 [00:28<00:28,  2.16it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 53%|██████████████████████████████████████████▋                                      | 68/129 [00:28<00:27,  2.21it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 53%|███████████████████████████████████████████▎                                     | 69/129 [00:29<00:25,  2.35it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 54%|███████████████████████████████████████████▉                                     | 70/129 [00:29<00:23,  2.48it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 55%|████████████████████████████████████████████▌                                    | 71/129 [00:29<00:22,  2.61it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 56%|█████████████████████████████████████████████▏                                   | 72/129 [00:30<00:24,  2.37it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 57%|█████████████████████████████████████████████▊                                   | 73/129 [00:30<00:24,  2.29it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 57%|██████████████████████████████████████████████▍                                  | 74/129 [00:31<00:33,  1.62it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 58%|███████████████████████████████████████████████                                  | 75/129 [00:32<00:30,  1.78it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 59%|███████████████████████████████████████████████▋                                 | 76/129 [00:32<00:28,  1.84it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 60%|████████████████████████████████████████████████▎                                | 77/129 [00:33<00:26,  1.95it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 60%|████████████████████████████████████████████████▉                                | 78/129 [00:33<00:25,  2.01it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 61%|█████████████████████████████████████████████████▌                               | 79/129 [00:34<00:23,  2.12it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 62%|██████████████████████████████████████████████████▏                              | 80/129 [00:34<00:21,  2.31it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 63%|██████████████████████████████████████████████████▊                              | 81/129 [00:34<00:20,  2.38it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 64%|███████████████████████████████████████████████████▍                             | 82/129 [00:35<00:18,  2.52it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 64%|████████████████████████████████████████████████████                             | 83/129 [00:35<00:17,  2.64it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 65%|████████████████████████████████████████████████████▋                            | 84/129 [00:35<00:16,  2.72it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 66%|█████████████████████████████████████████████████████▎                           | 85/129 [00:36<00:16,  2.67it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 67%|██████████████████████████████████████████████████████                           | 86/129 [00:36<00:16,  2.65it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 67%|██████████████████████████████████████████████████████▋                          | 87/129 [00:36<00:15,  2.72it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 68%|███████████████████████████████████████████████████████▎                         | 88/129 [00:37<00:15,  2.72it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 69%|███████████████████████████████████████████████████████▉                         | 89/129 [00:37<00:14,  2.79it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 70%|████████████████████████████████████████████████████████▌                        | 90/129 [00:37<00:13,  2.79it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 71%|█████████████████████████████████████████████████████████▏                       | 91/129 [00:38<00:14,  2.69it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 71%|█████████████████████████████████████████████████████████▊                       | 92/129 [00:38<00:14,  2.58it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 72%|██████████████████████████████████████████████████████████▍                      | 93/129 [00:39<00:13,  2.60it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 73%|███████████████████████████████████████████████████████████                      | 94/129 [00:39<00:15,  2.31it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 74%|███████████████████████████████████████████████████████████▋                     | 95/129 [00:40<00:14,  2.29it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 74%|████████████████████████████████████████████████████████████▎                    | 96/129 [00:40<00:14,  2.34it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 75%|████████████████████████████████████████████████████████████▉                    | 97/129 [00:40<00:13,  2.40it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 76%|█████████████████████████████████████████████████████████████▌                   | 98/129 [00:41<00:12,  2.40it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 77%|██████████████████████████████████████████████████████████████▏                  | 99/129 [00:41<00:12,  2.47it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 78%|██████████████████████████████████████████████████████████████                  | 100/129 [00:42<00:11,  2.56it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 78%|██████████████████████████████████████████████████████████████▋                 | 101/129 [00:42<00:10,  2.68it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 79%|███████████████████████████████████████████████████████████████▎                | 102/129 [00:42<00:10,  2.65it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 80%|███████████████████████████████████████████████████████████████▉                | 103/129 [00:43<00:09,  2.75it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 81%|████████████████████████████████████████████████████████████████▍               | 104/129 [00:43<00:08,  2.79it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 81%|█████████████████████████████████████████████████████████████████               | 105/129 [00:43<00:08,  2.76it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 82%|█████████████████████████████████████████████████████████████████▋              | 106/129 [00:44<00:09,  2.43it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 83%|██████████████████████████████████████████████████████████████████▎             | 107/129 [00:44<00:09,  2.39it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 84%|██████████████████████████████████████████████████████████████████▉             | 108/129 [00:45<00:08,  2.38it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 84%|███████████████████████████████████████████████████████████████████▌            | 109/129 [00:45<00:07,  2.55it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 85%|████████████████████████████████████████████████████████████████████▏           | 110/129 [00:45<00:07,  2.67it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 86%|████████████████████████████████████████████████████████████████████▊           | 111/129 [00:46<00:06,  2.71it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 87%|█████████████████████████████████████████████████████████████████████▍          | 112/129 [00:46<00:06,  2.76it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 88%|██████████████████████████████████████████████████████████████████████          | 113/129 [00:46<00:05,  2.77it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 88%|██████████████████████████████████████████████████████████████████████▋         | 114/129 [00:47<00:05,  2.83it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 89%|███████████████████████████████████████████████████████████████████████▎        | 115/129 [00:47<00:04,  2.81it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 90%|███████████████████████████████████████████████████████████████████████▉        | 116/129 [00:48<00:04,  2.84it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 91%|████████████████████████████████████████████████████████████████████████▌       | 117/129 [00:48<00:04,  2.85it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 91%|█████████████████████████████████████████████████████████████████████████▏      | 118/129 [00:48<00:03,  2.85it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 92%|█████████████████████████████████████████████████████████████████████████▊      | 119/129 [00:49<00:03,  2.79it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 93%|██████████████████████████████████████████████████████████████████████████▍     | 120/129 [00:49<00:03,  2.79it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 94%|███████████████████████████████████████████████████████████████████████████     | 121/129 [00:49<00:02,  2.81it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 95%|███████████████████████████████████████████████████████████████████████████▋    | 122/129 [00:50<00:02,  2.76it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 95%|████████████████████████████████████████████████████████████████████████████▎   | 123/129 [00:50<00:02,  2.79it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 96%|████████████████████████████████████████████████████████████████████████████▉   | 124/129 [00:50<00:01,  2.75it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 97%|█████████████████████████████████████████████████████████████████████████████▌  | 125/129 [00:51<00:01,  2.79it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 98%|██████████████████████████████████████████████████████████████████████████████▏ | 126/129 [00:51<00:01,  2.65it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 98%|██████████████████████████████████████████████████████████████████████████████▊ | 127/129 [00:52<00:00,  2.71it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


 99%|███████████████████████████████████████████████████████████████████████████████▍| 128/129 [00:52<00:00,  2.74it/s]

Size of embeddings: torch.Size([32, 14, 14, 128])


100%|████████████████████████████████████████████████████████████████████████████████| 129/129 [00:52<00:00,  2.45it/s]

Size of embeddings: torch.Size([25, 14, 14, 128])



  0%|                                                                                          | 0/100 [00:00<?, ?it/s]