# Code for running SiameseNet and TripletNet on BRACOL Dataset

***The code used in this notebook is mostly from [this](https://github.com/adambielski/siamese-triplet/). If you are interested in it, check it out, it is extremely well documented***

***Install libs***

In [5]:
import sys
sys.path.insert(0,'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/')
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor
from dataloaders import get_train_transforms, get_val_transforms, get_triplet_dataloader
from networks import TripletNet 
from models import MobileNetv2
from models import EfficientNetB4
from losses import TripletLoss
from trainer import fit
import torchvision
import timm
from IPython.display import clear_output 
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.manifold import TSNE
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import accuracy_score, f1_score , precision_score , recall_score

# استفاده از مدل EfficientNet به جای MobileNetv2 برای استخراج ویژگی‌ها
embedding_net =EfficientNetB4() #timm.create_model('efficientnet_b0', pretrained=True)  # استفاده از مدل EfficientNet
siamese_model = TripletNet(embedding_net=embedding_net)

optimizer = torch.optim.Adam(siamese_model.parameters(), lr=1e-4)  # تغییر به Adam برای بهبود عملکرد
lr_scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)  # استفاده از Cosine Annealing برای تغییر نرخ یادگیری
loss_fn = TripletLoss(1.)
n_epochs = 100  # تعداد epochs
device = torch.cuda.is_available()

if device:
    siamese_model.cuda()

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to C:\Users\Mey/.cache\torch\hub\checkpoints\efficientnet-b4-6ed6700e.pth
100%|█████████████████████████████████████████████████████████████████████████████| 74.4M/74.4M [33:37<00:00, 38.7kB/s]


Loaded pretrained weights for efficientnet-b4


***Import some libs***

In [None]:
# بارگذاری داده‌ها
path_data = 'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
triplet_train_loader = get_triplet_dataloader(root=path_data + '/train/', batch_size=5, transforms=get_train_transforms())
triplet_val_loader = get_triplet_dataloader(root=path_data + '/val/', batch_size=5, transforms=get_val_transforms())




***Define model hiperparams***

In [6]:
import torch
# آموزش مدل Siamese
#fit(triplet_train_loader, triplet_val_loader, siamese_model, loss_fn, optimizer, lr_scheduler, n_epochs, device, log_interval=10)
siamese_model = torch.load("C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/siamese.h5",map_location=torch.device('cpu'))
siamese_model.eval()

  siamese_model = torch.load("C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/siamese.h5",map_location=torch.device('cpu'))


