In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging

import torch.nn as nn
import torch.nn.functional as F

import emmental
from cxr_dataset import CXR8Dataset
from emmental import Meta
from emmental.data import EmmentalDataLoader
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.scorer import Scorer
from emmental.task import EmmentalTask
from modules.classification_module import ClassificationModule
from modules.torch_vision_encoder import TorchVisionEncoder
from task_config import CXR8_TASK_NAMES
from transforms import get_data_transforms

In [3]:
logger = logging.getLogger(__name__)

In [4]:
emmental.init("logs")

[2019-05-03 00:52:06,470][INFO] emmental.meta:95 - Setting logging directory to: logs/2019_05_03/00_52_06
[2019-05-03 00:52:06,486][INFO] emmental.meta:56 - Loading Emmental default config from /dfs/scratch1/senwu/mmtl/emmental/src/emmental/emmental-default-config.yaml.


In [5]:
Meta.update_config(
    config={
        "meta_config": {"seed": 1701, "device": 0},
        "learner_config": {
            "n_epochs": 20,
            "valid_split": "val",
            "optimizer_config": {"optimizer": "sgd", "lr": 0.001, "l2": 0.000},
            "lr_scheduler_config": {
                "warmup_steps": None,
                "warmup_unit": "batch",
                "lr_scheduler": "linear",
                "min_lr": 1e-6,
            },
        },
        "logging_config": {"evaluation_freq": 4000, "checkpointing": False},
    }
)

[2019-05-03 00:52:06,547][INFO] emmental.meta:143 - Updating Emmental config from user provided config.


In [6]:
DATA_NAME = "CXR8"

CXRDATA_PATH = (
    f"/dfs/scratch1/senwu/mmtl/emmental-tutorials/chexnet/data/nih_labels.csv"
)
CXRIMAGE_PATH = f"/dfs/scratch1/senwu/mmtl/emmental-tutorials/chexnet/data/images"

BATCH_SIZE = 16
CNN_ENCODER = "densenet121"

BATCH_SIZES = {"train": 16, "val": 64, "test": 64}

In [7]:
cxr8_transform = get_data_transforms(DATA_NAME)

  "please use transforms.Resize instead.")


In [8]:
datasets = {}

for split in ["train", "val", "test"]:

    datasets[split] = CXR8Dataset(
        name=DATA_NAME,
        path_to_images=CXRIMAGE_PATH,
        path_to_labels=CXRDATA_PATH,
        split=split,
        transform=cxr8_transform[split],
        sample=0,
        seed=1701,
    )

    logger.info(f"Loaded {split} split for {DATA_NAME}.")

[2019-05-03 00:52:59,061][INFO] __main__:15 - Loaded train split for CXR8.
[2019-05-03 00:53:06,792][INFO] __main__:15 - Loaded val split for CXR8.
[2019-05-03 00:53:21,866][INFO] __main__:15 - Loaded test split for CXR8.


In [9]:
task_to_label_dict = {task_name: task_name for task_name in CXR8_TASK_NAMES}
print(task_to_label_dict)

{'Atelectasis': 'Atelectasis', 'Cardiomegaly': 'Cardiomegaly', 'Effusion': 'Effusion', 'Infiltration': 'Infiltration', 'Mass': 'Mass', 'Nodule': 'Nodule', 'Pneumonia': 'Pneumonia', 'Pneumothorax': 'Pneumothorax', 'Consolidation': 'Consolidation', 'Edema': 'Edema', 'Emphysema': 'Emphysema', 'Fibrosis': 'Fibrosis', 'Pleural_Thickening': 'Pleural_Thickening', 'Hernia': 'Hernia'}


In [10]:
dataloaders = []

for split in ["train", "val", "test"]:
    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict=task_to_label_dict,
            dataset=datasets[split],
            split=split,
            shuffle=True if split == "train" else False,
            batch_size=BATCH_SIZES[split],
            num_workers=8,
        )
    )
    logger.info(f"Built dataloader for {datasets[split].name} {split} set.")

