In [1]:
import sys
from tqdm import tqdm
sys.path.append('..')

from utils import seed_everything
from utils import MetricsCallback, load_yaml_as_omegaconf, build_dataset
from models import build_pretrained_model
from data.birdsetwrapper import BirdSetWrapper
from torch.utils.data import DataLoader
import torch
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

In [2]:
args = load_yaml_as_omegaconf(yaml_file_path="../config.yaml")
seed_everything(args.random_seed + 42)
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

## 1. Get Finetuning Dataset (NBP)

In [3]:
dm = build_dataset(args)
dm.prepare_data()
dm.setup(stage="fit")

train_dataset = BirdSetWrapper(dm.train_dataset)
val_dataset = BirdSetWrapper(dm.val_dataset)

dm.setup(stage='test')
test_dataset = BirdSetWrapper(dm.test_dataset)

train_loader = DataLoader(
    train_dataset, batch_size=32, shuffle=True,num_workers=args.dataset.num_workers
)
val_loader = DataLoader(
    val_dataset, batch_size=32, shuffle=False, num_workers=args.dataset.num_workers
)
test_loader = DataLoader(
    test_dataset, batch_size=32, shuffle=False, num_workers=args.dataset.num_workers
)

2024-10-11 22:54:55 | INFO | birdset.datamodule.base_datamodule | Check if preparing has already been done.
2024-10-11 22:54:55 | INFO | birdset.datamodule.base_datamodule | Prepare Data
2024-10-11 22:54:55 | INFO | birdset.datamodule.base_datamodule | > Loading data set.
2024-10-11 22:54:58 | INFO | birdset.datamodule.birdset_datamodule | > Mapping data set.
2024-10-11 22:54:58 | INFO | birdset.datamodule.birdset_datamodule | >> Smart sampling to self.dataset_config.classlimit=500, self.dataset_config.eventlimit=1
2024-10-11 22:55:07 | INFO | birdset.datamodule.base_datamodule | Train fingerprint found in /mnt/stud/work/deeplearninglab/ss2024/ssl-1/data_birdset/HSN/HSN_processed_42_77fbb47c0e6ba6ab, saving to disk is skipped
2024-10-11 22:55:07 | INFO | birdset.datamodule.base_datamodule | fit
2024-10-11 22:55:07 | INFO | birdset.datamodule.base_datamodule | test


## 2. Init model with pretrained weights

In [4]:
model = build_pretrained_model(args).to(device)

2024-10-11 22:55:09 | INFO | models.EAT_pretraining | making target model


In [5]:
features_list = []
labels_list = []

for input, label in tqdm(train_dataset):
    input = input.to(device)
    features = model._extract_features(input.unsqueeze(dim=0)).flatten()
    features_list.append(features)
    labels_list.append(label.item())

100%|██████████| 8417/8417 [03:34<00:00, 39.24it/s]


In [6]:
features_list = torch.stack(features_list).detach().cpu()
labels_list = torch.tensor(labels_list)

## 3. Plot T-SNE Embedding

In [None]:
embeddings = TSNE().fit_transform(features_list)

In [None]:
plt.figure(figsize=(30,20))
plt.title(f'T-SNE of extracted features')
plt.scatter(embeddings[:, 0], embeddings[:, 1], c=labels)
plt.colorbar()
#plt.savefig('tsne_embeddings_EAT-base_epoch30.png')
plt.show()