TripletNet(
  (embedding_net): MobileNetv2(
    (features): Sequential(
      (0): ConvNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): ConvNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): InvertedResidual(
        (conv): Sequential(
          (0): ConvNormActivation(
            (0): Conv2d(16, 96, kernel_size=(1,

In [11]:
import cv2
import numpy as np
from sklearn.manifold import TSNE
import matplotlib as mpl
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import transforms
from torch.autograd import Variable
import os
import pandas as pd
import seaborn as sns
# استخراج ویژگی‌ها با استفاده از مدل Siamese
#def generate_embeddings(data_loader, model):
#    with torch.no_grad():
#        #device = 'cuda'
#        model.eval()
#        #model.to(device)
#        labels = None
#        embeddings = None
#        for batch_idx, data in tqdm(enumerate(data_loader)):
#            batch_imgs, batch_labels = data
#            batch_labels = batch_labels.numpy()
#           # batch_imgs = Variable(batch_imgs.to('cuda'))
#            batch_E = model.get_embedding(batch_imgs)
#            batch_E = batch_E.data.cpu().numpy()
#            embeddings = np.concatenate((embeddings, batch_E), axis=0) if embeddings is not None else batch_E
#            labels = np.concatenate((labels, batch_labels), axis=0) if labels is not None else batch_labels
#    return embeddings, labels
def generate_embeddings(data_loader, model):
    with torch.no_grad():
        #device = 'cuda'
        model.eval()
        #model.to(device)
        embeddings = []
        labels = []
        for batch_imgs, batch_labels in data_loader:
            #if device:
            #    batch_imgs = batch_imgs.cuda()
            batch_E = model.get_embedding(batch_imgs)
            print("*************************************")
            print(batch_E)
            print("#####################################")
            embeddings.append(batch_E.cpu().numpy())
            labels.append(batch_labels.numpy())
    return np.concatenate(embeddings), np.concatenate(labels)

def vis_tSNE(embeddings, labels, backbone='Convnet'):
    num_samples = embeddings.shape[0]
    X_embedded = TSNE(n_components=2).fit_transform(embeddings[0:num_samples, :])
    plt.figure(figsize=(16, 16))
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd','#1fa7b4', '#fb7f0e', '#27a02c', '#da2758', '#a46abd','#af7bb4', '#fa7fbe', '#2baf2c', '#4f2d28', '#b4f7bd']
    labels_name = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___healthy', 'Potato___Late_blight','Tomato__Target_Spot','Tomato__Tomato_mosaic_virus','Tomato__Tomato_YellowLeaf__Curl_Virus','Tomato_Bacterial_spot','Tomato_Early_blight','Tomato_healthy','Tomato_Late_blight','Tomato_Leaf_Mold','Tomato_Septoria_leaf_spot','Tomato_Spider_mites_Two_spotted_spider_mite']
    for i in range(16):
        inds = np.where(labels==i)[0]
        plt.scatter(X_embedded[inds,0], X_embedded[inds,1], alpha=.8, color=colors[i], s=200)
    # plt.title(f't-SNE', fontweight='bold', fontsize=24)
    plt.legend(labels_name, fontsize=30)
    plt.savefig(f'./tsne_{backbone}.png')
    
# تعریف مدل سفارشی ViT
path_data = 'C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/src/dataset'
class CustomViT(nn.Module):
    def __init__(self, original_vit_model, embedding_dim, num_classes):
        super(CustomViT, self).__init__()
        # حذف لایه‌های ابتدایی
        self.vit = nn.Sequential(*list(original_vit_model.children())[2:])  # لایه‌های ابتدایی حذف شده
        #self.fc = nn.Linear(embedding_dim, num_classes)  # لایه نهایی برای طبقه‌بندی
        self.fc = nn.Linear(1280, num_classes)  # لایه نهایی برای طبقه‌بندی
        #self.fc = nn.Linear(1792, num_classes)
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.vit(x)
        x = self.dropout(x)
        x = self.fc(x)
        return x    

train_data = torchvision.datasets.ImageFolder(root=path_data + '/train/', transform=get_val_transforms())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32)

test_data = torchvision.datasets.ImageFolder(root=path_data + '/test/', transform=get_val_transforms())
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32)

# بارگذاری مدل ViT از Hugging Face
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
vit_model.classifier = torch.nn.Linear(vit_model.config.hidden_size, 15)  # تعداد کلاس‌ها

embedding_dim = vit_model.config.hidden_size  # ابعاد embedding
model = CustomViT(vit_model, embedding_dim, num_classes=15)

# استخراج ویژگی‌ها از داده‌های آموزش
train_embeddings, train_labels = generate_embeddings(train_loader, siamese_model )
# استخراج ویژگی‌ها از داده‌های تست
test_embeddings, test_labels = generate_embeddings(test_loader, siamese_model )
 
# تبدیل داده‌ها به تنسور
X_train, y_train = train_embeddings, train_labels
X_val, y_val = test_embeddings, test_labels

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.long)

# انتقال مدل به دستگاه CUDA
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
    model.cuda()

# تنظیمات loss function و optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)  # استفاده از AdamW

# آموزش مدل
model.train()
for epoch in range(1000):  # تعداد epochs
    optimizer.zero_grad()
    outputs = model(X_train_tensor.to(device))
    loss = criterion(outputs, y_train_tensor.to(device))
    loss.backward()
    optimizer.step()
    lr_scheduler.step()  # اعمال تغییرات در نرخ یادگیری
    print(f'Epoch [{epoch+1}/1000], Loss: {loss.item():.4f}')

# ارزیابی مدل
model.eval()
with torch.no_grad():
    val_outputs = model(X_val_tensor.to(device))
    _, predicted = torch.max(val_outputs.data, 1)

# محاسبه دقت و امتیاز F1
accuracy = accuracy_score(y_val_tensor.cpu(), predicted.cpu())
f1 = f1_score(y_val_tensor.cpu(), predicted.cpu(), average='weighted')
precision = precision_score(y_val_tensor.cpu(), predicted.cpu(), average='macro') 
recall = recall_score(y_val_tensor.cpu(), predicted.cpu(), average='macro')

# نمایش نتایج
print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}')
print(f'Precision Score: {precision}')
print(f'Recall Score: {recall}')

# ذخیره مدل
torch.save(model.state_dict(), "C:/Users/Mey/Documents/PlantDiseaseDiagnosisFewShotLearning/siamese_triplet_net/siameseNFnet_improvedLocal.h5")

# تابع تولید Embedding‌ها
def generate_embeddings(data_loader, model):
    with torch.no_grad():
        model.eval()
        embeddings = []
        labels = []
        for batch_imgs, batch_labels in data_loader:
            #if device:
                #batch_imgs = batch_imgs.cuda()
            batch_E = model.get_embedding(batch_imgs)  # استخراج ویژگی‌ها
            embeddings.append(batch_E.cpu().numpy())
            labels.append(batch_labels.numpy())
    return np.concatenate(embeddings), np.concatenate(labels)

# تابع vis_tSNE برای نمایش ویژگی‌ها
def vis_tSNE(embeddings, labels, backbone='EfficientNet'):
    num_samples = embeddings.shape[0]
    X_embedded = TSNE(n_components=2).fit_transform(embeddings[0:num_samples, :])
    plt.figure(figsize=(16, 16))
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd','#1fa7b4', '#fb7f0e', '#27a02c', '#da2758', '#a46abd','#af7bb4', '#fa7fbe', '#2baf2c', '#4f2d28', '#b4f7bd']
    labels_name = ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___healthy', 'Potato___Late_blight','Tomato__Target_Spot','Tomato__Tomato_mosaic_virus','Tomato__Tomato_YellowLeaf__Curl_Virus','Tomato_Bacterial_spot','Tomato_Early_blight','Tomato_healthy','Tomato_Late_blight','Tomato_Leaf_Mold','Tomato_Septoria_leaf_spot','Tomato_Spider_mites_Two_spotted_spider_mite']
    for i in range(16):
        inds = np.where(labels==i)[0]
        plt.scatter(X_embedded[inds,0], X_embedded[inds,1], alpha=.8, color=colors[i], s=200)
    plt.legend(labels_name, fontsize=30)
    plt.savefig(f'./tsne_{backbone}.png')

# نمایش t-SNE برای داده‌های تست
test_data = torchvision.datasets.ImageFolder(root=path_data + '/test/', transform=get_val_transforms())
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32)
#val_embeddings_cl, val_labels_cl = generate_embeddings(test_loader, siamese_model)
val_embeddings_cl, val_labels_cl = generate_embeddings(test_loader, model)
vis_tSNE(val_embeddings_cl, val_labels_cl)

# نتیجه نهایی
print("Model training and evaluation completed.")

*************************************
tensor([[0.0000e+00, 6.2374e-02, 1.6497e-02,  ..., 2.8724e-03, 8.1951e-05,
         5.1641e-05],
        [0.0000e+00, 6.0241e-02, 1.7119e-02,  ..., 4.1973e-03, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 6.0687e-02, 1.7050e-02,  ..., 2.4782e-03, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 1.3913e-02,  ..., 0.0000e+00, 0.0000e+00,
         1.2238e-03],
        [4.9006e-02, 0.0000e+00, 1.5301e-02,  ..., 4.7019e-02, 0.0000e+00,
         5.6611e-02],
        [5.1137e-02, 5.5524e-05, 1.7489e-02,  ..., 4.5040e-02, 0.0000e+00,
         5.7446e-02]])
#####################################
*************************************
tensor([[0.0466, 0.0004, 0.0173,  ..., 0.0448, 0.0000, 0.0562],
        [0.0478, 0.0010, 0.0171,  ..., 0.0466, 0.0000, 0.0544],
        [0.0490, 0.0000, 0.0218,  ..., 0.0458, 0.0000, 0.0555],
        ...,
        [0.0000, 0.0000, 0.0043,  ..., 0.0000, 0.0000, 0.0020],
        [0.0000, 0.00

*************************************
tensor([[0.0000e+00, 0.0000e+00, 1.6879e-02,  ..., 0.0000e+00, 0.0000e+00,
         3.1674e-03],
        [0.0000e+00, 0.0000e+00, 1.5161e-02,  ..., 0.0000e+00, 0.0000e+00,
         1.1653e-03],
        [0.0000e+00, 0.0000e+00, 1.7548e-02,  ..., 1.6139e-05, 0.0000e+00,
         3.7442e-03],
        ...,
        [0.0000e+00, 0.0000e+00, 1.6293e-02,  ..., 0.0000e+00, 0.0000e+00,
         2.7208e-03],
        [0.0000e+00, 0.0000e+00, 1.5568e-02,  ..., 0.0000e+00, 0.0000e+00,
         1.9123e-03],
        [0.0000e+00, 0.0000e+00, 1.7026e-02,  ..., 0.0000e+00, 0.0000e+00,
         2.1763e-03]])
#####################################
*************************************
tensor([[0.0000e+00, 0.0000e+00, 1.5912e-02,  ..., 0.0000e+00, 0.0000e+00,
         2.5558e-03],
        [0.0000e+00, 0.0000e+00, 1.6186e-02,  ..., 0.0000e+00, 0.0000e+00,
         1.7392e-03],
        [0.0000e+00, 0.0000e+00, 1.6418e-02,  ..., 3.5549e-05, 0.0000e+00,
         2.8453e-03],

*************************************
tensor([[0.0495, 0.0002, 0.0172,  ..., 0.0459, 0.0000, 0.0569],
        [0.0511, 0.0000, 0.0176,  ..., 0.0476, 0.0000, 0.0563],
        [0.0505, 0.0000, 0.0159,  ..., 0.0460, 0.0000, 0.0573],
        ...,
        [0.0000, 0.0479, 0.0008,  ..., 0.0000, 0.0000, 0.0235],
        [0.0000, 0.0483, 0.0008,  ..., 0.0000, 0.0000, 0.0232],
        [0.0000, 0.0472, 0.0000,  ..., 0.0000, 0.0000, 0.0189]])
#####################################
*************************************
tensor([[0.0000e+00, 4.8947e-02, 1.0239e-03,  ..., 0.0000e+00, 3.7388e-05,
         1.8898e-02],
        [0.0000e+00, 4.8473e-02, 2.5679e-04,  ..., 0.0000e+00, 0.0000e+00,
         1.9531e-02],
        [0.0000e+00, 4.6946e-02, 9.2397e-04,  ..., 0.0000e+00, 0.0000e+00,
         2.1189e-02],
        ...,
        [0.0000e+00, 4.8570e-02, 9.2164e-04,  ..., 0.0000e+00, 0.0000e+00,
         2.1704e-02],
        [0.0000e+00, 4.4509e-02, 4.5105e-04,  ..., 0.0000e+00, 0.0000e+00,
         2.0

*************************************
tensor([[0.0000e+00, 0.0000e+00, 1.5916e-03,  ..., 0.0000e+00, 3.9242e-02,
         0.0000e+00],
        [0.0000e+00, 7.7627e-05, 0.0000e+00,  ..., 0.0000e+00, 4.2068e-02,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 2.5746e-03,  ..., 0.0000e+00, 3.9949e-02,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 4.6217e-04,  ..., 0.0000e+00, 4.2714e-02,
         0.0000e+00],
        [0.0000e+00, 2.8583e-04, 2.0112e-03,  ..., 0.0000e+00, 3.9961e-02,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 3.8597e-02,
         0.0000e+00]])
#####################################
*************************************
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 3.5712e-02,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 1.7206e-03,  ..., 0.0000e+00, 4.1797e-02,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 3.2601e-03,  ..., 0.0000e+00, 3.9159e-02,
         0.0000e+00],

*************************************
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0038],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
#####################################
*************************************
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
#####################################
*************************************
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.000

*************************************
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0017, 0.0000, 0.0000],
        [0.0013, 0.0000, 0.0000,  ..., 0.0018, 0.0000, 0.0000],
        [0.0007, 0.0000, 0.0000,  ..., 0.0017, 0.0000, 0.0000],
        ...,
        [0.0017, 0.0000, 0.0000,  ..., 0.0026, 0.0000, 0.0000],
        [0.0044, 0.0000, 0.0000,  ..., 0.0029, 0.0000, 0.0000],
        [0.0008, 0.0000, 0.0000,  ..., 0.0020, 0.0000, 0.0000]])
#####################################
*************************************
tensor([[0.0017, 0.0000, 0.0000,  ..., 0.0018, 0.0000, 0.0000],
        [0.0031, 0.0000, 0.0000,  ..., 0.0042, 0.0000, 0.0000],
        [0.0004, 0.0000, 0.0000,  ..., 0.0013, 0.0000, 0.0000],
        ...,
        [0.0031, 0.0000, 0.0000,  ..., 0.0057, 0.0000, 0.0000],
        [0.0014, 0.0000, 0.0000,  ..., 0.0038, 0.0000, 0.0000],
        [0.0042, 0.0000, 0.0000,  ..., 0.0036, 0.0000, 0.0000]])
#####################################
*************************************
tensor([[2.905

*************************************
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
#####################################
*************************************
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
#####################################
*************************************
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
#####################################
**********

*************************************
tensor([[0.0000e+00, 0.0000e+00, 3.2090e-02,  ..., 7.6104e-05, 4.9720e-02,
         3.8341e-02],
        [0.0000e+00, 0.0000e+00, 2.9227e-02,  ..., 3.9782e-05, 5.0527e-02,
         3.5819e-02],
        [0.0000e+00, 0.0000e+00, 3.3121e-02,  ..., 0.0000e+00, 5.6331e-02,
         3.9630e-02],
        ...,
        [8.9715e-05, 5.7339e-04, 3.3699e-02,  ..., 1.8079e-04, 5.5772e-02,
         3.7311e-02],
        [3.7871e-05, 0.0000e+00, 2.8445e-02,  ..., 6.8144e-04, 5.3815e-02,
         3.3904e-02],
        [5.8656e-05, 5.6636e-05, 3.0683e-02,  ..., 7.8618e-04, 5.4592e-02,
         3.6403e-02]])
#####################################
*************************************
tensor([[0.0000e+00, 0.0000e+00, 2.8788e-02,  ..., 0.0000e+00, 4.8902e-02,
         3.5152e-02],
        [0.0000e+00, 3.0475e-04, 3.3533e-02,  ..., 4.4437e-05, 5.5648e-02,
         3.9389e-02],
        [0.0000e+00, 0.0000e+00, 3.3496e-02,  ..., 0.0000e+00, 5.7268e-02,
         3.6363e-02],

*************************************
tensor([[0.0000, 0.0196, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0385, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0231, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0389, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0414, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0400, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
#####################################
*************************************
tensor([[0.0410, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0411, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0375, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0404, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0376, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0394, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
#####################################
*************************************
tensor([[4.149

Epoch [114/1000], Loss: 2.5076
Epoch [115/1000], Loss: 2.5087
Epoch [116/1000], Loss: 2.5081
Epoch [117/1000], Loss: 2.5042
Epoch [118/1000], Loss: 2.5062
Epoch [119/1000], Loss: 2.5017
Epoch [120/1000], Loss: 2.4985
Epoch [121/1000], Loss: 2.4991
Epoch [122/1000], Loss: 2.4970
Epoch [123/1000], Loss: 2.4966
Epoch [124/1000], Loss: 2.4911
Epoch [125/1000], Loss: 2.4924
Epoch [126/1000], Loss: 2.4892
Epoch [127/1000], Loss: 2.4896
Epoch [128/1000], Loss: 2.4863
Epoch [129/1000], Loss: 2.4856
Epoch [130/1000], Loss: 2.4823
Epoch [131/1000], Loss: 2.4797
Epoch [132/1000], Loss: 2.4797
Epoch [133/1000], Loss: 2.4748
Epoch [134/1000], Loss: 2.4771
Epoch [135/1000], Loss: 2.4757
Epoch [136/1000], Loss: 2.4711
Epoch [137/1000], Loss: 2.4720
Epoch [138/1000], Loss: 2.4696
Epoch [139/1000], Loss: 2.4654
Epoch [140/1000], Loss: 2.4625
Epoch [141/1000], Loss: 2.4626
Epoch [142/1000], Loss: 2.4608
Epoch [143/1000], Loss: 2.4612
Epoch [144/1000], Loss: 2.4612
Epoch [145/1000], Loss: 2.4537
Epoch [1

Epoch [386/1000], Loss: 2.0654
Epoch [387/1000], Loss: 2.0656
Epoch [388/1000], Loss: 2.0602
Epoch [389/1000], Loss: 2.0625
Epoch [390/1000], Loss: 2.0625
Epoch [391/1000], Loss: 2.0567
Epoch [392/1000], Loss: 2.0558
Epoch [393/1000], Loss: 2.0606
Epoch [394/1000], Loss: 2.0543
Epoch [395/1000], Loss: 2.0547
Epoch [396/1000], Loss: 2.0496
Epoch [397/1000], Loss: 2.0509
Epoch [398/1000], Loss: 2.0483
Epoch [399/1000], Loss: 2.0436
Epoch [400/1000], Loss: 2.0438
Epoch [401/1000], Loss: 2.0487
Epoch [402/1000], Loss: 2.0421
Epoch [403/1000], Loss: 2.0378
Epoch [404/1000], Loss: 2.0456
Epoch [405/1000], Loss: 2.0382
Epoch [406/1000], Loss: 2.0449
Epoch [407/1000], Loss: 2.0345
Epoch [408/1000], Loss: 2.0383
Epoch [409/1000], Loss: 2.0347
Epoch [410/1000], Loss: 2.0278
Epoch [411/1000], Loss: 2.0298
Epoch [412/1000], Loss: 2.0301
Epoch [413/1000], Loss: 2.0250
Epoch [414/1000], Loss: 2.0295
Epoch [415/1000], Loss: 2.0207
Epoch [416/1000], Loss: 2.0188
Epoch [417/1000], Loss: 2.0165
Epoch [4

Epoch [678/1000], Loss: 1.6638
Epoch [679/1000], Loss: 1.6612
Epoch [680/1000], Loss: 1.6660
Epoch [681/1000], Loss: 1.6579
Epoch [682/1000], Loss: 1.6643
Epoch [683/1000], Loss: 1.6653
Epoch [684/1000], Loss: 1.6647
Epoch [685/1000], Loss: 1.6631
Epoch [686/1000], Loss: 1.6477
Epoch [687/1000], Loss: 1.6621
Epoch [688/1000], Loss: 1.6539
Epoch [689/1000], Loss: 1.6464
Epoch [690/1000], Loss: 1.6541
Epoch [691/1000], Loss: 1.6555
Epoch [692/1000], Loss: 1.6476
Epoch [693/1000], Loss: 1.6406
Epoch [694/1000], Loss: 1.6429
Epoch [695/1000], Loss: 1.6388
Epoch [696/1000], Loss: 1.6472
Epoch [697/1000], Loss: 1.6405
Epoch [698/1000], Loss: 1.6392
Epoch [699/1000], Loss: 1.6422
Epoch [700/1000], Loss: 1.6446
Epoch [701/1000], Loss: 1.6438
Epoch [702/1000], Loss: 1.6284
Epoch [703/1000], Loss: 1.6266
Epoch [704/1000], Loss: 1.6355
Epoch [705/1000], Loss: 1.6266
Epoch [706/1000], Loss: 1.6325
Epoch [707/1000], Loss: 1.6267
Epoch [708/1000], Loss: 1.6210
Epoch [709/1000], Loss: 1.6279
Epoch [7

Epoch [956/1000], Loss: 1.3592
Epoch [957/1000], Loss: 1.3512
Epoch [958/1000], Loss: 1.3506
Epoch [959/1000], Loss: 1.3577
Epoch [960/1000], Loss: 1.3601
Epoch [961/1000], Loss: 1.3552
Epoch [962/1000], Loss: 1.3597
Epoch [963/1000], Loss: 1.3518
Epoch [964/1000], Loss: 1.3492
Epoch [965/1000], Loss: 1.3507
Epoch [966/1000], Loss: 1.3436
Epoch [967/1000], Loss: 1.3444
Epoch [968/1000], Loss: 1.3486
Epoch [969/1000], Loss: 1.3460
Epoch [970/1000], Loss: 1.3447
Epoch [971/1000], Loss: 1.3336
Epoch [972/1000], Loss: 1.3451
Epoch [973/1000], Loss: 1.3405
Epoch [974/1000], Loss: 1.3494
Epoch [975/1000], Loss: 1.3382
Epoch [976/1000], Loss: 1.3348
Epoch [977/1000], Loss: 1.3351
Epoch [978/1000], Loss: 1.3414
Epoch [979/1000], Loss: 1.3241
Epoch [980/1000], Loss: 1.3317
Epoch [981/1000], Loss: 1.3277
Epoch [982/1000], Loss: 1.3370
Epoch [983/1000], Loss: 1.3402
Epoch [984/1000], Loss: 1.3366
Epoch [985/1000], Loss: 1.3384
Epoch [986/1000], Loss: 1.3214
Epoch [987/1000], Loss: 1.3180
Epoch [9

AttributeError: 'CustomViT' object has no attribute 'get_embedding'