[2019-05-03 00:53:21,987][INFO] __main__:14 - Built dataloader for CXR8 train set.
[2019-05-03 00:53:21,988][INFO] __main__:14 - Built dataloader for CXR8 val set.
[2019-05-03 00:53:21,989][INFO] __main__:14 - Built dataloader for CXR8 test set.


# Build Emmental task

In [11]:
from functools import partial

In [12]:
def ce_loss(task_name, immediate_ouput, Y, active):
    return F.cross_entropy(
        immediate_ouput[f"classification_module_{task_name}"][0], Y.view(-1) - 1
    )

In [13]:
def output(task_name, immediate_ouput):
    return F.softmax(immediate_ouput[f"classification_module_{task_name}"][0], dim=1)

In [14]:
input_shape = (3, 224, 224)

cnn_module = TorchVisionEncoder(CNN_ENCODER, pretrained=True)
classification_layer_dim = cnn_module.get_frm_output_size(input_shape)

In [15]:
tasks = [
    EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict(
            {
                "cnn": cnn_module,
                f"classification_module_{task_name}": ClassificationModule(
                    classification_layer_dim, 2
                ),
            }
        ),
        task_flow=[
            {"name": "cnn", "module": "cnn", "inputs": [("_input_", "image")]},
            {
                "name": f"classification_module_{task_name}",
                "module": f"classification_module_{task_name}",
                "inputs": [("cnn", 0)],
            },
        ],
        loss_func=partial(ce_loss, task_name),
        output_func=partial(output, task_name),
        scorer=Scorer(metrics=["accuracy", "roc_auc"]),
    )
    for task_name in CXR8_TASK_NAMES
]

[2019-05-03 00:53:23,576][INFO] emmental.task:34 - Created task: Atelectasis
[2019-05-03 00:53:23,578][INFO] emmental.task:34 - Created task: Cardiomegaly
[2019-05-03 00:53:23,581][INFO] emmental.task:34 - Created task: Effusion
[2019-05-03 00:53:23,583][INFO] emmental.task:34 - Created task: Infiltration
[2019-05-03 00:53:23,584][INFO] emmental.task:34 - Created task: Mass
[2019-05-03 00:53:23,586][INFO] emmental.task:34 - Created task: Nodule
[2019-05-03 00:53:23,587][INFO] emmental.task:34 - Created task: Pneumonia
[2019-05-03 00:53:23,589][INFO] emmental.task:34 - Created task: Pneumothorax
[2019-05-03 00:53:23,590][INFO] emmental.task:34 - Created task: Consolidation
[2019-05-03 00:53:23,592][INFO] emmental.task:34 - Created task: Edema
[2019-05-03 00:53:23,593][INFO] emmental.task:34 - Created task: Emphysema
[2019-05-03 00:53:23,594][INFO] emmental.task:34 - Created task: Fibrosis
[2019-05-03 00:53:23,596][INFO] emmental.task:34 - Created task: Pleural_Thickening
[2019-05-03 00:

In [16]:
mtl_model = EmmentalModel(name="Chexnet", tasks=tasks)

[2019-05-03 00:53:23,637][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,156][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,167][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,175][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,182][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,189][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,197][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,204][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,212][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,219][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,227][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:53:28,234][INFO] emmental.model:58 - Moving model to GPU (cuda:0).
[2019-05-03 00:5

In [17]:
emmental_learner = EmmentalLearner()

In [19]:
emmental_learner.learn(mtl_model, dataloaders)

[2019-05-03 00:53:28,389][INFO] emmental.logging.logging_manager:33 - Evaluating every 4000 batch.
[2019-05-03 00:53:28,390][INFO] emmental.logging.logging_manager:51 - No checkpointing.
[2019-05-03 00:53:28,423][INFO] emmental.learner:283 - Start learning...


HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




In [20]:
emmental_learner.learn(mtl_model, dataloaders)

[2019-05-03 07:02:35,371][INFO] emmental.logging.logging_manager:33 - Evaluating every 4000 batch.
[2019-05-03 07:02:35,373][INFO] emmental.logging.logging_manager:51 - No checkpointing.
[2019-05-03 07:02:35,460][INFO] emmental.learner:283 - Start learning...


HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)






HBox(children=(IntProgress(value=0, max=4905), HTML(value='')))

KeyboardInterrupt: 