In [1]:
import os
import warnings
from datetime import datetime
from pathlib import Path

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler

from unsupervised_meta_learning.callbacks.confinterval import *
from unsupervised_meta_learning.callbacks.image_generation import *
from unsupervised_meta_learning.callbacks.pcacallbacks import *
from unsupervised_meta_learning.callbacks.umapcallbacks import *
from unsupervised_meta_learning.pl_dataloaders import (
    UnlabelledDataModule,
    UnlabelledDataset,
    OracleDataModule,
)
from unsupervised_meta_learning.protoclr import ProtoCLR
from unsupervised_meta_learning.dataclasses.protoclr_container import (
    PCLRParamsContainer, ReRankerContainer,
)

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [2]:
pl.seed_everything(42)

Global seed set to 42


42

In [3]:
dataset='miniimagenet'
datapath='./data/untarred'
lr=1e-3
inner_lr=1e-3
gamma=1.0
distance="euclidean"
ckpt_dir=Path("./ckpts")
ae=False
tau=1.0
eval_ways=5
clustering_alg='kmeans'
clustering_callback=False
km_clusters=25
km_use_nearest=True
km_n_neighbours=50
cl_reduction='mean'
eval_support_shots=5
eval_query_shots=15
n_images=None
n_classes=None
n_support=1
n_query=3
batch_size=200
no_aug_support=True
no_aug_query=False
logging="wandb"
log_images=False
profiler="torch"
train_oracle_mode=False
train_oracle_ways=None
train_oracle_shots=None
num_workers=3
callbacks=True
patience=200
use_plotly=True
use_umap=False
umap_min_dist: float = .25
rdim_n_neighbors: int = 50
rdim_components: int = 2
rerank_kjrd=False
rrk1=20
rrk2=6
rrlambda=0
uuid=None

In [4]:
 params = PCLRParamsContainer(
        dataset,
        datapath,
        seed=42,
        gpus=1,
        lr=lr,
        inner_lr=inner_lr,
        gamma=gamma,
        distance=distance,
        ckpt_dir=Path("./ckpts"),
        ae=ae,
        tau=tau,
        clustering_algo=clustering_alg,
        km_clusters=km_clusters,
        km_use_nearest=km_use_nearest,
        km_n_neighbours=km_n_neighbours,
        cl_reduction=cl_reduction,
        eval_ways=eval_ways,
        eval_support_shots=eval_support_shots,
        eval_query_shots=eval_query_shots,
        n_images=n_images,
        n_classes=n_classes,
        n_support=n_support,
        n_query=n_query,
        batch_size=batch_size,
        no_aug_support=no_aug_support,
        no_aug_query=no_aug_query,
        log_images=log_images,
        train_oracle_mode=train_oracle_mode,
        train_oracle_ways=train_oracle_ways,
        train_oracle_shots=train_oracle_shots,
        num_workers=num_workers,
        use_umap=use_umap,
        umap_min_dist=umap_min_dist,
        rdim_components=int(rdim_components),
        rdim_n_neighbors=int(rdim_n_neighbors),
        rerank_kjrd=rerank_kjrd,
        re_rank_args=ReRankerContainer(k1=rrk1, k2=rrk2, lambda_value=rrlambda)
    )

In [5]:
model = ProtoCLR(params)

Clustering algo in use: kmeans


In [6]:
model = ProtoCLR.load_from_checkpoint("./data/tough-morning-38/epoch=144-step=14499-val_loss=1.46-val_accuracy=0.588.ckpt", params=params)

Clustering algo in use: kmeans


In [8]:
trainer = pl.Trainer(gpus=1, limit_test_batches=100)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [9]:
dm = UnlabelledDataModule(params)

In [10]:
trainer.test(model=model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Supervised data loader for miniimagenet:test.


  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.6507999897003174,
 'test_acc_epoch': 0.6507999897003174,
 'test_loss': 1.3385509252548218,
 'test_loss_epoch': 1.3385509252548218}
--------------------------------------------------------------------------------


[{'test_loss': 1.3385509252548218,
  'test_loss_epoch': 1.3385509252548218,
  'test_acc': 0.6507999897003174,
  'test_acc_epoch': 0.6507999897003174}]