In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from dataset.CaptionDataModule import CaptionDataModule
from models.MultiModalityClassifier import MultiModalityClassifier
from models.CaptionModalityClassifier import CaptionModalityClassifier
from utils.caption_utils import load_embedding_matrix
from experiments.microscopy.microscopy import experiment, get_model
import numpy as np
import json
import torch

In [2]:
MAX_NUMBER_WORDS = 20000       # number of words to consider from embeddings vocabulary
MAX_WORDS_PER_SENTENCE = 300   # sentence maximum length
WORD_DIMENSION = 300           # number of features per embedding
NUM_CLASSES = 4                # 4 microscopy classes

DATA_PATH = '/workspace/data/multimodality_classification.csv'
EMBEDDINGS = '/workspace/data/embeddings'
BASE_IMG_DIR = '/workspace/data/'
TEXT_MODEL_PATH = "./outputs/dainty-snowflake-10/checkpoint2.pt"
BATCH_SIZE = 32

In [3]:
dm = CaptionDataModule(BATCH_SIZE, DATA_PATH, MAX_NUMBER_WORDS, MAX_WORDS_PER_SENTENCE, BASE_IMG_DIR)
dm.prepare_data()
dm.setup()

In [4]:
text_model = CaptionModalityClassifier.load_from_checkpoint(checkpoint_path=TEXT_MODEL_PATH)

In [5]:
def load_shallow_model(model_id, model_dict):
    model_name, experiment_name = model_id.split('.')
    model = get_model(model_name, "shallow", 4, layers=model_dict[model_id]['layers'], pretrained=True)
    
    checkpoint = torch.load('../outputs/{0}/checkpoint.pt'.format(model_dict[model_id]['id']))
    model.load_state_dict(checkpoint)
    
    return model

JSON_INPUT_PATH = "../src/experiments/microscopy/shallow-resnet50.json"
with open(JSON_INPUT_PATH) as json_file:
    models = json.load(json_file)
resnet50_4_2 = load_shallow_model('resnet50.layer4-2', models)    

In [6]:
multi = MultiModalityClassifier(text_model, resnet50_4_2)

In [20]:
multi(x, y)

tensor([[-1.3023e+00, -7.5245e-02, -5.7086e-01, -4.4363e-01],
        [-7.6666e-01, -3.7439e-01, -2.2228e-01, -7.6566e-01],
        [-7.3730e-01, -7.7391e-02,  3.6623e-01, -4.5575e-01],
        [-9.6148e-01,  1.0762e-01, -3.0328e-01, -6.5552e-01],
        [-9.2193e-01,  3.0935e-01, -3.2081e-01, -4.0472e-01],
        [-7.6309e-01,  2.0461e-01,  2.7947e-01, -3.5668e-01],
        [-7.9378e-01, -2.6768e-01,  2.1916e-01, -4.0320e-01],
        [-7.0279e-01,  4.0420e-02, -4.0649e-01, -2.8392e-01],
        [-5.0083e-01,  3.3044e-01, -5.2417e-01, -9.0567e-01],
        [-2.0338e-01,  1.6602e-01,  1.5647e-01, -1.0180e-01],
        [-9.2463e-01,  1.1310e-01,  3.2507e-01, -4.5493e-01],
        [-7.7767e-01, -2.6316e-01,  2.1692e-01, -5.1206e-01],
        [-1.2173e+00, -3.2166e-01, -4.9767e-02, -3.1664e-01],
        [-9.4197e-01, -2.8521e-02, -5.4986e-01, -5.3572e-01],
        [-9.0618e-01,  3.6854e-02, -5.0338e-01, -7.7679e-01],
        [-5.6917e-01,  3.7597e-01,  1.8245e-01, -1.8527e-01],
        

In [7]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.0,
    patience=5,
    verbose=True,
    mode='min'
)

trainer = Trainer(gpus=1, early_stop_callback=early_stop_callback, logger=None)
trainer.fit(multi, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type                      | Params
----------------------------------------------------------
0 | accuracy    | Accuracy                  | 0     
1 | text_model  | CaptionModalityClassifier | 2 M   
2 | image_model | ShallowResNet             | 23 M  
3 | fc          | Linear                    | 9 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

                    When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the
                    'monitor' key of EarlyStopping has no effect.
                    Remove EarlyStopping(monitor='val_early_stop_on) to fix')
                


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..
Epoch 00015: early stopping triggered.





1

In [11]:
val = dm.val_dataloader()
tr = dm.train_dataloader()

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = multi.to(device)

model.eval()
y_true = []
y_pred = []
for x_t, x_v, y in val:         
    y_true += y
    
    x_t = x_t.to(device)
    x_v = x_v.to(device)
    y = y.to(device)

    outputs = model(x_t, x_v)
    _, predicted = torch.max(outputs.data, 1)
    y_pred += predicted.cpu()

In [15]:
len(y_true)

466

In [18]:
from sklearn.metrics import confusion_matrix, classification_report
confusion_matrix(y_true, y_pred)

array([[ 49,   1,   0,   2],
       [  1, 187,   0,   0],
       [  1,   2, 154,   0],
       [  0,   1,   2,  66]])

In [19]:
y_true = []
y_pred = []

for x_t, x_v, y in dm.test_dataloader():         
    y_true += y
    
    x_t = x_t.to(device)
    x_v = x_v.to(device)
    y = y.to(device)

    outputs = model(x_t, x_v)
    _, predicted = torch.max(outputs.data, 1)
    y_pred += predicted.cpu()

In [20]:
confusion_matrix(y_true, y_pred)

array([[ 41,  16,  19,  12],
       [  6, 264,  11,   3],
       [  4,   5, 394,   2],
       [ 22,   5,  13,  56]])

In [21]:
print(classification_report(y_true, y_pred, digits=4))

              precision    recall  f1-score   support

           0     0.5616    0.4659    0.5093        88
           1     0.9103    0.9296    0.9199       284
           2     0.9016    0.9728    0.9359       405
           3     0.7671    0.5833    0.6627        96

    accuracy                         0.8648       873
   macro avg     0.7852    0.7379    0.7569       873
weighted avg     0.8554    0.8648    0.8576       873

