In [None]:
from glob import glob
import os.path as osp
from collections import OrderedDict

import wandb
from omegaconf import OmegaConf
import torch
import torch.nn.functional as F

import model_factory
import dataset_factory


In [None]:
run = wandb.init()

In [None]:
exprement_number = 'ewrilqhd'
batch_number = 'best'
path_files = glob(osp.join('wandb','*-'+exprement_number,'files','config.yaml'))
cfg = OmegaConf.load(path_files[0])
cfg.dataset = cfg.dataset.value
cfg.model = cfg.model.value
cfg.optimize = cfg.optimize.value
cfg.seed = cfg.seed.value
cfg.train = cfg.train.value
cfg.transform = cfg.transform.value
cfg.wandb = cfg.wandb.value
path_files = glob(osp.join(cfg.wandb.project,exprement_number,'checkpoints','*'))
pl_ckpt_file = path_files[0]
print(pl_ckpt_file)

In [None]:
artifact = run.use_artifact(osp.join(cfg.wandb.entity,cfg.wandb.project,f'model-{exprement_number}:{batch_number}'), type='model')
artifact_dir = artifact.download()
wandb_ckpt_file = osp.join(artifact_dir,'model.ckpt')
print(wandb_ckpt_file)

In [None]:
checkpoint = torch.load(wandb_ckpt_file)
ordered_list = [(a[6:],b) for a , b in checkpoint['state_dict'].items()]
stated_dict_wandb = OrderedDict(ordered_list)

checkpoint = torch.load(pl_ckpt_file)
ordered_list = [(a[6:],b) for a , b in checkpoint['state_dict'].items()]
stated_dict_pl = OrderedDict(ordered_list)

is_same = all([torch.allclose(v,stated_dict_pl[c]) for c,v in stated_dict_wandb.items()])
print(f'Are pl and wandb the same? {is_same}')


In [None]:
cfg.train.batch_size = 32
loaders = dataset_factory.factory(cfg)
train_dataset_loader, val_dataset_loader, test_dataset_loader = loaders

In [None]:
cfg.train

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg.model.name = "DGCNN"
model_2knn = model_factory.factory(cfg).to(device)
cfg.model.name = "DGCNN2"
model_1knn = model_factory.factory(cfg).to(device)

In [None]:
def test(loader, model):
    model.eval()
    all_pred = []
    all_true = []
    correct = 0
    total_loss = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out,_ = model(data)
            pred = out.max(dim=1)[1]
            all_pred.append(pred)
            all_true.append(data.y)
            loss = F.nll_loss(out, data.y)
            correct += pred.eq(data.y).sum().item()
            total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset), correct / len(loader.dataset), all_pred, all_true

In [None]:
model_2knn.load_state_dict(stated_dict_pl)
model_1knn.load_state_dict(stated_dict_pl)

if not is_same:
    model_2knn.state_dict(stated_dict_wandb)
    model_1knn.state_dict(stated_dict_wandb)

In [None]:
for _ in range(5):
    perf = test(test_dataset_loader, model_2knn)
    print(perf[0], perf[1])

In [None]:
dataset = test_dataset_loader.dataset
dataset_idx = [x for x in range(len(dataset)) if dataset[x].label[0] == 'butterfly']


In [None]:
model_2knn.eval()
torch.cuda.empty_cache()
out,inters = model_2knn(dataset[dataset_idx[0]].to(device))

In [None]:
dataset[dataset_idx[0]]

In [None]:
import scipy.io as sio

sio.savemat('butterfly_image_0014.mat', {'pos': dataset[dataset_airplane_idx[0]].pos.cpu().numpy(), 'x1': inters['x1'].detach().cpu().numpy(),
                                         'x2': inters['x2'].detach().cpu().numpy(), 'out1': inters['out1'].detach().cpu().numpy(), 
                                         'out2': inters['out2'].detach().cpu().numpy(),'out': inters['out'].detach().cpu().numpy()})



In [None]:
x1s = []
x2s = []
out1s = []
out2s = []
for idx in dataset_airplane_idx:
    out,inters = model_2knn(dataset[idx].to(device))
    x1s.append(inters['x1'])
    x2s.


In [None]:
torch.cuda.empty_cache()

In [None]:
for data in dataset_airplane:
    with torch.no_grad():
        data = data.to(device)
        out,inters = model_2knn(data)
    # dd.inters = inters


In [None]:
perf_2 = test(test_dataset_loader, model_2knn)
perf_1 = test(test_dataset_loader, model_1knn)
print(perf_2[0], perf_2[1])
print(perf_1[0], perf_1[1])
perf_2 = test(val_dataset_loader, model_2knn)
perf_1 = test(val_dataset_loader, model_1knn)
print(perf_2[0], perf_2[1])
print(perf_1[0], perf_1[1])

In [None]:
class_names = []
for cc in range(101):
    idx = np.where(test_dataset.data.y == cc)[0][0]   
    class_names.append(test_dataset.data.label[idx][0])
    print(test_dataset.data.y[idx].item(), test_dataset.data.label[idx][0])

In [None]:
%matplotlib widget

y_true = torch.cat(all_true,dim=0).cpu().numpy()
y_pred = torch.cat(all_pred,dim=0).cpu().numpy()

conf_mat = confusion_matrix(y_true, y_pred)
np.fill_diagonal(conf_mat, 0)
# plt.figure(figsize=(10,10))
plt.imshow(conf_mat, interpolation='none')
plt.show()

In [None]:
max_number = 20
rows_sorted, columns_sorted = np.unravel_index(np.argsort(conf_mat.flatten())[::-1][:max_number], conf_mat.shape)
for r,c in zip(rows_sorted,columns_sorted):
    print(conf_mat[r,c],class_names[r],r,class_names[c],c)

In [None]:
%matplotlib widget
max_number = 20
num_classses = [torch.sum(test_dataset.data.y == i).item() for i in np.arange(101)]
num_errors = np.sum(conf_mat,axis=1)
sorted_vec = num_errors/num_classses
# plt.stem(err_vec,linefmt = 'r:',markerfmt='rD')
for ii in sorted_vec.argsort()[::-1][:max_number]:
    print(ii,num_errors[ii],num_classses[ii],num_errors[ii]/num_classses[ii])
plt.show()

In [None]:
for ii in np.where(np.diag(confusion_matrix(y_true, y_pred)) == 0)[0]:
    print(ii,class_names[ii])