## Embedder 사용방법 예시

plantnet-300K의 특정 레이블에서 embedding 뽑아내기.

In [None]:
import torch
from torch import nn
from models import EfficientB4
from torchvision import transforms
from config.path import PATH
from experiment.extract_embeddings import Embedder

root = PATH["PLANTNET-300K"]
split = 'train'
transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((380, 380)),
            transforms.ToTensor()])
device = "cuda:2"
weight_dir = "/home/files/experiments/efficientB4/exp_set3/checkpoints/checkpoint.pt"

emb = Embedder(root, split)
labels = list(emb.label_to_class.keys())
model = EfficientB4(num_classes=emb.num_classes, loss_fn=nn.CrossEntropyLoss()) #get your model
model.load(weight_dir) # load Its the best checkpoint.
emb.get_model(model, transform, device)


In [None]:
label = labels[0]
embeddings, top_1_class, top_1_prob, correctness, file_paths = emb.extract_embeddings(label)

## 전체 데이터에서 embedding 추출하기

for loop를 통해 전체 label의 embedding 추출하기

In [None]:
import torch
from torch import nn
from models import EfficientB4
from torchvision import transforms
from config.path import PATH
from experiment.extract_embeddings import Embedder

root = PATH["PLANTNET-300K"]
transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((380, 380)),
            transforms.ToTensor()])
device = "cuda:0"
weight_dir = "/home/files/experiments/efficientB4/exp_set3/checkpoints/checkpoint.pt"

In [None]:
from os.path import join
for split in ["train", "val", "test"]:
    emb = Embedder(root, split)
    labels = list(emb.label_to_class.keys())
    model = EfficientB4(num_classes=emb.num_classes, loss_fn=nn.CrossEntropyLoss()) #get your model
    model.load(weight_dir) # load Its the best checkpoint.
    emb.get_model(model, transform, device)

    for label in labels:
        path = join("/home/files/experiments/plantnet_embeddings", split)
        emb.save_embeddings(path, label)

## Run single code

하나의 코드로 위의 전체 레이블 분석 코드 돌리기

In [None]:
import torch
from torch import nn
from models import EfficientB4
from torchvision import transforms
from config.path import PATH
from experiment.extract_embeddings import Embedder
from tqdm import tqdm

root = PATH["PLANTNET-300K"]
transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((380, 380)),
            transforms.ToTensor()])
device = "cuda:0"
weight_dir = "/home/files/experiments/efficientB4/exp_set3/checkpoints/checkpoint.pt"


from os.path import join
for split in ["train", "val", "test"]:
    emb = Embedder(root, split)
    labels = list(emb.label_to_class.keys())
    model = EfficientB4(num_classes=emb.num_classes, loss_fn=nn.CrossEntropyLoss()) #get your model
    model.load(weight_dir) # load Its the best checkpoint.
    emb.get_model(model, transform, device)

    for label in tqdm(labels):
        emb.save_embeddings(join("/home/files/experiments/plantnet_embeddings", split), label)