In [1]:
# -*- coding: utf-8 -*-
import os, pickle
import time
import argparse
import shutil
from pathlib import Path

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import pytorch_lightning as pl
from torchmetrics.functional.classification import auroc, stat_scores, average_precision, precision_recall_curve, auc
from pytorch_lightning.loggers import WandbLogger
import wandb

import warnings
warnings.filterwarnings("ignore")

In [2]:
from dwac_args import DWAC
from dres_args import DRES
from resn_args import RESN

In [3]:
args = argparse.Namespace(
    train_dir='/net/scratch/hanliu/radiology/explain_teach/data/bm/train', 
    valid_dir='/net/scratch/hanliu/radiology/explain_teach/data/bm/valid', 
    eval_batch_size=1, embed_dim=10, merge_dim=2, merge_seq=True)
name = 'emb10.merged2' 
# ckpt = '/net/scratch/hanliu/radiology/explain_teach/model/results/dwac-emb10-mrg10/1b62f0sd/checkpoints/epoch=81-valid_loss=0.20.ckpt' # DWAC
ckpt = '/net/scratch/hanliu/radiology/explain_teach/model/results/dwac-emb10-mrg2/3n4ihavu/checkpoints/epoch=53-valid_loss=0.27.ckpt' # emb10-mrg2
model = DWAC.load_from_checkpoint(ckpt, **vars(args))
_ = model.eval()

Using Guassian kernel


In [20]:
batches = list(iter(model.val_dataloader()))
inputs = batches[0][0]
labels = batches[0][1]
# embeds = model.embed(inputs)
# conv
conv = [conv(inputs[:, i].unsqueeze(1).repeat(1, 16, 1, 1))
                for i, conv in enumerate(model.conv)]
embeds = torch.cat(conv, 1)
# linear
conv_x = torch.cat([c.unsqueeze(1) for c in conv], 1)
linear = [linear[0](conv_x[:, i]) for i, linear in enumerate(model.linear)]
fusion = model.fusion[0](torch.cat(conv, 1))
embeds = torch.cat(linear + [fusion], 1)
# merger
# embeds = model.merger(embeds)

batch = list(iter(model.ref_dataloader()))
ref_x = batch[0][0]
ref_y = batch[0][1]
# ref_z = model.embed(ref_x)
conv = [conv(ref_x[:, i].unsqueeze(1).repeat(1, 16, 1, 1))
                for i, conv in enumerate(model.conv)]
ref_z = torch.cat(conv, 1)
# linear
conv_x = torch.cat([c.unsqueeze(1) for c in conv], 1)
linear = [linear[0](conv_x[:, i]) for i, linear in enumerate(model.linear)]
fusion = model.fusion[0](torch.cat(conv, 1))
ref_z = torch.cat(linear + [fusion], 1)
# merger
# ref_z = model.merger(ref_z)


In [21]:
val_fids = sorted(os.listdir(model.hparams.valid_dir+'/0')) + sorted(os.listdir(model.hparams.valid_dir+'/1'))
val_fids = [fid.replace('.npy', '') for fid in val_fids]

ref_fids = sorted(os.listdir(model.hparams.train_dir+'/0')) + sorted(os.listdir(model.hparams.train_dir+'/1'))
ref_fids = [fid.replace('.npy', '') for fid in ref_fids]

val_fids = np.asarray(val_fids)
inputs = np.asarray([i.squeeze().detach().numpy() for i in inputs])
labels = np.asarray([l.squeeze().detach().numpy() for l in labels])
embeds = np.asarray([e.squeeze().detach().numpy() for e in embeds])

ref_fids = np.asarray(ref_fids)
ref_x = np.asarray([i.squeeze().detach().numpy() for i in ref_x])
ref_y = np.asarray([l.squeeze().detach().numpy() for l in ref_y])
ref_z = np.asarray([e.squeeze().detach().numpy() for e in ref_z])

path = model.hparams.valid_dir.replace('valid', 'embs/dwac_valid_{}.linear.0.pkl'.format(name))
pickle.dump((val_fids, inputs, labels, embeds), open(path, "wb"))
print("Encoded valid embeddings (fids, inputs, labels, embeds) at " + path)

path = model.hparams.train_dir.replace('train', 'embs/dwac_train_{}.linear.0.pkl'.format(name))
pickle.dump((ref_fids, ref_x, ref_y, ref_z), open(path, "wb"))
print("Encoded train embeddings (fids, inputs, labels, embeds) at " + path)

Encoded valid embeddings (fids, inputs, labels, embeds) at /net/scratch/hanliu/radiology/explain_teach/data/bm/embs/dwac_valid_emb10.merged2.linear.0.pkl
Encoded train embeddings (fids, inputs, labels, embeds) at /net/scratch/hanliu/radiology/explain_teach/data/bm/embs/dwac_train_emb10.merged2.linear.0.pkl


In [22]:
args = argparse.Namespace(
    train_dir='/net/scratch/hanliu/radiology/explain_teach/data/bm/train', 
    valid_dir='/net/scratch/hanliu/radiology/explain_teach/data/bm/valid', 
    eval_batch_size=1, embed_dim=2)
