In [2]:
from omegaconf import OmegaConf
import numpy as np
import os
import re
import os.path as osp
import torch
import pytorch_lightning as pl
from tqdm import tqdm
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger

import model_factory
from graph_data_module import GraphDataModule
from train import Runner
from datasets.dataset_factory import create_dataset
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import torchvision
import torchmetrics
from sklearn.metrics import ConfusionMatrixDisplay

In [3]:
artifact_dir = WandbLogger.download_artifact(artifact="haraghi/DGCNN/model-0gmu6xyw:best")

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [4]:
run_id = "0gmu6xyw"
entity = "haraghi"
project = "DGCNN"
cfg_path = "cfg_folder"

In [5]:
method = "EST"
dataset = "NCALTECH101"
sampling_number = '1024'

res = [f for f in os.listdir(cfg_path) if method in f and dataset in f and sampling_number in f]
for f in res:
    print(f)
pre_cfg_path = osp.join(cfg_path,res[0])

EST_NCALTECH101_1024_ShuffleNet_250epoch_not_pretrained.yaml
EST_NCALTECH101_1024_ShuffleNet_250epoch_remove_outliers_not_pretrained.yaml
EST_NCALTECH101_1024_not_pretrained_Rsnet18.yaml
EST_NCALTECH101_1024_MobileNet_250epoch_not_pretrained.yaml
EST_NCALTECH101_1024_ShuffleNet_500epoch_batch_32_remove_outliers_not_pretrained.yaml
EST_NCALTECH101_1024_MobileNet_250epoch_remove_outliers_not_pretrained.yaml
EST_NCALTECH101_1024_not_pretrained.yaml
EST_NCALTECH101_1024.yaml


In [6]:
pre_cfg_path = osp.join(cfg_path,res[4])

In [7]:
api = wandb.Api()
if run_id is None:
    runs = api.runs(
        path=osp.join(entity,project),
        filters={"config.wandb.experiment_name": {"$regex": f"^.*{method}.*{sampling_number}.*$"}},
        # filters={"config.wandb.experiment_name": {"$regex": "^.*EST-aug 20000.*$"}}
    )
    print([l.id for l in runs])
    run_id = runs[0].id
    
print(run_id)

0gmu6xyw


In [8]:
run = wandb.init()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mharaghi[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
checkpoint_reference = osp.join(entity, project, "model-" + run_id+":best")
artifact = run.use_artifact(checkpoint_reference, type='model')
artifact_dir = artifact.download()


config = api.run(osp.join(entity, project, run_id)).config
cfg = OmegaConf.create(config)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [10]:
cfg_file = OmegaConf.load(pre_cfg_path)
cfg = OmegaConf.merge(cfg_file, cfg)
# cfg = cfg_file
print(OmegaConf.to_yaml(cfg))

description: ' Applying EST on NCALTECH101 with subsampling 1024 and NOT pretrained
  shufflenet_v2_x0_5 for 500 epoch. And, batch size is 32. Remove outliers. Only the
  learning sceduler is slightly different.'
seed: 0
optimize:
  optimizer: Adam
  lr: 0.0001
  lr_scheduler: ReduceLROnPlateau
  mode: min
  factor: 0.5
  patience: 10
model:
  name: EST
  k: null
  aggr: ''
  num_bins: 9
  cnn_type: shufflenet_v2_x0_5
  resnet_crop_dimension:
  - 224
  - 224
  est_mlp_layers:
  - 1
  - 30
  - 30
  - 1
  est_activation: nn.LeakyReLU(negative_slope=0.1)
  resnet_pretrained: false
train:
  epochs: 500
  batch_size: 32
  loss_fn: nn.CrossEntropyLoss()
  profiler: simple
  ckpt_path: null
dataset:
  name: NCALTECH101
  train_percentage: 0.75
  validation_percentage: 0.1
  image_resolution:
  - 180
  - 240
  num_samples_per_class: null
  num_classes: 101
  dataset_path: null
  num_workers: 8
transform:
  train:
    transform: true
    spatial_centering: null
    temporal_scale: null
    num_

In [11]:
cfg.transform.train.transform = False

In [12]:
# Seed everything. Note that this does not make training entirely
# deterministic.
pl.seed_everything(cfg.seed, workers=True)


# Create datasets using factory pattern
gdm = GraphDataModule(cfg)
cfg.dataset.num_classes = gdm.num_classes

Global seed set to 0


In [13]:
ds = create_dataset(
        dataset_path = gdm.dataset_path,
        dataset_name  = gdm.dataset_name,
        dataset_type = 'training',
        transform = None,
        num_workers=gdm.num_workers
    )

In [23]:
max_list = []
for data in ds:
    
    vals, args = data.pos[:,:2].max(axis=0)
    max_list.append(vals)
    


In [22]:
data_list = []
for data in ds:
    if data.file_id == 'image_0048.bin' and data.label[0] == 'kangaroo':
        data_list.append(data)


In [23]:
data_list

[Data(x=[109377, 1], pos=[109377, 3], file_id='image_0048.bin', label=[1], y=[1])]

In [25]:
data_list[0].pos.max(axis=0)

torch.return_types.max(
values=tensor([1.9600e+02, 1.7200e+02, 3.0081e+05]),
indices=tensor([   180,    492, 109376]))

In [26]:
x = torch.tensor([1,2,3,4,5,6,7,8,9,10])
%timeit x.long()

315 ns ± 13.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
