In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from models.CaptionModalityClassifier import CaptionModalityClassifier
from dataset.CaptionDataModule import CaptionDataModule
from utils.caption_utils import load_embedding_matrix
import numpy as np

In [2]:
MAX_NUMBER_WORDS = 20000       # number of words to consider from embeddings vocabulary
MAX_WORDS_PER_SENTENCE = 300   # sentence maximum length
WORD_DIMENSION = 300           # number of features per embedding
NUM_CLASSES = 4                # 4 microscopy classes

DATA_PATH = '/workspace/data/multimodality_classification.csv'
EMBEDDINGS = '/workspace/data/embeddings'
BATCH_SIZE = 32

In [3]:
dm = CaptionDataModule(BATCH_SIZE, DATA_PATH, MAX_NUMBER_WORDS, MAX_WORDS_PER_SENTENCE)

In [4]:
dm.prepare_data()
dm.setup()

In [5]:
dm.vocab_size

7221

In [6]:
embeddings_dict = load_embedding_matrix(EMBEDDINGS, WORD_DIMENSION)

Dimension: 300; found 400000 word vectors.


In [7]:
if dm.vocab_size < MAX_NUMBER_WORDS:
    MAX_NUMBER_WORDS = dm.vocab_size + 1
embedding_matrix = np.zeros((MAX_NUMBER_WORDS, WORD_DIMENSION))
    
for word, idx in dm.word_index.items():    
    if idx < MAX_NUMBER_WORDS:
        word_embedding = embeddings_dict.get(word)
        if word_embedding is not None:
            embedding_matrix[idx] = word_embedding
        else:
            embedding_matrix[idx] = np.random.randn(WORD_DIMENSION)

In [11]:
model = CaptionModalityClassifier(
                 max_input_length=MAX_WORDS_PER_SENTENCE,
                 vocab_size=MAX_NUMBER_WORDS,
                 embedding_dim=WORD_DIMENSION,
                 filters=100,
                 embeddings=embedding_matrix,
                 num_classes=NUM_CLASSES,
                 train_embeddings=True,
                 lr=1e-4)

In [15]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
import torch

wandb_logger = WandbLogger(project='pytorchlightning')
wandb_logger.experiment.save()
print(wandb_logger.experiment.name)

from pathlib import Path
import os
output_run_path = Path('./outputs') / wandb_logger.experiment.name 
os.makedirs(output_run_path, exist_ok=False)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.0,
    patience=5,
    verbose=True,
    mode='min'
)

trainer = Trainer(gpus=1, early_stop_callback=early_stop_callback, logger=wandb_logger)
trainer.fit(model, dm)

torch.save(model.state_dict(), output_run_path / 'checkpoint.pt')


Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Wandb version 0.10.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type            | Params
---------------------------------------------
0 | accuracy | Accuracy        | 0     
1 | CNNText  | CNNTextBackbone | 2 M   
2 | fc       | Linear          | 1 K   


dainty-snowflake-10




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

                    When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the
                    'monitor' key of EarlyStopping has no effect.
                    Remove EarlyStopping(monitor='val_early_stop_on) to fix')
                


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..
Epoch 00007: early stopping triggered.





In [10]:
print(model)

CaptionModalityClassifier(
  (accuracy): Accuracy()
  (CNNText): CNNTextBackbone(
    (embeddings): Embedding(7222, 300)
    (conv1d_1): Conv1d(300, 100, kernel_size=(3,), stride=(1,))
    (relu1): ReLU()
    (maxpool1): MaxPool1d(kernel_size=298, stride=298, padding=0, dilation=1, ceil_mode=False)
    (conv1d_2): Conv1d(300, 100, kernel_size=(4,), stride=(1,))
    (relu2): ReLU()
    (maxpool2): MaxPool1d(kernel_size=297, stride=297, padding=0, dilation=1, ceil_mode=False)
    (conv1d_3): Conv1d(300, 100, kernel_size=(5,), stride=(1,))
    (relu3): ReLU()
    (maxpool3): MaxPool1d(kernel_size=296, stride=296, padding=0, dilation=1, ceil_mode=False)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (fc): Linear(in_features=300, out_features=4, bias=True)
)


In [17]:
trainer.save_checkpoint(str(output_run_path / 'checkpoint2.pt'))

In [18]:
for x, y in dm.train_dataloader():
    break

In [19]:
x = 
model(x)

tensor([[-1.7381, -1.8300,  4.3508, -2.1198],
        [-2.6980,  5.6537, -0.9756, -1.6762],
        [ 1.2824, -2.3209, -1.8930, -0.1409],
        [ 0.2061, -0.0753,  0.9443, -2.9515],
        [-1.4886, -1.1114,  1.9535, -0.8665],
        [-1.9705,  1.8203,  0.7369, -4.1768],
        [-1.6583, -1.6331,  4.3735, -1.5197],
        [-2.6538,  2.0293, -0.1982, -2.0733],
        [-2.8402,  4.2675, -1.3636, -3.1963],
        [ 3.6290, -0.9845, -2.5096, -1.0467],
        [-2.1151,  3.6398, -2.3830, -1.5604],
        [-2.9140,  4.4678, -1.2635, -0.8338],
        [-2.6110,  2.7123, -3.9475, -1.9221],
        [ 3.1484, -0.8544, -1.9854, -1.7324],
        [-2.4984, -0.6780,  2.8862, -2.1551],
        [-2.6663,  0.0908,  3.1733, -2.0626],
        [-1.5970, -1.8656,  2.9868, -1.7876],
        [ 1.1765, -2.5292,  1.9170, -1.8969],
        [-3.0695, -2.4462, -1.2969,  4.1975],
        [-2.6766,  3.7560, -2.0168, -1.2125],
        [-1.5802, -1.5938,  4.4949, -1.2684],
        [-2.9759, -1.0397,  4.8811