In [2]:
# Reference to ../src
import os
import sys
module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)

import pandas as pd
import csv
import os

from dataset.MicroscopyTrainDataLoader import MicroscopyTrainDataLoader
from experiments.microscopy.microscopy import experiment, get_model
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from sklearn.metrics import classification_report

import numpy as np; np.random.seed(0)
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt

from models.CaptionModalityClassifier import CaptionModalityClassifier
from models.MultiModalityClassifier import MultiModalityClassifier
from dataset.CaptionDataModule import CaptionDataModule
from utils.caption_utils import load_embedding_matrix

In [3]:
def load_shallow_model(model_id, model_dict):
    model_name, experiment_name = model_id.split('.')
    model = get_model(model_name, "shallow", 4, layers=model_dict[model_id]['layers'], pretrained=True)
    
    checkpoint = torch.load('../outputs/{0}/checkpoint.pt'.format(model_dict[model_id]['id']))
    model.load_state_dict(checkpoint)
    
    return model

In [4]:
JSON_INPUT_PATH = "../src/experiments/microscopy/shallow-resnet50.json"
import json

with open(JSON_INPUT_PATH) as json_file:
    models = json.load(json_file)

In [5]:
resnet50_4_2 = load_shallow_model('resnet50.layer4-2', models)

In [6]:
sample_image = torch.rand([3,224,224])
x_image = torch.stack([sample_image, sample_image], dim=0)
x_image.shape

torch.Size([2, 3, 224, 224])

In [7]:
resnet50_4_2(x_image)

tensor([[ 0.0799, -0.4657, -0.2040,  0.1104],
        [ 0.0799, -0.4657, -0.2040,  0.1104]], grad_fn=<AddmmBackward>)

In [8]:
TEXT_MODEL_PATH = "./outputs/dainty-snowflake-10/checkpoint2.pt"

In [11]:
MAX_NUMBER_WORDS = 20000       # number of words to consider from embeddings vocabulary
MAX_WORDS_PER_SENTENCE = 300   # sentence maximum length
WORD_DIMENSION = 300           # number of features per embedding
NUM_CLASSES = 4                # 4 microscopy classes

DATA_PATH = '/workspace/data/multimodality_classification.csv'
# EMBEDDINGS = '/workspace/data/embeddings'
BATCH_SIZE = 32

dm = CaptionDataModule(BATCH_SIZE, DATA_PATH, MAX_NUMBER_WORDS, MAX_WORDS_PER_SENTENCE)
dm.prepare_data()
dm.setup()

# embeddings_dict = load_embedding_matrix(EMBEDDINGS, WORD_DIMENSION)

# if dm.vocab_size < MAX_NUMBER_WORDS:
#     MAX_NUMBER_WORDS = dm.vocab_size + 1
# embedding_matrix = np.zeros((MAX_NUMBER_WORDS, WORD_DIMENSION))
    
# for word, idx in dm.word_index.items():    
#     if idx < MAX_NUMBER_WORDS:
#         word_embedding = embeddings_dict.get(word)
#         if word_embedding is not None:
#             embedding_matrix[idx] = word_embedding
#         else:
#             embedding_matrix[idx] = np.random.randn(WORD_DIMENSION)

In [12]:
# text_model = CaptionModalityClassifier(
#                  max_input_length=MAX_WORDS_PER_SENTENCE,
#                  vocab_size=MAX_NUMBER_WORDS,
#                  embedding_dim=WORD_DIMENSION,
#                  filters=100,
#                  embeddings=embedding_matrix,
#                  num_classes=NUM_CLASSES,
#                  train_embeddings=True,
#                  lr=1e-4)

In [13]:
text_model = CaptionModalityClassifier.load_from_checkpoint(checkpoint_path=TEXT_MODEL_PATH)

In [14]:
train_dataloader = dm.train_dataloader()

In [15]:
for x_text, y in train_dataloader:
    break

In [16]:
x_sample = x_text[:2]
# there is a bug when i send just one image
x_sample.shape, x_text.shape

(torch.Size([2, 300]), torch.Size([32, 300]))

In [25]:
text_model(x_sample).shape

torch.Size([2, 4])

In [27]:
multi = MultiModalityClassifier(text_model, resnet50_4_2)

In [32]:
resnet50_4_2(x_image).shape

torch.Size([1, 4])

In [29]:
multi(x_sample, x_image)

tensor([[ 0.1857, -0.2383,  0.2918, -0.4302],
        [ 0.0160, -0.2975,  0.4445, -0.3780]], grad_fn=<AddmmBackward>)

In [18]:
text_features = text_model(x_sample)
image_features = resnet50_4_2(x_image)
text_features.shape, image_features.shape

(torch.Size([2, 300]), torch.Size([2, 2048]))

In [23]:
import torch.nn as nn
fc = nn.Linear(300+2048, 4)

In [17]:
text_model.fc = nn.Identity()
resnet50_4_2.fc = nn.Identity()

In [19]:
x1 = text_features.view(text_features.size(0), -1)
x1.shape

torch.Size([2, 300])

In [20]:
x = torch.cat([text_features, image_features], dim=1)

In [21]:
x.shape

torch.Size([2, 2348])

In [24]:
a = fc(x)

In [25]:
a.shape

torch.Size([2, 4])

In [26]:
a

tensor([[-0.1928,  0.0730, -0.1536,  0.5851],
        [-0.0774, -0.1140, -0.1233,  0.5369]], grad_fn=<AddmmBackward>)