name = 'emb2' 
# ckpt = '/net/scratch/hanliu/radiology/explain_teach/model/results/dres-emb10/3r59qpvk/checkpoints/epoch=57-valid_loss=0.06.ckpt' # DRES
# ckpt = '/net/scratch/hanliu/radiology/explain_teach/model/results/dres-emb10/257qh640/checkpoints/epoch=56-valid_loss=0.25.ckpt' # DRES
ckpt = '/net/scratch/hanliu/radiology/explain_teach/model/results/dres-emb2/u15h6xun/checkpoints/epoch=34-valid_loss=0.19.ckpt' # emb2
model = DRES.load_from_checkpoint(ckpt, **vars(args))
_ = model.eval()

Using Guassian kernel


In [35]:
batches = list(iter(model.val_dataloader()))
inputs = batches[0][0]
labels = batches[0][1]
# embeds = model.embed(inputs)
embeds = model.feature_extractor(inputs)
embeds = model.fc[:4](embeds)

batch = list(iter(model.ref_dataloader()))
ref_x = batch[0][0]
ref_y = batch[0][1]
# ref_z = model.embed(ref_x)
ref_z = model.feature_extractor(ref_x)
ref_z = model.fc[:4](ref_z)

In [36]:
val_fids = sorted(os.listdir(model.hparams.valid_dir+'/0')) + sorted(os.listdir(model.hparams.valid_dir+'/1'))
val_fids = [fid.replace('.npy', '') for fid in val_fids]

ref_fids = sorted(os.listdir(model.hparams.train_dir+'/0')) + sorted(os.listdir(model.hparams.train_dir+'/1'))
ref_fids = [fid.replace('.npy', '') for fid in ref_fids]

val_fids = np.asarray(val_fids)
inputs = np.asarray([i.squeeze().detach().numpy() for i in inputs])
labels = np.asarray([l.squeeze().detach().numpy() for l in labels])
embeds = np.asarray([e.squeeze().detach().numpy() for e in embeds])

ref_fids = np.asarray(ref_fids)
ref_x = np.asarray([i.squeeze().detach().numpy() for i in ref_x])
ref_y = np.asarray([l.squeeze().detach().numpy() for l in ref_y])
ref_z = np.asarray([e.squeeze().detach().numpy() for e in ref_z])

path = model.hparams.valid_dir.replace('valid', 'embs/dres_valid_{}.linear.0.pkl'.format(name))
pickle.dump((val_fids, inputs, labels, embeds), open(path, "wb"))
print("Encoded valid embeddings (fids, inputs, labels, embeds) at " + path)

path = model.hparams.train_dir.replace('train', 'embs/dres_train_{}.linear.0.pkl'.format(name))
pickle.dump((ref_fids, ref_x, ref_y, ref_z), open(path, "wb"))
print("Encoded train embeddings (fids, inputs, labels, embeds) at " + path)

Encoded valid embeddings (fids, inputs, labels, embeds) at /net/scratch/hanliu/radiology/explain_teach/data/bm/embs/dres_valid_emb2.linear.0.pkl
Encoded train embeddings (fids, inputs, labels, embeds) at /net/scratch/hanliu/radiology/explain_teach/data/bm/embs/dres_train_emb2.linear.0.pkl


In [55]:
split = 'valid'
args = argparse.Namespace(
    valid_dir='/net/scratch/hanliu/radiology/explain_teach/data/bm/{}'.format(split), 
    eval_batch_size=1, embed_dim=2)
name = 'emb2' 
# ckpt = '/net/scratch/hanliu/radiology/explain_teach/model/results/resn-emb10/32vzr4v5/checkpoints/epoch=95-valid_loss=0.30.ckpt' # RESN
# ckpt = '/net/scratch/hanliu/radiology/explain_teach/model/results/resn-emb10/39e6y5lf/checkpoints/epoch=97-valid_loss=0.39.ckpt' # RESN
ckpt = '/net/scratch/hanliu/radiology/explain_teach/model/results/resn-emb2/2pqtxzwi/checkpoints/epoch=72-valid_loss=0.50.ckpt' # emb2
model = RESN.load_from_checkpoint(ckpt, **vars(args))
_ = model.eval()

In [58]:
batches = list(iter(model.val_dataloader()))
inputs = [b[0] for b in batches]
labels = [b[1] for b in batches]
# embeds = [model.feature_extractor(im) for im in inputs]
embeds = [model.fc[:4](model.feature_extractor(im)) for im in inputs]

In [59]:
fids = sorted(os.listdir(model.hparams.valid_dir+'/0')) + sorted(os.listdir(model.hparams.valid_dir+'/1'))
fids = [fid.replace('.jpg', '') for fid in fids]
fids = np.asarray(fids)
embeds = np.asarray([e.squeeze().detach().numpy() for e in embeds])[0]
inputs = np.asarray([i.squeeze().detach().numpy() for i in inputs])[0]
labels = np.asarray([l.squeeze().detach().numpy() for l in labels])[0]

path = model.hparams.valid_dir.replace(split, 'embs/resn_{}_{}.linear.0.pkl'.format(split, name))
pickle.dump((fids, inputs, labels, embeds), open(path, "wb"))
print("Encoded {} findings (fids, inputs, labels, embeds) at ".format(split, name) + path)

Encoded valid findings (fids, inputs, labels, embeds) at /net/scratch/hanliu/radiology/explain_teach/data/bm/embs/resn_valid_emb2.linear.0.pkl
